Optimisers.jl defines many standard gradient-based optimisation rules, and tools for applying them to deeply nested models.
This is the future of training for Flux.jl neural networks, and the present for Lux.jl. But it can be used separately on any array, or anything else understood by Functors.jl.
] add Optimisers
The core idea is that optimiser state (such as momentum) is explicitly handled.
It is initialised by setup
, and then at each step, update
returns both the new
state, and the model with its trainable parameters adjusted:
state = Optimisers.setup(Optimisers.Adam(), model) # just once
grad = Zygote.gradient(m -> loss(m(x), y), model)[1]
state, model = Optimisers.update(state, model, grad) # at every step
For models with deeply nested layers containing the parameters (like Flux.jl models),
this state is a similarly nested tree. As is the gradient: if using Zygote, you must use the "explicit" style as shown,
not the "implicit" one with Params
.
The function destructure
collects all the trainable parameters into one vector,
and returns this along with a function to re-build a similar model:
vector, re = Optimisers.destructure(model)
model2 = re(2 .* vector)
The documentation explains usage in more detail, describes all the optimization rules, and shows how to define new ones.