Skip to content

Commit

Permalink
Use eltype(x) everywhere, ignore typeof(η) (#151)
Browse files Browse the repository at this point in the history
* always convert learning rate to eltype(momentum) before use

* don't parameterise rule structs, just use Float64

* fix tests like === 0.2f0

* fix a constructor

* fix AdamW

* fix Rprop

* use a macro to define structs with default values

* use T = eltype(x)

* a few more structs

* more structs

* fix tests

* doc fixes

* fix docstrings

* skip Yota on nightly

* docstrings

* breaking change, v0.3-dev

* print Adam(0.01f0) without 0.009999999776482582

* Revert "skip Yota on nightly"

This reverts commit abf0e13.

* don't accidentally write to stdout
  • Loading branch information
mcabbott authored Aug 21, 2023
1 parent 322a6bb commit 5ac2d6f
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 178 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Optimisers"
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
authors = ["Mike J Innes <mike.j.innes@gmail.com>"]
version = "0.2.20"
version = "0.3.0-DEV"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
7 changes: 4 additions & 3 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ These act on one array of parameters:

```julia
# Define a container to hold any optimiser specific parameters (if any):
struct DecayDescent{T} <: Optimisers.AbstractRule
eta::T
struct DecayDescent <: Optimisers.AbstractRule
eta::Float64
end

# Define an `apply!` rule which encodes how the gradients will be used to
# update the parameters:
function Optimisers.apply!(o::DecayDescent, state, x, x̄)
newx̄ = (o.eta / state) .*
T = eltype(x)
newx̄ = T(o.eta / state) .*
nextstate = state + 1
return nextstate, newx̄
end
Expand Down
14 changes: 7 additions & 7 deletions src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ or [`update!`](@ref).
julia> m = (x = rand(3), y = (true, false), z = tanh);
julia> Optimisers.setup(Momentum(), m) # same field names as m
(x = Leaf(Momentum{Float32}(0.01, 0.9), [0.0, 0.0, 0.0]), y = ((), ()), z = ())
(x = Leaf(Momentum(0.01, 0.9), [0.0, 0.0, 0.0]), y = ((), ()), z = ())
```
The recursion into structures uses Functors.jl, and any new `struct`s containing parameters
Expand All @@ -91,15 +91,15 @@ julia> struct Layer; mat; fun; end
julia> model = (lay = Layer([1 2; 3 4f0], sin), vec = [5, 6f0]);
julia> Optimisers.setup(Momentum(), model) # new struct is by default ignored
(lay = (), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))
(lay = (), vec = Leaf(Momentum(0.01, 0.9), Float32[0.0, 0.0]))
julia> destructure(model)
(Float32[5.0, 6.0], Restructure(NamedTuple, ..., 2))
julia> using Functors; @functor Layer # annotate this type as containing parameters
julia> Optimisers.setup(Momentum(), model)
(lay = (mat = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0 0.0; 0.0 0.0]), fun = ()), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))
(lay = (mat = Leaf(Momentum(0.01, 0.9), Float32[0.0 0.0; 0.0 0.0]), fun = ()), vec = Leaf(Momentum(0.01, 0.9), Float32[0.0, 0.0]))
julia> destructure(model)
(Float32[1.0, 3.0, 2.0, 4.0, 5.0, 6.0], Restructure(NamedTuple, ..., 6))
Expand All @@ -120,13 +120,13 @@ See also [`update!`](@ref), which will be faster for models of ordinary `Array`s
```jldoctest
julia> m = (x = Float32[1,2,3], y = tanh);
julia> t = Optimisers.setup(Descent(0.1f0), m)
(x = Leaf(Descent{Float32}(0.1), nothing), y = ())
julia> t = Optimisers.setup(Descent(0.1), m)
(x = Leaf(Descent(0.1), nothing), y = ())
julia> g = (x = [1,1,1], y = nothing); # fake gradient
julia> Optimisers.update(t, m, g)
((x = Leaf(Descent{Float32}(0.1), nothing), y = ()), (x = Float32[0.9, 1.9, 2.9], y = tanh))
((x = Leaf(Descent(0.1), nothing), y = ()), (x = Float32[0.9, 1.9, 2.9], y = tanh))
```
"""
update
Expand All @@ -152,7 +152,7 @@ julia> using StaticArrays, Zygote, Optimisers
julia> m = (x = [1f0, 2f0], y = SA[4f0, 5f0]); # partly mutable model
julia> t = Optimisers.setup(Momentum(1/30, 0.9), m) # tree of states
(x = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.0, 0.0]), y = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.0, 0.0]))
(x = Leaf(Momentum(0.0333333, 0.9), Float32[0.0, 0.0]), y = Leaf(Momentum(0.0333333, 0.9), Float32[0.0, 0.0]))
julia> g = gradient(m -> sum(abs2.(m.x .+ m.y)), m)[1] # structural gradient
(x = Float32[10.0, 14.0], y = Float32[10.0, 14.0])
Expand Down
18 changes: 9 additions & 9 deletions src/adjust.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ julia> Optimisers.freeze!(s.x)
julia> Optimisers.update!(s, m, (x = ([pi], 10pi), y = [100pi])); # with fake gradient
julia> m
(x = ([1.0], 2.0), y = [-0.14159258336972558])
(x = ([1.0], 2.0), y = [-0.14159265358979312])
julia> s
(x = (Leaf(Momentum{Float32}(0.01, 0.9), [0.0], frozen = true), ()), y = Leaf(Momentum{Float32}(0.01, 0.9), [3.14159]))
(x = (Leaf(Momentum(0.01, 0.9), [0.0], frozen = true), ()), y = Leaf(Momentum(0.01, 0.9), [3.14159]))
julia> Optimisers.thaw!(s)
julia> s.x
(Leaf(Momentum{Float32}(0.01, 0.9), [0.0]), ())
(Leaf(Momentum(0.01, 0.9), [0.0]), ())
```
"""
freeze!(tree) = foreach(freeze!, tree)
Expand Down Expand Up @@ -72,17 +72,17 @@ To change just the learning rate, provide a number `η::Real`.
julia> m = (vec = rand(Float32, 2), fun = sin);
julia> st = Optimisers.setup(Nesterov(), m) # stored momentum is initialised to zero
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[0.0, 0.0]), fun = ())
(vec = Leaf(Nesterov(0.001, 0.9), Float32[0.0, 0.0]), fun = ())
julia> st, m = Optimisers.update(st, m, (vec = [16, 88], fun = nothing)); # with fake gradient
julia> st
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[-0.016, -0.088]), fun = ())
(vec = Leaf(Nesterov(0.001, 0.9), Float32[-0.016, -0.088]), fun = ())
julia> Optimisers.adjust!(st, 0.123) # change learning rate, stored momentum untouched
julia> st
(vec = Leaf(Nesterov{Float32}(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())
(vec = Leaf(Nesterov(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())
```
To change other parameters, `adjust!` also accepts keyword arguments matching the field
Expand All @@ -93,13 +93,13 @@ julia> fieldnames(Adam)
(:eta, :beta, :epsilon)
julia> st2 = Optimisers.setup(OptimiserChain(ClipGrad(), Adam()), m)
(vec = Leaf(OptimiserChain(ClipGrad{Float32}(10.0), Adam{Float32}(0.001, (0.9, 0.999), 1.19209f-7)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())
(vec = Leaf(OptimiserChain(ClipGrad(10.0), Adam(0.001, (0.9, 0.999), 1.0e-8)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())
julia> Optimisers.adjust(st2; beta = (0.777, 0.909), delta = 11.1) # delta acts on ClipGrad
(vec = Leaf(OptimiserChain(ClipGrad{Float32}(11.1), Adam{Float32}(0.001, (0.777, 0.909), 1.19209f-7)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())
(vec = Leaf(OptimiserChain(ClipGrad(11.1), Adam(0.001, (0.777, 0.909), 1.0e-8)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())
julia> Optimisers.adjust(st; beta = "no such field") # silently ignored!
(vec = Leaf(Nesterov{Float32}(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())
(vec = Leaf(Nesterov(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())
```
"""
adjust!(tree, eta::Real) = foreach(st -> adjust!(st, eta), tree)
Expand Down
43 changes: 43 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ const Zero = Union{Nothing, AbstractZero} # Union{Zygote, Diffractor}

abstract type AbstractRule end

function Base.show(io::IO, rule::AbstractRule) # makes Adam(0.01f0) prettier
invoke(show, Tuple{IO,Any}, IOContext(io, :compact => true), rule)
end

###
### setup
###
Expand Down Expand Up @@ -225,3 +229,42 @@ Broadcast.materialize(x::Lazy) = Broadcast.instantiate(x.bc)
onevalue::T, x::AbstractArray{T}) where T = map(_ -> λ, x)
onevalue(λ, x::AbstractArray{T}) where T = onevalue(convert(float(T), λ), x)

nonneg::Real) = η < 0 ? throw(DomainError(η, "the learning rate cannot be negative")) : η

"""
@def struct Rule; eta = 0.1; beta = (0.7, 0.8); end
Helper macro for defining rules with default values.
The types of the literal values are used in the `struct`,
like this:
```
struct Rule
eta::Float64
beta::Tuple{Float64, Float64}
Rule(eta = 0.1, beta = (0.7, 0.8)) = eta < 0 ? error() : new(eta, beta)
end
```
Any field called `eta` is assumed to be a learning rate, and cannot be negative.
"""
macro def(expr)
Meta.isexpr(expr, :struct) || throw("@def must act on a struct definition")
lines = expr.args[3].args
names, vals = [], []
for i in eachindex(lines)
lines[i] isa Symbol && throw("@def requires a default for every field")
Meta.isexpr(lines[i], :(=)) || continue
name, val = lines[i].args
push!(names, name)
push!(vals, val)
lines[i] = :($name::$typeof($val))
end
rule = Meta.isexpr(expr.args[2], :<:) ? expr.args[2].args[1] : expr.args[2]
check = :eta in names ? :(eta < 0 && throw(DomainError(eta, "the learning rate cannot be negative"))) : nothing
inner = :(function $rule($([Expr(:kw, nv...) for nv in zip(names,vals)]...))
$check
new($(names...))
end)
push!(lines, inner)
esc(expr)
end

Loading

0 comments on commit 5ac2d6f

Please sign in to comment.