Implicit Diff for hyperparam optimization with stochastic solver #227
-
Hi! I'm playing around to better understand the mechanics of this great library before testing it in my current project. In the inner optimization loop I'm using a stochastic solver (sgd). First, is it possible to implicit diff through the stochastic solver? Thanks for your help!
|
Beta Was this translation helpful? Give feedback.
Replies: 3 comments
-
I debugged the error more, and I found out that the problem occurs when jaxopt calls the optimality condition Is that really necessary? I'm thinking that implicit diff is not possible with stochastic solvers at the moment. |
Beta Was this translation helpful? Give feedback.
-
Hi, In Jaxopt the general rule is that the signature of the decorated function Hence:
As you noticed, the
Let's take a step back at your problem for a moment. First we need to remember that Jax is a functional language without side effects that despites any some sort of modification of a global state; this allows easy translation between maths and code. You want to differentiate The drawback: an iterator is not a Jax object, it is even mutable (!!), which contradicts the Solution 1This solution is the well posed one that drops stochasticity of inner problem. The outer problem remains stochastic as you can see. @custom_root(jax.grad(inner_loss, has_aux=True), has_aux=True)
def inner_loop_solver(params, l2reg, data):
inner_sol = params
state = solver.init_state(inner_sol)
for idx in range(inner_iters):
print(idx)
batch = data[idx] # simulate iterations over minibatchs, with data a list or any pytree you like
inner_sol, state = solver.update(
params=inner_sol,
state=state,
l2reg=l2reg,
data=(batch[0].reshape(-1, 784) / 255., batch[1])
)
return inner_sol
# we now construct the outer loss and perform gradient descent on it
def outer_loss(l2reg, data):
inner_sol = inner_loop_solver(params, jnp.exp(l2reg), data)
print("Outer iter")
return objective.l2_multiclass_logreg(
W=inner_sol, l2reg=0, data=(images_val, labels_val)), inner_sol
gd_outer = GradientDescent(fun=outer_loss, tol=1e-3, maxiter=50, has_aux=True)
data = [batch for batch in ds_train.take(inner_iters)] # create the sequence of minibatchs
outer_state = gd_outer.init_state(l2reg)
for _ in range(outer_iters):
data = [batch for batch in ds_train.take(inner_iters)] # create the sequence of minibatchs
l2reg, outer_state = gd_outer.update(l2reg, outer_state, data) # use mini batchs for current inner minimization step This solution is really the only one that makes sense from a mathematical viewpoint. Solution 2As a second thought, we could actually want to find In this case, you can chose the representative mini-batch in advance and use it for implicit diff, hoping for the best: it is worth checking the value of @custom_root(jax.grad(inner_loss, has_aux=True), has_aux=True)
def inner_loop_solver(params, l2reg, data): # data is now your representative minibatch: you must choose it in advance
inner_sol = params
state = solver.init_state(inner_sol)
for idx in range(inner_iters):
print(idx)
batch = next(ds_train)
inner_sol, state = solver.update(
params=inner_sol,
state=state,
l2reg=l2reg,
data=(batch[0].reshape(-1, 784) / 255., batch[1])
)
return inner_sol
# we now construct the outer loss and perform gradient descent on it
def outer_loss(l2reg, data):
inner_sol = inner_loop_solver(params, jnp.exp(l2reg), data)
print("Outer iter")
return objective.l2_multiclass_logreg(
W=inner_sol, l2reg=0, data=(images_val, labels_val)), inner_sol
gd_outer = GradientDescent(fun=outer_loss, tol=1e-3, maxiter=50, has_aux=True)
data = next(ds_train)
outer_state = gd_outer.init_state(l2reg)
for _ in range(outer_iters):
data = next(ds_train) # use current minibatch as a "representative minibatch" of the whole dataset
l2reg, outer_state = gd_outer.update(l2reg, outer_state, data) # use representative mini batch for current inner minimization step The solution 2 is closer from what you are trying to achive. Giving formal guarantees about the soundness of this approach is certainly possible but requires some work and thinking about the well posedness of your problem. |
Beta Was this translation helpful? Give feedback.
-
Hi! One thing that I noticed is that implicit diff is actually faster and more precise than unrolling the computational graph, but the memory footprint doesn't seem to be constant as I expected. It feels like a memory leak, are you aware of anything like it? You can try my example monitoring the gpu memory allocated and check the difference between unrolling and implicit diff using for instance 500 outer iterations. |
Beta Was this translation helpful? Give feedback.
Hi,
In Jaxopt the general rule is that the signature of the decorated function
inner_loop_solver
and theoptimality_fun
(that you pass to custom root) must match.Hence:
init_params
with the pytree shape/dtype as the one returned; even though this parameter is often ignored (only useful for warm starts)As you noticed, the
data: Tuple[jnp.ndarray, jnp.ndarray]
argument of l2_multiclass_logreg must appear in inner_loop_solver signature; and there is a very good reason for that.