Skip to content

Commit

Permalink
Fix variable name typo in text (#1887)
Browse files Browse the repository at this point in the history
  • Loading branch information
kralka authored Jul 20, 2024
1 parent bd3a790 commit fdfc41b
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion guides/writing_a_custom_training_loop_in_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y)
```
Once you have such a function, you can get the gradient function by
specifying `hax_aux` in `value_and_grad`: it tells JAX that the loss
specifying `has_aux` in `value_and_grad`: it tells JAX that the loss
computation function returns more outputs than just the loss. Note that the loss
should always be the first output.
Expand Down

0 comments on commit fdfc41b

Please sign in to comment.