You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
functionapply!(o::WeightDecay, state, x::AbstractArray{T}, dx) where T
λ =T(o.lambda)
dx′ =@lazy dx + λ * x
return state, dx′
end
In AdamW, and indeed in PyTorch, the WeightDecay value needs to be multiplied by the learning rate too:
From: https://arxiv.org/pdf/1711.05101
This appears to be the source of some great frustration for me, as I was observing extreme misbehavior from the model I've been trying to port from PyTorch.
The following optimiser produces the correct behavior:
Optimisers.@def struct LearningWeightDecay <: Optimisers.AbstractRule
lambda = 5e-4
eta = 0.001
end
Optimisers.init(o::LearningWeightDecay, x::AbstractArray) = nothing
function Optimisers.apply!(o::LearningWeightDecay, state, x::AbstractArray{T}, dx) where T
λ, η = T(o.lambda), T(o.eta)
dx′ = Optimisers.@lazy dx + η * λ * x
return state, dx′
end
CorrectAdamW(η, β = (0.9, 0.999), λ = 0.0, ϵ = 1e-8) =
Optimisers.OptimiserChain(Optimisers.Adam(η, β, ϵ), LearningWeightDecay(λ, η))
The text was updated successfully, but these errors were encountered:
Duplicate of FluxML/Flux.jl#2433. The TL;DR is that Optimisers and Flux try to follow the original paper while PyTorch decided to do its own thing. For more information on why this is less than ideal, see my comment at FluxML/Flux.jl#2433 (comment).
For better or worse, the PyTorch approach has become the de-facto one because it's the 300lb gorilla in the room. If someone wants to carry out the work mentioned in that Flux issue, we'll be compatible with them by default.
In Optimisers.jl,
AdamW
is implemented as anOptimiserChain
ofAdam
andWeightDecay
:Optimisers.jl/src/rules.jl
Lines 510 to 514 in c2ae321
WeightDecay here simply multiplies the decay value by the parameter:
Optimisers.jl/src/rules.jl
Lines 569 to 574 in c2ae321
In AdamW, and indeed in PyTorch, the WeightDecay value needs to be multiplied by the learning rate too:
From: https://arxiv.org/pdf/1711.05101
This appears to be the source of some great frustration for me, as I was observing extreme misbehavior from the model I've been trying to port from PyTorch.
The following optimiser produces the correct behavior:
The text was updated successfully, but these errors were encountered: