You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have been trying to take the derivative of a Flux model in testmode, and noticed that the BatchNorm layer behaves incorrectly for 4D and 5D CUDA-arrays.
Here is a MVE of this behaviour, computing the gradient of the BatchNorm for differently reshaped inputs:
using Flux, CUDA, Zygote
function gradient_varying_shape(m, x, n_dims, device)
m = m |> device
Flux.testmode!(m)
x = reshape(x, ntuple(i -> 1, n_dims)) |> device
return gradient(input -> sum(m(input).^2), x)[1] |> cpu
end
model = BatchNorm(1)
x = [1f0]
for i=2:7
cpu_gradient = gradient_varying_shape(model, x, i, cpu)
gpu_gradient = gradient_varying_shape(model, x, i, gpu)
println("n_dim=$i, cpu: $(cpu_gradient[1]), gpu: $(gpu_gradient[1])")
end
Looking through the Code, I found that the implementation of the CUDA backwards batchnorm here ignores the argument training. Could this be the origin of this behavior?
I'm using Julia 1.9.3 with NNlib version 0.9.7 and this environment:
I have been looking into fixing this issue, but I have a hard time understanding the function signature of cudnnBNBackward! and on what variable it is supposed to act. Is there some additional documentation or information on that?
No, but the good news is that it's merely a thin wrapper over cudnnBatchNormalizationBackward and that is documented by Nvidia. If you have any questions during your effort, I'd be happy to answer them.
I have been trying to take the derivative of a Flux model in testmode, and noticed that the BatchNorm layer behaves incorrectly for 4D and 5D CUDA-arrays.
Here is a MVE of this behaviour, computing the gradient of the BatchNorm for differently reshaped inputs:
This gives the following output for me:
Looking through the Code, I found that the implementation of the CUDA backwards batchnorm here ignores the argument
training
. Could this be the origin of this behavior?I'm using Julia 1.9.3 with NNlib version 0.9.7 and this environment:
The text was updated successfully, but these errors were encountered: