You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I don't know why this is unstable, the ways of Zygote are mysterious sometimes.
The loss broadcasts this function, which contains odd things: abs_error .< δ is strange as these are scalars. And ignore_derivatives is strange as Zygote shouldn't go here... the broadcasting uses ForwardDiff, as you can confirm with @show. But commenting out that line doesn't fix anything.
julia>@eval Flux.Losses @inlinefunction_huber_metric(abs_error, δ)
#TODO: remove ignore_derivatives when Zygote can handle this function with CuArrays
temp =false# Zygote.ignore_derivatives(abs_error .< δ)
x =ofeltype(abs_error, 0.5)
@show δ
((abs_error * abs_error) * temp) * x + δ * (abs_error - x * δ) * (1- temp)
end
_huber_metric (generic function with 7 methods)
julia>wrapfunc(fc, fobs_ar, labels_ar, internfunc_nobroad_huberloss)
δ =Dual{Nothing}(1.0,0.0,1.0)
δ =Dual{Nothing}(1.0,0.0,1.0)
δ =Dual{Nothing}(1.0,0.0,1.0)
δ =Dual{Nothing}(1.0,0.0,1.0)
δ =Dual{Nothing}(1.0,0.0,1.0)
δ =Dual{Nothing}(1.0,0.0,1.0)
δ =Dual{Nothing}(1.0,0.0,1.0)
δ =Dual{Nothing}(1.0,0.0,1.0)
δ =Dual{Nothing}(1.0,0.0,1.0)
δ =Dual{Nothing}(1.0,0.0,1.0)
((layers = ((weight = Float32[0.00.0…0.00.0; 0.00.0…0.00.0; 0.00.0…0.00.0], bias = Float32[0.0, 0.0, 0.0], σ =nothing), (weight = Float32[0.00.00.0; 0.00.00.0; 0.00.00.0], bias = Float32[0.0, 0.0, 0.0], σ =nothing), (weight = Float32[0.00.00.0], bias = Float32[1.0000001], σ =nothing)),),)
It looks like
Flux.huber_loss
is type unstable when it comes to Zygote autodiff ?The text was updated successfully, but these errors were encountered: