Skip to content

Commit

Permalink
Add model cache docs (#675)
Browse files Browse the repository at this point in the history
* added model cache docs

* added model cache docs

* add fixes

* typo

* fix

* remove whitespace

* fixed whitespace issue

---------

Co-authored-by: Sidharth Shanker <sid.shanker@baseten.co>
  • Loading branch information
Varun Shenoy and squidarth authored Sep 29, 2023
1 parent b906180 commit 9e7fe01
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 2 deletions.
4 changes: 2 additions & 2 deletions docs/examples/performance/cached-weights.mdx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
title: Load cached model weights
description: "Deploy a model with private Hugging Face weights"
title: Deploy Llama 2 with Caching
description: "Enable fast cold starts for a model with private Hugging Face weights"
---

In this example, we will cover how you can use the `hf_cache` key in your Truss's `config.yml` to automatically bundle model weights from a private Hugging Face repo.
Expand Down
119 changes: 119 additions & 0 deletions docs/guides/model-cache.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
---
title: Caching model weights
description: "Accelerate cold starts by caching your weights"
---

Truss natively supports automatic caching for model weights. This is a simple yet effective strategy to enhance deployment speed and operational efficiency when it comes to cold starts and scaling beyond a single replica.

<Tip>
### What is a "cold start"?

"Cold start" is a term used to refer to the time taken by a model to boot up after being idle. This process can become a critical factor in serverless environments, as it can significantly influence the model response time, customer satisfaction, and cost.

Without caching our model's weights, we would need to download weights every time we scale up. Caching model weights circumvents this download process. When our new instance boots up, the server automatically finds the cached weights and can proceed with starting up the endpoint.

In practice, this reduces the cold start for large models to just a few seconds. For example, Stable Diffusion XL can take a few minutes to boot up without caching. With caching, it takes just under 10 seconds.

</Tip>

## Enabling Caching for a Model

To enable caching, simply add `hf_cache` to your `config.yml` with a valid `repo_id`. The `hf_cache` has a few key configurations:
- `repo_id` (required): The endpoint for your cloud bucket. Currently, we support Hugging Face and Google Cloud Storage.
- `revision`: Points to your revision. This is only relevant if you are pulling By default, it refers to `main`.
- `allow_patterns`: Only cache files that match specified patterns. Utilize Unix shell-style wildcards to denote these patterns.
- `ignore_patterns`: Conversely, you can also denote file patterns to ignore, hence streamlining the caching process.

Here is an example of a well written `hf_cache` for Stable Diffusion XL. Note how it only pulls the model weights that it needs using `allow_patterns`.

```yaml config.yml
...
hf_cache:
- repo_id: madebyollin/sdxl-vae-fp16-fix
allow_patterns:
- config.json
- diffusion_pytorch_model.safetensors
- repo_id: stabilityai/stable-diffusion-xl-base-1.0
allow_patterns:
- "*.json"
- "*.fp16.safetensors"
- sd_xl_base_1.0.safetensors
- repo_id: stabilityai/stable-diffusion-xl-refiner-1.0
allow_patterns:
- "*.json"
- "*.fp16.safetensors"
- sd_xl_refiner_1.0.safetensors
...
```

Many Hugging Face repos have model weights in different formats (`.bin`, `.safetensors`, `.h5`, `.msgpack`, etc.). You only need one of these most of the time. To minimize cold starts, ensure that you only cache the weights you need.

There are also some additional steps depending on the cloud bucket you want to query.

### Hugging Face 🤗
For any public Hugging Face repo, you don't need to do anything else. Adding the `hf_cache` key with an appropriate `repo_id` should be enough.

However, if you want to deploy a model from a gated repo like [Llama 2](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) to Baseten, there's a few steps you need to take:
<Steps>
<Step title="Get Hugging Face API Key">
[Grab an API key](https://huggingface.co/settings/tokens) from Hugging Face with `read` access. Make sure you have access to the model you want to serve.
</Step>
<Step title="Add it to Baseten Secrets Manager">
Paste your API key in your [secrets manager in Baseten](https://app.baseten.co/settings/secrets) under the key `hf_access_token`. You can read more about secrets [here](https://truss.baseten.co/guides/secrets).
</Step>
<Step title="Update Config">
In your Truss's `config.yml`, add the following code:

```yaml config.yml
...
secrets:
hf_access_token: null
...
```

Make sure that the key `secrets` only shows up once in your `config.yml`.
</Step>
</Steps>

If you run into any issues, run through all the steps above again and make sure you did not misspell the name of the repo or paste an incorrect API key.

<Tip>
Weights will be cached in the default Hugging Face cache directory, `~/.cache/huggingface/hub/models--{your_model_name}/`. You can change this directory by setting the `HF_HOME` or `HUGGINGFACE_HUB_CACHE` environment variable in your `config.yml`.

[Read more here](https://huggingface.co/docs/huggingface_hub/guides/manage-cache).
</Tip>

### Google Cloud Storage
Google Cloud Storage is a great alternative to Hugging Face when you have a custom model or fine-tune you want to gate, especially if you are already using GCP and care about security and compliance.

Your `hf_cache` should look something like this:

```yaml config.yml
...
hf_cache:
repo_id: gcs://path-to-my-bucket
...
```

For a private GCS bucket, first export your service account key. Rename it to be `service_account.json` and add it to the `data` directory of your Truss.

Your file structure should look something like this:

```
your-truss
|--model
| └── model.py
|--data
|. └── service_account.json
```

<Warning>
If you are using version control, like git, for your Truss, make sure to add `service_account.json` to your `.gitignore` file. You don't want to accidentally expose your service account key.
</Warning>

Weights will be cached at `/app/hf_cache/{your_bucket_name}`.


### Other Buckets

We're currently workign on adding support for more bucket types, including AWS S3. If you have any suggestions, please [leave an issue](https://github.com/basetenlabs/truss/issues) on our GitHub repo.
1 change: 1 addition & 0 deletions docs/mint.json
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
"pages": [
"guides/secrets",
"guides/base-images",
"guides/model-cache",
"guides/concurrency"
]
},
Expand Down

0 comments on commit 9e7fe01

Please sign in to comment.