-
Notifications
You must be signed in to change notification settings - Fork 75
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #774 from basetenlabs/bump-version-0.7.22
Release 0.7.22
- Loading branch information
Showing
8 changed files
with
186 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
--- | ||
title: How to load model weights | ||
description: "Load model weights without Hugging Face or S3" | ||
--- | ||
|
||
Serving a model requires access to model files, such as model weights. These files are often many gigabytes. | ||
|
||
For many models, these files are loaded from Hugging Face. However, model files can come from other sources or be stored directly in the Truss. Model weights and other model data can be: | ||
|
||
* Public on Hugging Face (default, [example here](/examples/04-image-generation)) | ||
* [Private on Hugging Face](/examples/09-private-huggingface) | ||
* [Bundled directly with the Truss](#bundling-model-weights-in-truss) | ||
* [Public cloud storage like S3](#loading-public-model-weights-from-s3) | ||
* [Private cloud storage like S3](#loading-private-model-weights-from-s3) | ||
|
||
## Bundling model weights in Truss | ||
|
||
You can bundle model data directly with your model in Truss. To do so, use the Truss' `data` folder to store any necessary files. | ||
|
||
Here's an example of the `data` folder for [a Truss of Stable Diffusion 2.1](https://github.com/basetenlabs/truss-examples/tree/main/stable-diffusion/stable-diffusion). | ||
|
||
``` | ||
data/ | ||
scheduler/ | ||
scheduler_config.json | ||
text_encoder/ | ||
config.json | ||
diffusion_pytorch_model.bin | ||
tokenizer/ | ||
merges.txt | ||
special_tokens_map.json | ||
tokenizer_config.json | ||
vocab.json | ||
unet/ | ||
config.json | ||
diffusion_pytorch_model.bin | ||
vae/ | ||
config.json | ||
diffusion_pytorch_model.bin | ||
model_index.json | ||
``` | ||
|
||
To access the data in the model, use the `self._data_dir` variable in the `load()` function of `model/model.py`: | ||
|
||
```python | ||
class Model: | ||
def __init__(self, **kwargs) -> None: | ||
self._data_dir = kwargs["data_dir"] | ||
|
||
def load(self): | ||
self.model = StableDiffusionPipeline.from_pretrained( | ||
str(self._data_dir), # Set to "data" by default from config.yaml | ||
revision="fp16", | ||
torch_dtype=torch.float16, | ||
).to("cuda") | ||
``` | ||
|
||
## Loading public model weights from S3 | ||
|
||
Bundling multi-gigabyte files with your Truss can be difficult if you have limited local storage and can make deployment slow. Instead, you can store your model weights and other files in cloud storage like S3. | ||
|
||
Using files from S3 requires four steps: | ||
|
||
1. Uploading the content of your data directory to S3 | ||
2. Setting `external_data` in config.yaml | ||
3. Removing unneeded files from the `data` directory | ||
4. Accessing data correctly in the model | ||
|
||
Here's an example of that setup for Stable Diffusion, where we have already uploaded the content of our `data/` directory to S3. | ||
|
||
First, add the URLs for hosted versions of the large files to `config.yaml`: | ||
|
||
```yaml | ||
external_data: | ||
- url: https://baseten-public.s3.us-west-2.amazonaws.com/models/stable-diffusion-truss/unet/diffusion_pytorch_model.bin | ||
local_data_path: unet/diffusion_pytorch_model.bin | ||
- url: https://baseten-public.s3.us-west-2.amazonaws.com/models/stable-diffusion-truss/text_encoder/pytorch_model.bin | ||
local_data_path: text_encoder/pytorch_model.bin | ||
- url: https://baseten-public.s3.us-west-2.amazonaws.com/models/stable-diffusion-truss/vae/diffusion_pytorch_model.bin | ||
local_data_path: vae/diffusion_pytorch_model.bin | ||
``` | ||
Each URL matches with a local data path that represents where the model data would be stored if everything was bundled together locally. This is how your model code will know where to look for the data. | ||
Then, get rid of the large files from your `data` folder. The Stable Diffusion Truss has the following directory structure after large files are removed: | ||
|
||
``` | ||
data/ | ||
scheduler/ | ||
scheduler_config.json | ||
text_encoder/ | ||
config.json | ||
tokenizer/ | ||
merges.txt | ||
special_tokens_map.json | ||
tokenizer_config.json | ||
vocab.json | ||
unet/ | ||
config.json | ||
vae/ | ||
config.json | ||
model_index.json | ||
``` | ||
The code in `model/model.py` does not need to be changed and will automatically pull the large files from the provided links. | ||
## Loading private model weights from S3 | ||
If your model weights are proprietary, you'll be storing them in a private S3 bucket or similar access-restricted data store. Accessing these model files works exactly the same as above, but first uses [secrets](/guides/secrets) to securely authenticate your model with the data store. | ||
First, set the following secrets in `config.yaml`. Set the values to `null`, only the keys are needed here. | ||
```yaml | ||
secrets: | ||
aws_access_key_id: null | ||
aws_secret_access_key: null | ||
aws_region: null # e.g. us-east-1 | ||
aws_bucket: null | ||
``` | ||
|
||
Then, [add secrets to your Baseten account](https://docs.baseten.co/observability/secrets) for your AWS access key id, secret access key, region, and bucket. This time, use the actual values as they will be securely stored and provided to your model at runtime. | ||
|
||
In your model code, authenticate with AWS in the `__init__()` function: | ||
|
||
```python | ||
def __init__(self, **kwargs) -> None: | ||
self._config = kwargs.get("config") | ||
secrets = kwargs.get("secrets") | ||
self.s3_config = ( | ||
{ | ||
"aws_access_key_id": secrets["aws_access_key_id"], | ||
"aws_secret_access_key": secrets["aws_secret_access_key"], | ||
"aws_region": secrets["aws_region"], | ||
} | ||
) | ||
self.s3_bucket = (secrets["aws_bucket"]) | ||
``` | ||
|
||
You can then use the `boto3` package to access your model weights in `load()`. | ||
|
||
When you're ready to deploy your model, make sure to pass `is_trusted=True` to `baseten.deploy()`: | ||
|
||
```python | ||
import baseten | ||
import truss | ||
|
||
my_model = truss.load("my-model") | ||
baseten.deploy( | ||
my_model, | ||
model_name="My model", | ||
is_trusted=True | ||
) | ||
``` | ||
|
||
For further details, see [docs on using secrets in models](/guides/secrets). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters