Question about LBFGS #307
-
Hello JAXopt Team, |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Hi DoTulip The tool The For neural network I assume you are interested in a stochastic variant of |
Beta Was this translation helpful? Give feedback.
Hi DoTulip
The tool
jaxopt.ScipyMinimize
is just a wrapper for Scipy - it is equivalent to callingScipy.minimize
on your function directly (same code is running hunder the hood). In particular this code is not jittable, does not benefit from GPU/TPU speed up. The only exception with Scipy is that it is actually possible to differentiate through the wrapper thanks to implicit differentiation.The
jaxopt.LBFGS
is a pure re-implementation of L-BFGS in Jax: it is differentiable, run on GPU/TPU, can be wrapped injax.jit
. This should be your preferred tool if performance is an issue (this definitively what you want to use for thetrain_step
function of your neural network).For neural network…