diff --git a/guides/writing_a_custom_training_loop_in_jax.py b/guides/writing_a_custom_training_loop_in_jax.py index 107a72c634..86c9353258 100644 --- a/guides/writing_a_custom_training_loop_in_jax.py +++ b/guides/writing_a_custom_training_loop_in_jax.py @@ -154,7 +154,7 @@ def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y) ``` Once you have such a function, you can get the gradient function by -specifying `hax_aux` in `value_and_grad`: it tells JAX that the loss +specifying `has_aux` in `value_and_grad`: it tells JAX that the loss computation function returns more outputs than just the loss. Note that the loss should always be the first output.