Home

Awesome

TensorRules.jl

Build Status Code Style: Blue

TensorRules.jl provides a macro @∇ (you can type by \nabla<tab>), which enable us to use automatic differentiation (AD) libraries (e.g., Zygote.jl, Diffractor.jl) with @tensor and @tensoropt macros in TensorOperations.jl.

TensorRules.jl uses ChainRulesCore.jl to define custom adjoints. So, you can use any AD libraries which supports ChainRulesCore.jl.

How to use

julia> using TensorOperations, TensorRules, Zygote;
julia> function foo(a, b, c) # define function with Einstein summation
           # d_F = \sum_{A,B,C,D} a_{A,B,C} b_{C,D,E,F} c_{A,B,D,E}
           @tensor d[F] := a[A, B, C] * b[C, D, E, F] * c[A, B, D, E]
           return d[1]
       end;
julia> a, b, c = randn(3, 4, 5), randn(5, 6, 7, 8), randn(3, 4, 6, 7);
julia> gradient(foo, a, b, c); # try to obtain gradient of `foo` by Zygote
ERROR: this intrinsic must be compiled to be called
Stacktrace:
...
julia> @∇ function foo(a, b, c) # use @∇
           @tensor d[F] := a[A, B, C] * b[C, D, E, F] * c[A, B, D, E]
           return d[1]
       end;
julia> gradient(foo, a, b, c); # it works!

How it works

The strategy of TensorRules.jl are very similar to TensorGrad.jl.

@∇ converts functions which contains @tensor or @tensoropt macro. First, @∇ detects @tensor or @tensoropt expressions in function definition and convert them to inlined functions. Then, @∇ define custom adjoint rules for the generated functions.

For example, the following definition

@∇ function foo(a, b, c, d, e, f)
    @tensoropt !C x[A, B] := conj(a[A, C]) * sin.(b)[C, D] * c.d[D, B] + d * e[1, 2][A, B]
    x = x + f
    @tensor x[A, B] += a[A, C] * (a * a)[C, B]
    return x
end

will be converted to a code equivalent to

function foo(a, b, c, d, e, f)
    x = _foo_1(a, sin.(a), c.d, d, e[1, 2])
    x = x + f
    x += _foo_2(a, a * a)
    return x
end

@inline _foo_1(x1, x2, x3, x4, x5) =
    @tensoropt !C _[A, B] := conj(x1[A, C]) * x2[C, D] * x3[D, B] + x4 * x5[A, B]

@inline _foo_2(x1, x2) = @tensor _[A, B] := x1[A, C] * x2[C, B]

function rrule(::typeof(_foo_1), x1, x2, x3, x4, x5)
    f = _foo_1(x1, x2, x3, x4, x5)
    Px1, Px2, Px3, Px4, Px5 = ProjectTo(x1), ProjectTo(x2), ProjectTo(x3), ProjectTo(x4), ProjectTo(x5)
    function _foo_1_pullback(Δf)
        fnΔx1(Δf, x1, x2, x3, x4, x5) = @tensoropt !C _[A, C] := conj(Δf[A, B]) * x2[C, D] * x3[D, B]
        fnΔx1add!!(x, Δf, x1, x2, x3, x4, x5) = @tensoropt !C x[A, C] += conj(Δf[A, B]) * x2[C, D] * x3[D, B]
        fnΔx2(Δf, x1, x2, x3, x4, x5) = @tensoropt !C _[C, D] := conj(conj(x1[A, C]) * conj(Δf[A, B]) * x3[D, B])
        fnΔx2add!!(x, Δf, x1, x2, x3, x4, x5) = @tensoropt !C x[C, D] += conj(conj(x1[A, C]) * conj(Δf[A, B]) * x3[D, B])
        fnΔx3(Δf, x1, x2, x3, x4, x5) = @tensoropt !C _[D, B] := conj(conj(x1[A, C]) * x2[C, D] * conj(Δf[A, B]))
        fnΔx3add!!(x, Δf, x1, x2, x3, x4, x5) = @tensoropt !C x[D, B] += conj(conj(x1[A, C]) * x2[C, D] * conj(Δf[A, B]))
        fnΔx4(Δf, x1, x2, x3, x4, x5) = first(@tensoropt !C _[] := conj(conj(Δf[A, B]) * x5[A, B]))
        fnΔx5(Δf, x1, x2, x3, x4, x5) = @tensoropt !C _[A, B] := conj(x4 * conj(Δf[A, B]))
        fnΔx5add!!(x, Δf, x1, x2, x3, x4, x5) = @tensoropt !C x[A, B] += conj(x4 * conj(Δf[A, B]))
        Δx1 = InplaceableThunk(
            Thunk(() -> Px1(fnΔx1(Δf, x1, x2, x3, x4, x5))),
            x -> fnΔx1add!!(x, Δf, x1, x2, x3, x4, x5)
        )
        Δx2 = InplaceableThunk(
            Thunk(() -> Px2(fnΔx2(Δf, x1, x2, x3, x4, x5))),
            x -> fnΔx2add!!(x, Δf, x1, x2, x3, x4, x5)
        )
        Δx3 = InplaceableThunk(
            Thunk(() -> Px3(fnΔx3(Δf, x1, x2, x3, x4, x5))),
            x -> fnΔx3add!!(x, Δf, x1, x2, x3, x4, x5)
        )
        Δx4 = Thunk(() -> fnΔx4(Δf, x1, x2, x3, x4, x5))
        Δx5 = InplaceableThunk(
            Thunk(() -> Px5(fnΔx5(Δf, x1, x2, x3, x4, x5))),
            x -> fnΔx5add!!(x, Δf, x1, x2, x3, x4, x5)
        )
        return (NoTangent(), Δx1, Δx2, Δx3, Δx4, Δx5)
    end
    return f, _foo_1_pullback
end

function rrule(::typeof(_foo_2), x1, x2)
    ...
end

By using Thunk and InplaceableThunk properly, adjoints will be evaluated only if they are needed.

unsupported features

TODO