Awesome
Tracker.jl
<!---[![Coverage](https://coveralls.io/repos/github/FluxML/Tracker.jl/badge.svg?branch=master)](https://coveralls.io/github/FluxML/Tracker.jl?branch=master) --->This was the original automatic differentiation engine for Flux.jl, before being replaced by Zygote.jl in 2019. Both were written by Mike Innes.
This package is solid and still in active use, but is no longer heavily maintained. PRs and issues may go unanswered.
Introduction
Like ReverseDiff.jl and AutoGrad.jl, Tracker traces through a program by wrapping arrays in a special TrackedArray
type. The final answer contains a "tape" of the operations performed, which is reversed by back!
:
x = param([1,2,3]) # Tracked 3-element Vector{Float64}
f(x) = sum(abs2, x) + prod(x[2:end])
y = f(x) # TrackedReal
back!(y) # run back-propagation
Tracker.grad(x) # extract gradient from TrackedArray
This is a much lower-tech approach than that of Zygote, Yota and Diffractor. At best, those can produce fast, compiled Julia code for the reverse pass, instead of an interpreted tape. At worst, they can have extremely long compile-times and can be difficult to debug.
Interface
Instead of calling back!
yourself, you can pass the function and the input to gradient
:
gradient(f, [1,2,3]) # returns ([2.0, 7.0, 8.0],)
withgradient(f, [1,2,3]) # returns (val = 20, grad = ([2.0, 7.0, 8.0],))
The original interface to Flux.jl involved a dictionary of arrays called Params
, much like Zygote's "implicit" parameter interface. This appears not to be documented.
A more modern way to use Flux relies on withgradient
's ability to take gradients with respect to complex nested structures. This is what Optimisers.jl is designed to accept:
julia> using Flux, Tracker
julia> model = Chain(Dense(2 => 1, tanh), Dense(1 => 1, bias=false));
julia> withgradient(model, rand(Float32, 2)) do m, x
sum(abs2, m(x))
end
(val = 0.035716165f0,
grad = ((layers = ((weight = Float32[-0.4241869 -0.16741231], bias = Float32[-0.5529184], σ = nothing),
(weight = Float32[-0.04804218;;], bias = nothing, σ = nothing)),),
Float32[0.12706584, -0.08858479]))
Rules
Tracker.jl contains rules for many common operations. It relies on DiffRules.jl for many definitions, and does not connect to the newer ChainRules.jl at all.
To define more rules, use track
and @grad
. See the source for more examples:
f(x::TrackedArray) = track(f, x) # entry point, via dispatch
@grad function f(x)
y = f(data(x)) # forward pass, withtout tracking
back(dy) = (dy * ∂f∂x(data(x)),) # pullback function, returns a tuple
return y, back
end