Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more Duplicated methods for Enzyme.jl support #2471

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

mcabbott
Copy link
Member

This adds a method like gradient(f, ::Duplicated) which like train!(loss, model::Duplicated, data, opt) from #2446 uses the Duplicated type to signal that you want to use Enzyme not Zygote. It returns the gradient (for compatibility?) and mutates the Duplicated object.

  • To avoid piracy, this creates a new function Flux.gradient which by default calls Zygote.gradient. Unfortunately that's going to mean every using Flux, Zygote now produces ambiguities... so probably it should not be exported? Which means 0.15.

  • There's also withgradient but it doesn't allow you to return a tuple the way Zygote does, not yet.

  • There's also a method of update! which either needs to move to Optimisers.jl, or again we need to let Flux own the function.

  • Finally, @layer Chain defines a 1-argument Duplicated(c::Chain) method, so that you don't need to construct the dual by hand.

WIP, RFC?

Needs tests, and docs.

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant