Skip to content

Commit

Permalink
Merge pull request #144 from mcabbott/bc_norm
Browse files Browse the repository at this point in the history
Make ClipNorm work on GPU Broadcasted
  • Loading branch information
mcabbott authored Jul 10, 2023
2 parents 8a37946 + d73a0ee commit 6eaf26d
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Optimisers"
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
authors = ["Mike J Innes <mike.j.innes@gmail.com>"]
version = "0.2.18"
version = "0.2.19"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
25 changes: 24 additions & 1 deletion src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ ClipNorm(ω = 10f0, p = 2; throw::Bool = true) = ClipNorm{float(typeof(ω))}(ω,
init(o::ClipNorm, x::AbstractArray) = nothing

function apply!(o::ClipNorm, state, x, dx)
nrm = norm(dx, o.p)
nrm = _norm(dx, o.p)
if o.throw && !isfinite(nrm)
throw(DomainError("gradient has $(o.p)-norm $nrm, for array $(summary(x))"))
end
Expand All @@ -620,6 +620,29 @@ function apply!(o::ClipNorm, state, x, dx)
return state, @lazy dx * λ
end

_norm(dx::AbstractArray, p::Real) = norm(dx, p) # LinearAlgebra, CUDA
function _norm(dx::Broadcast.Broadcasted, p::Real)
if p == 2
# This lacks the undeflow/overflow tests of LinearAlgebra's version
sqrt(sum(abs2, dx))
elseif p == 1
float(sum(abs, dx))
elseif p == Inf
float(maximum(abs, dx))
elseif p == 0
cnt = count(!iszero, dx)
T = Base.@default_eltype dx
T <: Number ? convert(float(T), cnt) : cnt
elseif p == -Inf
float(minimum(abs, dx))
else
# This isn't optimally fast but does ensure p::Float64 doesn't promote
tmp = abs.(dx)
q = convert(float(eltype(tmp)), p)
sum(tmp .^ q) ^ (1/q)
end
end

"""
OptimiserChain(opts...)
Expand Down
17 changes: 16 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Optimisers
using ChainRulesCore, Functors, StaticArrays, Zygote, Yota
using LinearAlgebra, Statistics, Test, Random
using Optimisers: @.., @lazy
using Base.Broadcast: broadcasted, instantiate, Broadcasted

Random.seed!(1)

Expand Down Expand Up @@ -89,7 +90,8 @@ y2z(x) = x
_, m2 = Optimisers.update(s2, m, (α = ([0.1], nothing), γ = [1,10,100],))
@test only(m.α[1] .- m2.α[1]) 0.1
@test norm(m.γ .- m2.γ) 10
@test_throws DomainError Optimisers.update(s2, m, (α = [0.1], γ = [1,10,NaN],))
# This error is thrown by apply! due to NaN input.
@test_throws DomainError Optimisers.update(s2, m, (α = ([0.1], nothing), γ = [1,10,NaN],))

s3 = Optimisers.setup(ClipNorm(5, 1; throw=false), m)
_, m3 = Optimisers.update(s3, m, (α = ([0.1], nothing), γ = [1,10,100],))
Expand Down Expand Up @@ -506,6 +508,19 @@ y2z(x) = x
y = Optimisers.subtract!(x, nothing)
@test y === x
end

@testset "_norm(dx, p) works" begin
bc = instantiate(broadcasted(+, randn(Float32, 10), randn(Float32, 10)'));
arr = collect(bc)
bc2 = instantiate(broadcasted(+, [1, 0, -3, 4], 0))
arr2 = collect(bc2)
for p in (-Inf, -3, -1, 0, 0.5, 1, 1.5, 2, 3f0, Inf32)
@test Optimisers._norm(bc, p) norm(arr, p)
@test Optimisers._norm(bc, p) isa Float32
@test Optimisers._norm(bc2, p) norm(arr2, p)
@test Optimisers._norm(bc2, p) isa Float64
end
end
end
@testset verbose=true "Destructure" begin
include("destructure.jl")
Expand Down

2 comments on commit 6eaf26d

@mcabbott
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/88252

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.19 -m "<description of version>" 6eaf26da61f3f1f7f6f663b0d413e3c124742014
git push origin v0.2.19

Please sign in to comment.