diff --git a/Project.toml b/Project.toml index a0734472..978a1087 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Optimisers" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" authors = ["Mike J Innes "] -version = "0.2.2" +version = "0.2.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rules.jl b/src/rules.jl index 65392a0f..9f161f4f 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -442,20 +442,24 @@ end """ WeightDecay(γ = 5f-4) -Decay weights by `γ`. +Decay weights by ``γ``, that is, add `γ .* x` to the gradient `x̄` which will be +subtracted from `x`. + +Typically composed with other optimisers as the first transformation in an [`OptimiserChain`](@ref). +This is equivalent to adding ``L_2`` regularization with coefficient ``γ`` to the loss. # Parameters - Weight decay (`γ`): Decay applied to weights during optimisation. """ struct WeightDecay{T} - wd::T + gamma::T end WeightDecay() = WeightDecay(5f-4) init(o::WeightDecay, x::AbstractArray) = nothing function apply!(o::WeightDecay, state, x, dx) - dx′ = @lazy dx + o.wd * x + dx′ = @lazy dx + o.gamma * x return state, dx′ end