-
Notifications
You must be signed in to change notification settings - Fork 75
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* added model cache docs * added model cache docs * add fixes * typo * fix * remove whitespace * fixed whitespace issue --------- Co-authored-by: Sidharth Shanker <sid.shanker@baseten.co>
- Loading branch information
Showing
3 changed files
with
122 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
--- | ||
title: Caching model weights | ||
description: "Accelerate cold starts by caching your weights" | ||
--- | ||
|
||
Truss natively supports automatic caching for model weights. This is a simple yet effective strategy to enhance deployment speed and operational efficiency when it comes to cold starts and scaling beyond a single replica. | ||
|
||
<Tip> | ||
### What is a "cold start"? | ||
|
||
"Cold start" is a term used to refer to the time taken by a model to boot up after being idle. This process can become a critical factor in serverless environments, as it can significantly influence the model response time, customer satisfaction, and cost. | ||
|
||
Without caching our model's weights, we would need to download weights every time we scale up. Caching model weights circumvents this download process. When our new instance boots up, the server automatically finds the cached weights and can proceed with starting up the endpoint. | ||
|
||
In practice, this reduces the cold start for large models to just a few seconds. For example, Stable Diffusion XL can take a few minutes to boot up without caching. With caching, it takes just under 10 seconds. | ||
|
||
</Tip> | ||
|
||
## Enabling Caching for a Model | ||
|
||
To enable caching, simply add `hf_cache` to your `config.yml` with a valid `repo_id`. The `hf_cache` has a few key configurations: | ||
- `repo_id` (required): The endpoint for your cloud bucket. Currently, we support Hugging Face and Google Cloud Storage. | ||
- `revision`: Points to your revision. This is only relevant if you are pulling By default, it refers to `main`. | ||
- `allow_patterns`: Only cache files that match specified patterns. Utilize Unix shell-style wildcards to denote these patterns. | ||
- `ignore_patterns`: Conversely, you can also denote file patterns to ignore, hence streamlining the caching process. | ||
|
||
Here is an example of a well written `hf_cache` for Stable Diffusion XL. Note how it only pulls the model weights that it needs using `allow_patterns`. | ||
|
||
```yaml config.yml | ||
... | ||
hf_cache: | ||
- repo_id: madebyollin/sdxl-vae-fp16-fix | ||
allow_patterns: | ||
- config.json | ||
- diffusion_pytorch_model.safetensors | ||
- repo_id: stabilityai/stable-diffusion-xl-base-1.0 | ||
allow_patterns: | ||
- "*.json" | ||
- "*.fp16.safetensors" | ||
- sd_xl_base_1.0.safetensors | ||
- repo_id: stabilityai/stable-diffusion-xl-refiner-1.0 | ||
allow_patterns: | ||
- "*.json" | ||
- "*.fp16.safetensors" | ||
- sd_xl_refiner_1.0.safetensors | ||
... | ||
``` | ||
|
||
Many Hugging Face repos have model weights in different formats (`.bin`, `.safetensors`, `.h5`, `.msgpack`, etc.). You only need one of these most of the time. To minimize cold starts, ensure that you only cache the weights you need. | ||
|
||
There are also some additional steps depending on the cloud bucket you want to query. | ||
|
||
### Hugging Face 🤗 | ||
For any public Hugging Face repo, you don't need to do anything else. Adding the `hf_cache` key with an appropriate `repo_id` should be enough. | ||
|
||
However, if you want to deploy a model from a gated repo like [Llama 2](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) to Baseten, there's a few steps you need to take: | ||
<Steps> | ||
<Step title="Get Hugging Face API Key"> | ||
[Grab an API key](https://huggingface.co/settings/tokens) from Hugging Face with `read` access. Make sure you have access to the model you want to serve. | ||
</Step> | ||
<Step title="Add it to Baseten Secrets Manager"> | ||
Paste your API key in your [secrets manager in Baseten](https://app.baseten.co/settings/secrets) under the key `hf_access_token`. You can read more about secrets [here](https://truss.baseten.co/guides/secrets). | ||
</Step> | ||
<Step title="Update Config"> | ||
In your Truss's `config.yml`, add the following code: | ||
|
||
```yaml config.yml | ||
... | ||
secrets: | ||
hf_access_token: null | ||
... | ||
``` | ||
|
||
Make sure that the key `secrets` only shows up once in your `config.yml`. | ||
</Step> | ||
</Steps> | ||
|
||
If you run into any issues, run through all the steps above again and make sure you did not misspell the name of the repo or paste an incorrect API key. | ||
|
||
<Tip> | ||
Weights will be cached in the default Hugging Face cache directory, `~/.cache/huggingface/hub/models--{your_model_name}/`. You can change this directory by setting the `HF_HOME` or `HUGGINGFACE_HUB_CACHE` environment variable in your `config.yml`. | ||
|
||
[Read more here](https://huggingface.co/docs/huggingface_hub/guides/manage-cache). | ||
</Tip> | ||
|
||
### Google Cloud Storage | ||
Google Cloud Storage is a great alternative to Hugging Face when you have a custom model or fine-tune you want to gate, especially if you are already using GCP and care about security and compliance. | ||
|
||
Your `hf_cache` should look something like this: | ||
|
||
```yaml config.yml | ||
... | ||
hf_cache: | ||
repo_id: gcs://path-to-my-bucket | ||
... | ||
``` | ||
|
||
For a private GCS bucket, first export your service account key. Rename it to be `service_account.json` and add it to the `data` directory of your Truss. | ||
|
||
Your file structure should look something like this: | ||
|
||
``` | ||
your-truss | ||
|--model | ||
| └── model.py | ||
|--data | ||
|. └── service_account.json | ||
``` | ||
|
||
<Warning> | ||
If you are using version control, like git, for your Truss, make sure to add `service_account.json` to your `.gitignore` file. You don't want to accidentally expose your service account key. | ||
</Warning> | ||
|
||
Weights will be cached at `/app/hf_cache/{your_bucket_name}`. | ||
|
||
|
||
### Other Buckets | ||
|
||
We're currently workign on adding support for more bucket types, including AWS S3. If you have any suggestions, please [leave an issue](https://github.com/basetenlabs/truss/issues) on our GitHub repo. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -69,6 +69,7 @@ | |
"pages": [ | ||
"guides/secrets", | ||
"guides/base-images", | ||
"guides/model-cache", | ||
"guides/concurrency" | ||
] | ||
}, | ||
|