diff --git a/examples/keras_recipes/ipynb/packaging_keras_models_for_wide_distribution.ipynb b/examples/keras_recipes/ipynb/packaging_keras_models_for_wide_distribution.ipynb
new file mode 100644
index 0000000000..9a3a4a883c
--- /dev/null
+++ b/examples/keras_recipes/ipynb/packaging_keras_models_for_wide_distribution.ipynb
@@ -0,0 +1,487 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "# Packaging Keras models for wide distribution using Functional Subclassing\n",
+ "\n",
+ "**Author:** Martin G\u00f6rner
\n",
+ "**Date created:** 2023-12-15
\n",
+ "**Last modified:** 2023-12-15
\n",
+ "**Description:** When sharing your deep learning models, package them using the Functional Subclassing pattern."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Introduction\n",
+ "\n",
+ "Keras is the ideal framework for sharing your cutting-edge deep learning models, in a\n",
+ "library of pre-trained (or not) models. Millions of ML engineers are fluent in the\n",
+ "familiar Keras API, making your models accessible to a global community, whatever their\n",
+ "preferred backend (Jax, PyTorch or TensorFlow).\n",
+ "\n",
+ "One of the benefits of the Keras API is that it lets users programmatically inspect or\n",
+ "edit a model, a feature that is necessary when creating new architectures or workflows\n",
+ "based on a pre-trained model.\n",
+ "\n",
+ "When distributing models, the Keras team recommends packaging them using the **Functional\n",
+ "Subclassing** pattern. Models implemented in this way combine two benefits:\n",
+ "\n",
+ "* They can be instantiated in the normal pythonic way:
\n",
+ "`model = model_collection_xyz.AmazingModel()`\n",
+ "\n",
+ "* They are Keras functional models which means that they have a programmatically\n",
+ "accessible graph of layers, for introspection or model surgery.\n",
+ "\n",
+ "This guide explains [how to use](#functional-subclassing-model) the Functional\n",
+ "Subclassing pattern, and showcases its benefits for [programmatic model\n",
+ "introspection](#model-introspection) and [model surgery](#model-surgery). It also shows\n",
+ "two other best practices for sharable Keras models: [configuring\n",
+ "models](#unconstrained-inputs) for the widest range of supported inputs, for example\n",
+ "images of various sizes, and [using dictionary inputs](#model-with-dictionary-inputs) for\n",
+ "clarity in more complex models."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "import keras\n",
+ "import tensorflow as tf # only for tf.data\n",
+ "\n",
+ "print(\"Keras version\", keras.version())\n",
+ "print(\"Keras is running on\", keras.config.backend())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Dataset\n",
+ "\n",
+ "Let's load an MNIST dataset so that we have something to train with."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "# tf.data is a great API for putting together a data stream.\n",
+ "# It works wether you use the TensorFlow, PyTorch or Jax backend,\n",
+ "# as long as you use it in the data stream only and not inside of a model.\n",
+ "\n",
+ "BATCH_SIZE = 256\n",
+ "\n",
+ "(x_train, train_labels), (x_test, test_labels) = tf.keras.datasets.mnist.load_data()\n",
+ "\n",
+ "train_data = tf.data.Dataset.from_tensor_slices((x_train, train_labels))\n",
+ "train_data = train_data.map(\n",
+ " lambda x, y: (tf.expand_dims(x, axis=-1), y)\n",
+ ") # 1-channel monochrome\n",
+ "train_data = train_data.batch(BATCH_SIZE)\n",
+ "train_data = train_data.cache()\n",
+ "train_data = train_data.shuffle(5000, reshuffle_each_iteration=True)\n",
+ "train_data = train_data.repeat()\n",
+ "\n",
+ "test_data = tf.data.Dataset.from_tensor_slices((x_test, test_labels))\n",
+ "test_data = test_data.map(\n",
+ " lambda x, y: (tf.expand_dims(x, axis=-1), y)\n",
+ ") # 1-channel monochrome\n",
+ "test_data = test_data.batch(10000)\n",
+ "test_data = test_data.cache()\n",
+ "\n",
+ "STEPS_PER_EPOCH = len(train_labels) // BATCH_SIZE\n",
+ "EPOCHS = 5"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Functional Subclassing Model\n",
+ "\n",
+ "The model is wrapped in a class so that end users can instantiate it normally by calling\n",
+ "the constructor `MnistModel()` rather than calling a factory function."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "class MnistModel(keras.Model):\n",
+ " def __init__(self, **kwargs):\n",
+ " # Keras Functional model definition. This could have used Sequential as\n",
+ " # well. Sequential is just syntactic sugar for simple functional models.\n",
+ "\n",
+ " # 1-channel monochrome input\n",
+ " inputs = keras.layers.Input(shape=(None, None, 1), dtype=\"uint8\")\n",
+ " # pixel format conversion from uint8 to float32\n",
+ " y = keras.layers.Rescaling(1 / 255.0)(inputs)\n",
+ "\n",
+ " # 3 convolutional layers\n",
+ " y = keras.layers.Conv2D(\n",
+ " filters=16, kernel_size=3, padding=\"same\", activation=\"relu\"\n",
+ " )(y)\n",
+ " y = keras.layers.Conv2D(\n",
+ " filters=32, kernel_size=6, padding=\"same\", activation=\"relu\", strides=2\n",
+ " )(y)\n",
+ " y = keras.layers.Conv2D(\n",
+ " filters=48, kernel_size=6, padding=\"same\", activation=\"relu\", strides=2\n",
+ " )(y)\n",
+ "\n",
+ " # 2 dense layers\n",
+ " y = keras.layers.GlobalAveragePooling2D()(y)\n",
+ " y = keras.layers.Dense(48, activation=\"relu\")(y)\n",
+ " y = keras.layers.Dropout(0.4)(y)\n",
+ " outputs = keras.layers.Dense(\n",
+ " 10, activation=\"softmax\", name=\"classification_head\" # 10 classes\n",
+ " )(y)\n",
+ "\n",
+ " # A Keras Functional model is created by calling keras.Model(inputs, outputs)\n",
+ " super().__init__(inputs=inputs, outputs=outputs, **kwargs)\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "Let's instantiate and train this model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "model = MnistModel()\n",
+ "\n",
+ "model.compile(\n",
+ " optimizer=\"adam\",\n",
+ " loss=\"sparse_categorical_crossentropy\",\n",
+ " metrics=[\"sparse_categorical_accuracy\"],\n",
+ ")\n",
+ "\n",
+ "history = model.fit(\n",
+ " train_data,\n",
+ " steps_per_epoch=STEPS_PER_EPOCH,\n",
+ " epochs=EPOCHS,\n",
+ " validation_data=test_data,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Unconstrained inputs\n",
+ "\n",
+ "Notice, in the model definition above, that the input is specified with undefined\n",
+ "dimensions: `Input(shape=(None, None, 1)`\n",
+ "\n",
+ "This allows the model to accept any image size as an input. However, this\n",
+ "only works if the loosely defined shape can be propagated through all the layers and\n",
+ "still determine the size of all weights.\n",
+ "\n",
+ "* So if you have a model architecture that can handle different input sizes\n",
+ "with the same weights (like here), then your users will be able to instantiate it without\n",
+ "parameters:
`model = MnistModel()`\n",
+ "\n",
+ "* If on the other hand, the model must provision different weights for different input\n",
+ "sizes, you will have to ask your users to specify the size in the constructor:
\n",
+ "`model = ModelXYZ(input_size=...)`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Model introspection\n",
+ "\n",
+ "Keras maintains a programmatically accessible graph of layers for every model. It can be\n",
+ "used for introspection and is accessed through the `model.layers` or `layer.layers`\n",
+ "attribute. The utility function `model.summary()` also uses this mechanism internally."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "model = MnistModel()\n",
+ "\n",
+ "# Model summary works\n",
+ "model.summary()\n",
+ "\n",
+ "\n",
+ "# Recursively walking the layer graph works as well\n",
+ "def walk_layers(layer):\n",
+ " if hasattr(layer, \"layers\"):\n",
+ " for layer in layer.layers:\n",
+ " walk_layers(layer)\n",
+ " else:\n",
+ " print(layer.name)\n",
+ "\n",
+ "\n",
+ "print(\"\\nWalking model layers:\\n\")\n",
+ "walk_layers(model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Model surgery\n",
+ "\n",
+ "End users might want to instantiate the model from your library but modify it before use.\n",
+ "Functional models have a programmatically accessible graph of layers. Edits are possible\n",
+ "by slicing and splicing the graph and creating a new functional model.\n",
+ "\n",
+ "The alternative is to fork the model code and make the modifications but that forces\n",
+ "users to then maintain their fork indefinitely.\n",
+ "\n",
+ "Example: instantiate the model but change the classification head to do a binary\n",
+ "classification, \"0\" or \"not 0\", instead of the original 10-way digits classification."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "model = MnistModel()\n",
+ "\n",
+ "input = model.input\n",
+ "# cut before the classification head\n",
+ "y = model.get_layer(\"classification_head\").input\n",
+ "\n",
+ "# add a new classification head\n",
+ "output = keras.layers.Dense(\n",
+ " 1, # single class for binary classification\n",
+ " activation=\"sigmoid\",\n",
+ " name=\"binary_classification_head\",\n",
+ ")(y)\n",
+ "\n",
+ "# create a new functional model\n",
+ "binary_model = keras.Model(input, output)\n",
+ "\n",
+ "binary_model.summary()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "We can now train the new model as a binary classifier."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "# new dataset with 0 / 1 labels (1 = digit '0', 0 = all other digits)\n",
+ "bin_train_data = train_data.map(\n",
+ " lambda x, y: (x, tf.cast(tf.math.equal(y, tf.zeros_like(y)), dtype=tf.uint8))\n",
+ ")\n",
+ "bin_test_data = test_data.map(\n",
+ " lambda x, y: (x, tf.cast(tf.math.equal(y, tf.zeros_like(y)), dtype=tf.uint8))\n",
+ ")\n",
+ "\n",
+ "# appropriate loss and metric for binary classification\n",
+ "binary_model.compile(\n",
+ " optimizer=\"adam\", loss=\"binary_crossentropy\", metrics=[\"binary_accuracy\"]\n",
+ ")\n",
+ "\n",
+ "history = binary_model.fit(\n",
+ " bin_train_data,\n",
+ " steps_per_epoch=STEPS_PER_EPOCH,\n",
+ " epochs=EPOCHS,\n",
+ " validation_data=bin_test_data,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Model with dictionary inputs\n",
+ "\n",
+ "In more complex models, with multiple inputs, structuring the inputs as a dictionary can\n",
+ "improve readability and usability. This is straightforward to do with a functional model:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "class MnistDictModel(keras.Model):\n",
+ " def __init__(self, **kwargs):\n",
+ " #\n",
+ " # The input is a dictionary\n",
+ " #\n",
+ " inputs = {\n",
+ " \"image\": keras.layers.Input(\n",
+ " shape=(None, None, 1), # 1-channel monochrome\n",
+ " dtype=\"uint8\",\n",
+ " name=\"image\",\n",
+ " )\n",
+ " }\n",
+ "\n",
+ " # pixel format conversion from uint8 to float32\n",
+ " y = keras.layers.Rescaling(1 / 255.0)(inputs[\"image\"])\n",
+ "\n",
+ " # 3 conv layers\n",
+ " y = keras.layers.Conv2D(\n",
+ " filters=16, kernel_size=3, padding=\"same\", activation=\"relu\"\n",
+ " )(y)\n",
+ " y = keras.layers.Conv2D(\n",
+ " filters=32, kernel_size=6, padding=\"same\", activation=\"relu\", strides=2\n",
+ " )(y)\n",
+ " y = keras.layers.Conv2D(\n",
+ " filters=48, kernel_size=6, padding=\"same\", activation=\"relu\", strides=2\n",
+ " )(y)\n",
+ "\n",
+ " # 2 dense layers\n",
+ " y = keras.layers.GlobalAveragePooling2D()(y)\n",
+ " y = keras.layers.Dense(48, activation=\"relu\")(y)\n",
+ " y = keras.layers.Dropout(0.4)(y)\n",
+ " outputs = keras.layers.Dense(\n",
+ " 10, activation=\"softmax\", name=\"classification_head\" # 10 classes\n",
+ " )(y)\n",
+ "\n",
+ " # A Keras Functional model is created by calling keras.Model(inputs, outputs)\n",
+ " super().__init__(inputs=inputs, outputs=outputs, **kwargs)\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "We can now train the model on inputs structured as a dictionary."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "model = MnistDictModel()\n",
+ "\n",
+ "# reformat the dataset as a dictionary\n",
+ "dict_train_data = train_data.map(lambda x, y: ({\"image\": x}, y))\n",
+ "dict_test_data = test_data.map(lambda x, y: ({\"image\": x}, y))\n",
+ "\n",
+ "model.compile(\n",
+ " optimizer=\"adam\",\n",
+ " loss=\"sparse_categorical_crossentropy\",\n",
+ " metrics=[\"sparse_categorical_accuracy\"],\n",
+ ")\n",
+ "\n",
+ "history = model.fit(\n",
+ " dict_train_data,\n",
+ " steps_per_epoch=STEPS_PER_EPOCH,\n",
+ " epochs=EPOCHS,\n",
+ " validation_data=dict_test_data,\n",
+ ")"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [],
+ "name": "packaging_keras_models_for_wide_distribution",
+ "private_outputs": false,
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/examples/keras_recipes/md/packaging_keras_models_for_wide_distribution.md b/examples/keras_recipes/md/packaging_keras_models_for_wide_distribution.md
new file mode 100644
index 0000000000..d64f91a27f
--- /dev/null
+++ b/examples/keras_recipes/md/packaging_keras_models_for_wide_distribution.md
@@ -0,0 +1,505 @@
+# Packaging Keras models for wide distribution using Functional Subclassing
+
+**Author:** Martin Görner
+**Date created:** 2023-12-15
+**Last modified:** 2023-12-15
+**Description:** When sharing your deep learning models, package them using the Functional Subclassing pattern.
+
+
+ [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_recipes/ipynb/packaging_keras_models_for_wide_distribution.ipynb) • [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_recipes/packaging_keras_models_for_wide_distribution.py)
+
+
+
+---
+## Introduction
+
+Keras is the ideal framework for sharing your cutting-edge deep learning models, in a
+library of pre-trained (or not) models. Millions of ML engineers are fluent in the
+familiar Keras API, making your models accessible to a global community, whatever their
+preferred backend (Jax, PyTorch or TensorFlow).
+
+One of the benefits of the Keras API is that it lets users programmatically inspect or
+edit a model, a feature that is necessary when creating new architectures or workflows
+based on a pre-trained model.
+
+When distributing models, the Keras team recommends packaging them using the **Functional
+Subclassing** pattern. Models implemented in this way combine two benefits:
+
+* They can be instantiated in the normal pythonic way:
+`model = model_collection_xyz.AmazingModel()`
+
+* They are Keras functional models which means that they have a programmatically
+accessible graph of layers, for introspection or model surgery.
+
+This guide explains [how to use](#functional-subclassing-model) the Functional
+Subclassing pattern, and showcases its benefits for [programmatic model
+introspection](#model-introspection) and [model surgery](#model-surgery). It also shows
+two other best practices for sharable Keras models: [configuring
+models](#unconstrained-inputs) for the widest range of supported inputs, for example
+images of various sizes, and [using dictionary inputs](#model-with-dictionary-inputs) for
+clarity in more complex models.
+
+---
+## Setup
+
+
+```python
+import keras
+import tensorflow as tf # only for tf.data
+
+print("Keras version", keras.version())
+print("Keras is running on", keras.config.backend())
+```
+
+
Model: "mnist_model_1"
+
+
+
+
+
+┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ +│ input_layer_1 (InputLayer) │ (None, None, None, 1) │ 0 │ +├─────────────────────────────────┼───────────────────────────┼────────────┤ +│ rescaling_1 (Rescaling) │ (None, None, None, 1) │ 0 │ +├─────────────────────────────────┼───────────────────────────┼────────────┤ +│ conv2d_3 (Conv2D) │ (None, None, None, 16) │ 160 │ +├─────────────────────────────────┼───────────────────────────┼────────────┤ +│ conv2d_4 (Conv2D) │ (None, None, None, 32) │ 18,464 │ +├─────────────────────────────────┼───────────────────────────┼────────────┤ +│ conv2d_5 (Conv2D) │ (None, None, None, 48) │ 55,344 │ +├─────────────────────────────────┼───────────────────────────┼────────────┤ +│ global_average_pooling2d_1 │ (None, 48) │ 0 │ +│ (GlobalAveragePooling2D) │ │ │ +├─────────────────────────────────┼───────────────────────────┼────────────┤ +│ dense_1 (Dense) │ (None, 48) │ 2,352 │ +├─────────────────────────────────┼───────────────────────────┼────────────┤ +│ dropout_1 (Dropout) │ (None, 48) │ 0 │ +├─────────────────────────────────┼───────────────────────────┼────────────┤ +│ classification_head (Dense) │ (None, 10) │ 490 │ +└─────────────────────────────────┴───────────────────────────┴────────────┘ ++ + + + +
Total params: 76,810 (300.04 KB) ++ + + + +
Trainable params: 76,810 (300.04 KB) ++ + + + +
Non-trainable params: 0 (0.00 B) ++ + + + +
Model: "functional_1"
+
+
+
+
+
+┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ +│ input_layer_2 (InputLayer) │ (None, None, None, 1) │ 0 │ +├─────────────────────────────────┼───────────────────────────┼────────────┤ +│ rescaling_2 (Rescaling) │ (None, None, None, 1) │ 0 │ +├─────────────────────────────────┼───────────────────────────┼────────────┤ +│ conv2d_6 (Conv2D) │ (None, None, None, 16) │ 160 │ +├─────────────────────────────────┼───────────────────────────┼────────────┤ +│ conv2d_7 (Conv2D) │ (None, None, None, 32) │ 18,464 │ +├─────────────────────────────────┼───────────────────────────┼────────────┤ +│ conv2d_8 (Conv2D) │ (None, None, None, 48) │ 55,344 │ +├─────────────────────────────────┼───────────────────────────┼────────────┤ +│ global_average_pooling2d_2 │ (None, 48) │ 0 │ +│ (GlobalAveragePooling2D) │ │ │ +├─────────────────────────────────┼───────────────────────────┼────────────┤ +│ dense_2 (Dense) │ (None, 48) │ 2,352 │ +├─────────────────────────────────┼───────────────────────────┼────────────┤ +│ dropout_2 (Dropout) │ (None, 48) │ 0 │ +├─────────────────────────────────┼───────────────────────────┼────────────┤ +│ binary_classification_head │ (None, 1) │ 49 │ +│ (Dense) │ │ │ +└─────────────────────────────────┴───────────────────────────┴────────────┘ ++ + + + +
Total params: 76,369 (298.32 KB) ++ + + + +
Trainable params: 76,369 (298.32 KB) ++ + + + +
Non-trainable params: 0 (0.00 B) ++ + + +We can now train the new model as a binary classifier. + + +```python +# new dataset with 0 / 1 labels (1 = digit '0', 0 = all other digits) +bin_train_data = train_data.map( + lambda x, y: (x, tf.cast(tf.math.equal(y, tf.zeros_like(y)), dtype=tf.uint8)) +) +bin_test_data = test_data.map( + lambda x, y: (x, tf.cast(tf.math.equal(y, tf.zeros_like(y)), dtype=tf.uint8)) +) + +# appropriate loss and metric for binary classification +binary_model.compile( + optimizer="adam", loss="binary_crossentropy", metrics=["binary_accuracy"] +) + +history = binary_model.fit( + bin_train_data, + steps_per_epoch=STEPS_PER_EPOCH, + epochs=EPOCHS, + validation_data=bin_test_data, +) +``` + +