Skip to content

Commit

Permalink
Document base Loss class
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Oct 3, 2024
1 parent c1f2c92 commit 5418991
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 17 deletions.
7 changes: 6 additions & 1 deletion scripts/api_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,7 +1190,7 @@
"keras.ops.inv",
"keras.ops.logdet",
"keras.ops.lstsq",
"keras.ops.lu_factor",
"keras.ops.lu_factor",
"keras.ops.norm",
"keras.ops.qr",
"keras.ops.solve",
Expand Down Expand Up @@ -1493,6 +1493,9 @@
"path": "losses/",
"title": "Losses",
"toc": True,
"generate": [
"keras.losses.Loss",
],
"children": [
{
"path": "probabilistic_losses",
Expand Down Expand Up @@ -1526,6 +1529,7 @@
"keras.losses.Huber",
"keras.losses.LogCosh",
"keras.losses.Tversky",
"keras.losses.Dice",
"keras.losses.mean_squared_error",
"keras.losses.mean_absolute_error",
"keras.losses.mean_absolute_percentage_error",
Expand All @@ -1534,6 +1538,7 @@
"keras.losses.huber",
"keras.losses.log_cosh",
"keras.losses.tversky",
"keras.losses.dice",
],
},
{
Expand Down
2 changes: 1 addition & 1 deletion scripts/cv_api_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@
"generate": [
"keras_cv.models.RetinaNet",
"keras_cv.models.RetinaNet.from_preset",
"keras_cv.models.retinanet.PredictionHead"
"keras_cv.models.retinanet.PredictionHead",
],
},
{
Expand Down
2 changes: 1 addition & 1 deletion scripts/master.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"path": "keras_cv/",
"title": "KerasCV: Computer Vision Workflows",
},
{
{
"path": "keras_nlp/",
"title": "KerasNLP: Natural Language Workflows",
},
Expand Down
30 changes: 18 additions & 12 deletions templates/api/losses/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ and they perform reduction by default when used in a standalone way (see details

{{toc}}

---

## Base Loss API

{{autogenerated}}

---

Expand Down Expand Up @@ -74,8 +79,9 @@ A loss is a callable with arguments `loss_fn(y_true, y_pred, sample_weight=None)
By default, loss functions return one scalar loss value per input sample, e.g.

```
>>> keras.losses.mean_squared_error(tf.ones((2, 2,)), tf.zeros((2, 2)))
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1., 1.], dtype=float32)>
>>> from keras import ops
>>> keras.losses.mean_squared_error(ops.ones((2, 2,)), ops.zeros((2, 2)))
<Array: shape=(2,), dtype=float32, numpy=array([1., 1.], dtype=float32)>
```

However, loss class instances feature a `reduction` constructor argument,
Expand All @@ -89,18 +95,18 @@ which defaults to `"sum_over_batch_size"` (i.e. average). Allowable values are

```
>>> loss_fn = keras.losses.MeanSquaredError(reduction='sum_over_batch_size')
>>> loss_fn(tf.ones((2, 2,)), tf.zeros((2, 2)))
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>
>>> loss_fn(ops.ones((2, 2,)), ops.zeros((2, 2)))
<Array: shape=(), dtype=float32, numpy=1.0>
```
```
>>> loss_fn = keras.losses.MeanSquaredError(reduction='sum')
>>> loss_fn(tf.ones((2, 2,)), tf.zeros((2, 2)))
<tf.Tensor: shape=(), dtype=float32, numpy=2.0>
>>> loss_fn(ops.ones((2, 2,)), ops.zeros((2, 2)))
<Array: shape=(), dtype=float32, numpy=2.0>
```
```
>>> loss_fn = keras.losses.MeanSquaredError(reduction='none')
>>> loss_fn(tf.ones((2, 2,)), tf.zeros((2, 2)))
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1., 1.], dtype=float32)>
>>> loss_fn(ops.ones((2, 2,)), ops.zeros((2, 2)))
<Array: shape=(2,), dtype=float32, numpy=array([1., 1.], dtype=float32)>
```

Note that this is an important difference between loss functions like `keras.losses.mean_squared_error`
Expand All @@ -109,13 +115,13 @@ does not perform reduction, but by default the class instance does.

```
>>> loss_fn = keras.losses.mean_squared_error
>>> loss_fn(tf.ones((2, 2,)), tf.zeros((2, 2)))
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1., 1.], dtype=float32)>
>>> loss_fn(ops.ones((2, 2,)), ops.zeros((2, 2)))
<Array: shape=(2,), dtype=float32, numpy=array([1., 1.], dtype=float32)>
```
```
>>> loss_fn = keras.losses.MeanSquaredError()
>>> loss_fn(tf.ones((2, 2,)), tf.zeros((2, 2)))
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>
>>> loss_fn(ops.ones((2, 2,)), ops.zeros((2, 2)))
<Array: shape=(), dtype=float32, numpy=1.0>
```

When using `fit()`, this difference is irrelevant since reduction is handled by the framework.
Expand Down
3 changes: 1 addition & 2 deletions templates/api/optimizers/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,11 @@ Check out [the learning rate schedule API documentation](/api/optimizers/learnin

---

## Core Optimizer API
## Base Optimizer API

These methods and attributes are common to all Keras optimizers.

{{autogenerated}}




0 comments on commit 5418991

Please sign in to comment.