Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AdamW optimizer implemented incorrectly - weight decay does not incorporate learning rate #182

Closed
BioTurboNick opened this issue Oct 25, 2024 · 2 comments

Comments

@BioTurboNick
Copy link

BioTurboNick commented Oct 25, 2024

In Optimisers.jl, AdamW is implemented as an OptimiserChain of Adam and WeightDecay:

Optimisers.jl/src/rules.jl

Lines 510 to 514 in c2ae321

AdamW(η, β = (0.9, 0.999), λ = 0.0, ϵ = 1e-8) =
OptimiserChain(Adam(η, β, ϵ), WeightDecay(λ))
AdamW(; eta = 0.001, beta = (0.9, 0.999), lambda = 0, epsilon = 1e-8) =
OptimiserChain(Adam(eta, beta, epsilon), WeightDecay(lambda))

WeightDecay here simply multiplies the decay value by the parameter:

Optimisers.jl/src/rules.jl

Lines 569 to 574 in c2ae321

function apply!(o::WeightDecay, state, x::AbstractArray{T}, dx) where T
λ = T(o.lambda)
dx′ = @lazy dx + λ * x
return state, dx′
end

In AdamW, and indeed in PyTorch, the WeightDecay value needs to be multiplied by the learning rate too:
image
From: https://arxiv.org/pdf/1711.05101

This appears to be the source of some great frustration for me, as I was observing extreme misbehavior from the model I've been trying to port from PyTorch.

The following optimiser produces the correct behavior:


Optimisers.@def struct LearningWeightDecay <: Optimisers.AbstractRule
  lambda = 5e-4
  eta = 0.001
end

Optimisers.init(o::LearningWeightDecay, x::AbstractArray) = nothing

function Optimisers.apply!(o::LearningWeightDecay, state, x::AbstractArray{T}, dx) where T
  λ, η = T(o.lambda), T(o.eta)
  dx′ = Optimisers.@lazy dx + η * λ * x

  return state, dx′
end

CorrectAdamW(η, β = (0.9, 0.999), λ = 0.0, ϵ = 1e-8) =
  Optimisers.OptimiserChain(Optimisers.Adam(η, β, ϵ), LearningWeightDecay(λ, η))
@ToucheSir
Copy link
Member

Duplicate of FluxML/Flux.jl#2433. The TL;DR is that Optimisers and Flux try to follow the original paper while PyTorch decided to do its own thing. For more information on why this is less than ideal, see my comment at FluxML/Flux.jl#2433 (comment).

For better or worse, the PyTorch approach has become the de-facto one because it's the 300lb gorilla in the room. If someone wants to carry out the work mentioned in that Flux issue, we'll be compatible with them by default.

@ToucheSir ToucheSir closed this as not planned Won't fix, can't repro, duplicate, stale Oct 25, 2024
@BioTurboNick
Copy link
Author

Ah. Well, dang.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants