Skip to content

Implicit Diff for hyperparam optimization with stochastic solver #227

Answered by Algue-Rythme
marcociccone asked this question in Q&A
Discussion options

You must be logged in to vote

Hi,

In Jaxopt the general rule is that the signature of the decorated function inner_loop_solver and the optimality_fun(that you pass to custom root) must match.

Hence:

  • the inner_loop_solver must always take some sort of init_params with the pytree shape/dtype as the one returned; even though this parameter is often ignored (only useful for warm starts)
  • all the "hyper-parameters" you are interested in (i.e l2reg in your case) must appear in both inner_loop_solver and optimality_fun

As you noticed, the data: Tuple[jnp.ndarray, jnp.ndarray] argument of l2_multiclass_logreg must appear in inner_loop_solver signature; and there is a very good reason for that.

implicit diff is not possible …

Replies: 3 comments

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Answer selected by marcociccone
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants