Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Collecting PyTorch -> Flux migration notes #2410

Open
BioTurboNick opened this issue Mar 25, 2024 · 1 comment
Open

Collecting PyTorch -> Flux migration notes #2410

BioTurboNick opened this issue Mar 25, 2024 · 1 comment

Comments

@BioTurboNick
Copy link
Contributor

BioTurboNick commented Mar 25, 2024

I'm in the process of moving a model from PyTorch to Flux, and I'm going to catalog the challenges I've found in migrating, perhaps for a doc page or to improve docs generally. If others wish to add anything, please do!

Weights initialization:

In PyTorch, the default weight init method is kaiming_uniform aka He initialization.
In Flux, the default weight init method is glorot_uniform aka Xavier initialization.

PyTorch chooses a gain for the init function based on the type of nonlinearity specified, which defaults to leaky_relu, and uses this:

a = √5 # argument passed into PyTorch's `kaiming_uniform_` function, the "negative slope" of the rectifier
gain = √(2 / (1 + a ^ 2))

Flux defaults the kaiming_uniform gain to √2, which is what PyTorch would use if relu was specified rather than leaky_relu.

To replicate the PyTorch default, kaiming_uniform(gain = √(2 / (1 + a ^ 2))) can be provided for the init keyword argument of the layer constructor.

Bias initialization:

PyTorch initializes bias parameters with uniformly random values between +/- 1 / √(fan_in), where fan_in in Flux is first(nfan(filter..., cin÷groups, cout)) for Conv layers. For Dense layers, last(nfan(out, in)) instead. Flux initializes them all to zero.

Layers:

In PyTorch, there are separate objects for different dimensionality (e.g. conv1d, conv2d, conv3d). In Flux, the dimensionality is specified by the tuple provided for the kernel of Conv.

In PyTorch, activation functions are inserted as separate steps in the chain, as equals to layers. In Flux, they are provided as an argument to the layer constructor.

Sequential => Chain

Linear => Dense

Upsample in Flux (via NNlib) is equivalent to align_corners=True with PyTorch's Upsample, but the default there is False. Note that this makes the gradients depend on image size.

When building a custom layer, the (::MyLayer)(input) = method is the equivalent of def forward(self, input):

Often if PyTorch has a method, Flux (many via MLUtils.jl) has the same method. e.g. unsqueeze to insert a 1-length dimension, or erf to compute the error function. Note that the inverse of unsqueeze is actually Base.dropdims.

Training:

A single step in Flux is simply gradient followed by update!. In PyTorch there are more steps: the optimizer's gradients must be zeroed with .zero_grad(), then the loss is calculated, then the tensor returned from the loss function is backward-propagated with .backward() to compute the gradients, and finally the optimizer is stepped forward and the model parameters are updated with .step(). In Flux, both parts can be combined with the train! function, and can also be used to iterate over a set of paired training inputs and outputs.

In Flux, an optimizer state object is first obtained by setup, and this state is passed to the training loop. In PyTorch, the optimizer object itself is manipulated in the training loop.

Added 10/13/24:

Upsample in PyTorch is actually deprecated in favor of nn.functional.interpolate, but the former just relies on the latter anyway.

clip_grad_norm_ in PyTorch (side note: they've adopted a trailing underscore to indicate a modifying function) can be accomplished by creating an Optimisers.OptimiserChain(ClipNorm(___), optimizer).

Added 10/25/24:

torch.where, which produces an array with elements of each type depending on a mask, can be accomplished with a broadcasted ifelse

AdamW optimizer is implemented differently. In PyTorch, the weight decay is moderated by the learning rate. In Flux, it is not. See FluxML/Optimisers.jl#182 for a workaround until Flux makes something built-in available.

@darsnack
Copy link
Member

@BioTurboNick Reported on Slack that with the default Flux initialization, his model would get stuck in an all zeros state, but not with the PyTorch init.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants