Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update knowledge distillation for Keras3 #1577

Closed
wants to merge 29 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
272fa95
Add multi-backend examples/nlp/bidirectional_lstm_imdb (#1569)
mattdangerw Nov 4, 2023
727d762
Minor fixes
fchollet Nov 4, 2023
9c9f194
Merge branch 'keras-3' of github.com:keras-team/keras-io into keras-3
fchollet Nov 4, 2023
5248cac
Try/except around domain package imports
fchollet Nov 4, 2023
ebe14fb
Add MNIST convnet example for Keras 3
fchollet Nov 4, 2023
65a1b76
Add multi-backend examples/nlp/addition_rnn (#1570)
mattdangerw Nov 4, 2023
1549a79
Add attention_mil_classification for Keras 3
fchollet Nov 4, 2023
069f641
Merge branch 'keras-3' of github.com:keras-team/keras-io into HEAD
fchollet Nov 4, 2023
1365480
Tutobooks fixes
fchollet Nov 4, 2023
52d07c5
Merge branch 'keras-3' of github.com:keras-team/keras-io into keras-3
fchollet Nov 4, 2023
cbb5b60
Update CCT example.
fchollet Nov 4, 2023
1c387ae
Convert convmixer example
fchollet Nov 5, 2023
33a8871
Merge branch 'keras-3' of github.com:keras-team/keras-io into HEAD
fchollet Nov 5, 2023
b00ea42
Update contributing.md
fchollet Nov 5, 2023
7c31058
Merge branch 'keras-3' of github.com:keras-team/keras-io into keras-3
fchollet Nov 5, 2023
855e37d
Update guides, templates (wip)
fchollet Nov 5, 2023
e8a126a
update distributed training guides
fchollet Nov 6, 2023
623290c
Update more guides to Keras 3
fchollet Nov 6, 2023
732a3a3
Update custom training loop guides
fchollet Nov 6, 2023
05b64da
Merge branch 'keras-3' of github.com:keras-team/keras-io into keras-3
fchollet Nov 6, 2023
dd90b6d
Add the rendered transfer learning example (#1572)
mattdangerw Nov 7, 2023
8530fac
Update templates
fchollet Nov 7, 2023
3c00024
Merge branch 'keras-3' of github.com:keras-team/keras-io into keras-3
fchollet Nov 7, 2023
af05bf4
Update guides
fchollet Nov 7, 2023
27fa3a2
Update image_classification_with_vision_transformer. (#1573)
qlzh727 Nov 7, 2023
f7735df
Update vision/oxford_pets_image_segmentation (#1574)
qlzh727 Nov 7, 2023
6cd83a8
Update pixelcnn. (#1575)
qlzh727 Nov 7, 2023
356e619
Update vision/knowledge_distillation example
grasskin Nov 8, 2023
0412848
Update vision/knowledge_distillation example
grasskin Nov 8, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added examples/generative/img/pixelcnn/pixelcnn_10_57.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/generative/img/pixelcnn/pixelcnn_10_58.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/generative/img/pixelcnn/pixelcnn_10_59.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/generative/img/pixelcnn/pixelcnn_10_60.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
44 changes: 21 additions & 23 deletions examples/generative/ipynb/pixelcnn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
"probability distribution of later elements. In the following example, images are generated\n",
"in this fashion, pixel-by-pixel, via a masked convolution kernel that only looks at data\n",
"from previously generated pixels (origin at the top left) to generate later pixels.\n",
"During inference, the output of the network is used as a probability distribution\n",
"During inference, the output of the network is used as a probability ditribution\n",
"from which new pixel values are sampled to generate a new image\n",
"(here, with MNIST, the pixel values range from white (0) to black (255).\n"
"(here, with MNIST, the pixels values are either black or white)."
]
},
{
Expand All @@ -43,10 +43,10 @@
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"from tqdm import tqdm\n"
"import keras\n",
"from keras import layers\n",
"from keras import ops\n",
"from tqdm import tqdm"
]
},
{
Expand All @@ -55,7 +55,7 @@
"colab_type": "text"
},
"source": [
"## Getting the data\n"
"## Getting the Data"
]
},
{
Expand All @@ -78,7 +78,7 @@
"# anything above this value gets rounded up to 1 so that all values are either\n",
"# 0 or 1\n",
"data = np.where(data < (0.33 * 256), 0, 1)\n",
"data = data.astype(np.float32)\n"
"data = data.astype(np.float32)"
]
},
{
Expand All @@ -87,7 +87,7 @@
"colab_type": "text"
},
"source": [
"## Create two classes for the requisite Layers for the model\n"
"## Create two classes for the requisite Layers for the model"
]
},
{
Expand All @@ -98,6 +98,7 @@
},
"outputs": [],
"source": [
"\n",
"# The first layer is the PixelCNN layer. This layer simply\n",
"# builds on the 2D convolutional layer, but includes masking.\n",
"class PixelConvLayer(layers.Layer):\n",
Expand All @@ -110,7 +111,7 @@
" # Build the conv2d layer to initialize kernel variables\n",
" self.conv.build(input_shape)\n",
" # Use the initialized kernel to create the mask\n",
" kernel_shape = self.conv.kernel.get_shape()\n",
" kernel_shape = ops.shape(self.conv.kernel)\n",
" self.mask = np.zeros(shape=kernel_shape)\n",
" self.mask[: kernel_shape[0] // 2, ...] = 1.0\n",
" self.mask[kernel_shape[0] // 2, : kernel_shape[1] // 2, ...] = 1.0\n",
Expand Down Expand Up @@ -146,7 +147,7 @@
" x = self.pixel_conv(x)\n",
" x = self.conv2(x)\n",
" return keras.layers.add([inputs, x])\n",
"\n"
""
]
},
{
Expand All @@ -155,7 +156,7 @@
"colab_type": "text"
},
"source": [
"## Build the model based on the original paper\n"
"## Build the model based on the original paper"
]
},
{
Expand All @@ -166,7 +167,7 @@
},
"outputs": [],
"source": [
"inputs = keras.Input(shape=input_shape)\n",
"inputs = keras.Input(shape=input_shape, batch_size=128)\n",
"x = PixelConvLayer(\n",
" mask_type=\"A\", filters=128, kernel_size=7, activation=\"relu\", padding=\"same\"\n",
")(inputs)\n",
Expand Down Expand Up @@ -195,7 +196,7 @@
"pixel_cnn.summary()\n",
"pixel_cnn.fit(\n",
" x=data, y=data, batch_size=128, epochs=50, validation_split=0.1, verbose=2\n",
")\n"
")"
]
},
{
Expand All @@ -208,10 +209,7 @@
"\n",
"The PixelCNN cannot generate the full image at once. Instead, it must generate each pixel in\n",
"order, append the last generated pixel to the current image, and feed the image back into the\n",
"model to repeat the process.\n",
"\n",
"You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/pixel-cnn-mnist) ",
"and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/pixelcnn-mnist-image-generation)."
"model to repeat the process."
]
},
{
Expand All @@ -238,13 +236,13 @@
" probs = pixel_cnn.predict(pixels)[:, row, col, channel]\n",
" # Use the probabilities to pick pixel values and append the values to the image\n",
" # frame.\n",
" pixels[:, row, col, channel] = tf.math.ceil(\n",
" probs - tf.random.uniform(probs.shape)\n",
" pixels[:, row, col, channel] = ops.ceil(\n",
" probs - keras.random.uniform(probs.shape)\n",
" )\n",
"\n",
"\n",
"def deprocess_image(x):\n",
" # Stack the single channeled black and white image to RGB values.\n",
" # Stack the single channeled black and white image to rgb values.\n",
" x = np.stack((x, x, x), 2)\n",
" # Undo preprocessing\n",
" x *= 255.0\n",
Expand All @@ -262,7 +260,7 @@
"display(Image(\"generated_image_0.png\"))\n",
"display(Image(\"generated_image_1.png\"))\n",
"display(Image(\"generated_image_2.png\"))\n",
"display(Image(\"generated_image_3.png\"))\n"
"display(Image(\"generated_image_3.png\"))"
]
}
],
Expand Down Expand Up @@ -295,4 +293,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
Loading
Loading