Skip to content

Question about LBFGS #307

Answered by Algue-Rythme
DoTulip asked this question in Q&A
Sep 14, 2022 · 1 comments · 3 replies
Discussion options

You must be logged in to vote

Hi DoTulip

The tool jaxopt.ScipyMinimize is just a wrapper for Scipy - it is equivalent to calling Scipy.minimize on your function directly (same code is running hunder the hood). In particular this code is not jittable, does not benefit from GPU/TPU speed up. The only exception with Scipy is that it is actually possible to differentiate through the wrapper thanks to implicit differentiation.

The jaxopt.LBFGS is a pure re-implementation of L-BFGS in Jax: it is differentiable, run on GPU/TPU, can be wrapped in jax.jit. This should be your preferred tool if performance is an issue (this definitively what you want to use for the train_step function of your neural network).

For neural network…

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@Algue-Rythme
Comment options

@DoTulip
Comment options

@mblondel
Comment options

Answer selected by DoTulip
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants