Awesome
TensorGrad.jl
This package adds gradient definitions for Zygote.jl
to most calculations using TensorOperations.jl,
and some using Einsum.jl.
It exports a macro @grad
which rewrites an expression like
@grad @tensor A[i,k] := B[i,j] * C[j,k] * D[l,l]
into something equivalent to this:
fun(b,c,d) = @tensor a[i,k] := b[i,j] * c[j,k] * d[l,l] # define a function
@adjoint function fun(b,c,d)
fwd = @tensor a[i,k] := b[i,j] * c[j,k] * d[l,l] # forward pass
function back(Δa)
@tensor Δb[i,j] := Δa[i,k] * c[j,k] * d[l,l] # reverse pass
@tensor Δc[j,k] := b[i,j] * Δa[i,k] * d[l,l]
δ = Diagonal(ones(size(d,1)))
@tensor Δd[l,l′] := b[i,j] * c[j,k] * Δa[i,k] * δ[l,l′]
return (Δb, Δc, Δd)
end
return (fwd, back)
end
A = fun(B,C,D) # apply this to B, C, D
You may also write @grad B C @tensor A[i,k] := B[i,j] * C[j,k] * D[l,l]
to specify that
only sensitivities for B
and C
are needed, this will remove the calculation
of Δd
above.
To see what is being defined, call TensorGrad.verbose(true)
before the macro
(rather than using @macroexpand1
).
If Tracker.jl is loaded, then it will now
define the same gradients for B::TrackedArray
etc.
Note that this is a fairly crude experiment, probably not something to rely on.
Limitations:
- The expression must be one term, and scalar factors are not handled yet.
- It makes no attempt to cache intermediate contractions for re-use,
and thus if there are many tensors it will do the same work several times
(like
b[i,j] * c[j,k]
above, done twice). - Requires you to add
@grad
everywhere, so won't work in other people's code.
I can solve 1. But 2 seems hard to solve with this design.
It now understands other macros like @einsum
which share the same syntax.
This allows it to treat non-Einstein contractions, such as batched matrix multiplication:
@grad x @einsum z[i,k,b] := x[i,j,b] * y[j,k,b]
Those are also handled by @ein
from OMEinsum.jl,
which may be pointless as that has its own gradients built-in.
Probably you should use that instead!
An earlier attempt is now TensorTrack.jl, which works at the level of
functions contract!
etc, and thus gets some re-use, 4.
But is completely limited by 2, being deeply plugged into TensorOperations.
Finally, note also that TensorCast.jl should be almost fully differentiable (although focused on operations other than contractions).
--- Michael Abbott, August 2019
Update:
Essentially the same code has been bolted onto Tullio.jl
originally in PR#6, and moved to @tensor
in PR#92. It has the same limitations as above.
(But it avoids eval
by attaching gradients to a callable struct Eval
always, not to the newly defined functions.)
The package TensorRules.jl has a macro @∇
which performs
manipulations of @tensor
expressions, acting on whole functions containing them.