Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional incremental training code #258

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
" * [Retrieve Training Artifacts](#3.1.-Retrieve-Training-Artifacts)\n",
" * [Set Training parameters](#3.2.-Set-Training-parameters)\n",
" * [Start Training](#3.3.-Start-Training)\n",
" * [Deploy and run inference on the fine-tuned model](#3.4.-Deploy-and-run-inference-on-the-fine-tuned-model)\n"
" * [Incremental Training](#3.4.-Incremental-Training)\n",
"4. [Deploy and run inference on the fine-tuned model](#3.4.-Deploy-and-run-inference-on-the-fine-tuned-model)\n"
]
},
{
Expand Down Expand Up @@ -86,6 +87,7 @@
"import sagemaker, boto3, json\n",
"from sagemaker import get_execution_role\n",
"import os\n",
"import tarfile\n",
"\n",
"try:\n",
" aws_role = sagemaker.get_execution_role()\n",
Expand Down Expand Up @@ -131,7 +133,6 @@
"source": [
"use_local_images = False # If False, notebook will use the example dataset provided by JumpStart\n",
"\n",
"\n",
"if not use_local_images:\n",
" # Downloading example dog images from JumpStart S3 bucket\n",
"\n",
Expand Down Expand Up @@ -452,12 +453,209 @@
"sd_estimator.fit({\"training\": train_s3_path}, logs=True)"
]
},
{
"cell_type": "markdown",
"id": "44a2a153-fe2c-4d08-9685-38bdc1fbab49",
"metadata": {},
"source": [
"# Skip to Section 3 to deploy and perform inference on your trained model"
]
},
{
"cell_type": "markdown",
"id": "5313df99-43fd-4031-8c3e-e20c60dcabb5",
"metadata": {},
"source": [
"### 2.4 *(Optional) Incremental Training* \n",
"---\n",
"The next optional section of code details the steps required to incrementally fine tune a stable diffusion model. The use case for this is multi-class training, for example if you would like to train the model on your dog, and your cat. While multi class training is currently not supported by jumpstart, we can circumvent this by fine tuning the already fine tuned model with a new class, i.e. we fine tune the pretrained model on images of your dog, we then save this model and train it again on images of your cat. This results in a model than is fine tuned on both images of your dog and your cat.\n",
"\n",
"As mentioned in the blog post, this is not officialy supported. It has also been found to work better on classes that are significantly different looking images (i.e. a bird and a dog) rather than more similar classes of images (different types of dogs). You will need to provide your own images to run the incremental training code.\n",
"\n",
"---"
]
},
{
"cell_type": "markdown",
"id": "c6fd61e0-908a-472b-87df-e4fc6a4930b4",
"metadata": {},
"source": [
"#### 2.4.1 Prepare images of the second class"
]
},
{
"cell_type": "markdown",
"id": "551de2e7-0ecc-4dc1-904f-0e535029f15f",
"metadata": {},
"source": [
"To run through the rest of the code we assume that you have uploaded ~10 images of a new animal in a local directory named 'training_images_2'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8e5cc45a-9c54-4549-ab13-dfedcb5065dd",
"metadata": {},
"outputs": [],
"source": [
"# local_training_dataset_folder_2 = \"training_images_2\"\n",
"\n",
"# # Example instance prompt below, replace this with a prompt relavent to your images\n",
"# instance_prompt = \"A photo of a Whiskers Cat\"\n",
"\n",
"\n",
"# # Create new manifest file\n",
"# with open(os.path.join(local_training_dataset_folder_2, \"dataset_info.json\"), \"w\") as f:\n",
"# f.write(json.dumps({\"instance_prompt\": instance_prompt}))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "253f16fd-8d8e-441a-932a-194af1534c87",
"metadata": {},
"outputs": [],
"source": [
"# # Copy locally stored data to S3 bucket in new directory\n",
"# train_s3_path = f\"s3://{training_bucket}/custom_dog_stable_diffusion_dataset/training_images_2\"\n",
"# !aws s3 cp --recursive $local_training_dataset_folder_2 $train_s3_path"
]
},
{
"cell_type": "markdown",
"id": "19856f33-4df9-4d79-844b-cdfb7438c658",
"metadata": {},
"source": [
"#### 2.4.2 Update saved (already fine tuned) model"
]
},
{
"cell_type": "markdown",
"id": "7536f021-c2ff-49b6-9658-58b915c9a4f6",
"metadata": {},
"source": [
"Currently the model is saved without a model__info.json file. To fine-tune this model, we need to download the model.tar.gz file, untar it, add in the missing file, then re-tar and upload to S3. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fb51128e-3536-450b-90ff-5c13082eddd2",
"metadata": {},
"outputs": [],
"source": [
"# # Download the fine-tuned model into a local 'model' file\n",
"# s3_model_location = f\"s3://{output_bucket}/{output_prefix}/output/{sd_estimator.latest_training_job.name}/output/\"\n",
"\n",
"# !aws s3 cp --recursive {s3_model_location} fine_tuned_model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ab05a314-b1e6-4b93-b44f-81b0d0cf88cd",
"metadata": {},
"outputs": [],
"source": [
"# # Create the missing __model_info.json file\n",
"# data = {\n",
"# \"use_schedular\": True,\n",
"# \"resolution\": 512\n",
"# }\n",
"\n",
"# # Write the data to a JSON file\n",
"# with open(\"__model_info__.json\", \"w\") as f:\n",
"# json.dump(data, f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d4b67cd1-e8d2-4601-bc49-a1001dfb9100",
"metadata": {},
"outputs": [],
"source": [
"# # Untar the model & Add in created __model__info.json file\n",
"\n",
"# # Local paths\n",
"# tar_gz_file_path = 'fine_tuned_model/model.tar.gz'\n",
"# extracted_folder = 'fine_tuned_model/un_tar/'\n",
"# file_to_copy = '__model_info__.json'\n",
"\n",
"# # Extract the tar file\n",
"# with tarfile.open(tar_gz_file_path, 'r:gz') as tar:\n",
"# tar.extractall(path=extracted_folder)\n",
"\n",
"# # Copy a file to the extracted folder\n",
"# copied_file_name = os.path.basename(file_to_copy)\n",
"# copied_file_path = os.path.join(extracted_folder, copied_file_name)\n",
"# os.replace(file_to_copy, copied_file_path) # Move the file to the extracted folder\n",
"\n",
"# # Re-tar the folder\n",
"# re_tar_file_path = 'model_updated.tar.gz'\n",
"# with tarfile.open(re_tar_file_path, 'w:gz') as tar:\n",
"# tar.add(extracted_folder, arcname=os.path.basename(extracted_folder))\n",
"\n",
"# print(\"Taring process completed successfully.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6b76ab9b-c810-4ef4-98ca-b4420131b24e",
"metadata": {},
"outputs": [],
"source": [
"# # Remove old model and upload new model to S3\n",
"\n",
"# !aws s3 rm {s3_model_location} --recursive\n",
"# !aws s3 cp model_updated.tar.gz {s3_model_location}"
]
},
{
"cell_type": "markdown",
"id": "e184a2b4-4fb0-42b0-b426-48e24e2e3392",
"metadata": {},
"source": [
"#### 2.4.3 Perform training on the updated model\n",
"\n",
"*We can assume we use the same hyperparameters and training artifacts as earlier*\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "01e15e53-ffe7-4377-af50-f01bb9df4d9b",
"metadata": {},
"outputs": [],
"source": [
"# %time\n",
"\n",
"# # Create SageMaker Estimator instance\n",
"# sd_estimator = Estimator(\n",
"# role=aws_role,\n",
"# image_uri=train_image_uri,\n",
"# source_dir=train_source_uri,\n",
"# model_uri=s3_model_location, # updated to point to the new model we have uploaded in S3.\n",
"# entry_point=\"transfer_learning.py\",\n",
"# instance_count=1,\n",
"# instance_type=training_instance_type,\n",
"# max_run=360000,\n",
"# hyperparameters=hyperparameters,\n",
"# output_path=s3_output_location,\n",
"# base_job_name=training_job_name,\n",
"# )\n",
"\n",
"# # Launch a SageMaker Training job by passing s3 path of the training data\n",
"# sd_estimator.fit({\"training\": train_s3_path}, logs=True)"
]
},
{
"cell_type": "markdown",
"id": "6fadc21e",
"metadata": {},
"source": [
"### 2.4. Deploy and run inference on the fine-tuned model\n",
"### 3. Deploy and run inference on the fine-tuned model\n",
"\n",
"---\n",
"\n",
Expand Down Expand Up @@ -573,6 +771,7 @@
},
"outputs": [],
"source": [
"# Update prompts to include references to both classes when performing inference on the incrementally trained model\n",
"all_prompts = [\n",
" \"A photo of a Doppler dog on a beach\",\n",
" \"A pencil sketch of a Doppler dog\",\n",
Expand Down Expand Up @@ -1117,9 +1316,9 @@
],
"instance_type": "ml.t3.medium",
"kernelspec": {
"display_name": "sagemaker:Python",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "conda-env-sagemaker-py"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -1131,7 +1330,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.15"
"version": "3.10.13"
},
"pycharm": {
"stem_cell": {
Expand Down