Skip to content

Commit

Permalink
Merge pull request #774 from basetenlabs/bump-version-0.7.22
Browse files Browse the repository at this point in the history
Release 0.7.22
  • Loading branch information
squidarth authored Dec 14, 2023
2 parents 019583d + 06d8da2 commit 7649d3e
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 3 deletions.
155 changes: 155 additions & 0 deletions docs/guides/data-directory.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
---
title: How to load model weights
description: "Load model weights without Hugging Face or S3"
---

Serving a model requires access to model files, such as model weights. These files are often many gigabytes.

For many models, these files are loaded from Hugging Face. However, model files can come from other sources or be stored directly in the Truss. Model weights and other model data can be:

* Public on Hugging Face (default, [example here](/examples/04-image-generation))
* [Private on Hugging Face](/examples/09-private-huggingface)
* [Bundled directly with the Truss](#bundling-model-weights-in-truss)
* [Public cloud storage like S3](#loading-public-model-weights-from-s3)
* [Private cloud storage like S3](#loading-private-model-weights-from-s3)

## Bundling model weights in Truss

You can bundle model data directly with your model in Truss. To do so, use the Truss' `data` folder to store any necessary files.

Here's an example of the `data` folder for [a Truss of Stable Diffusion 2.1](https://github.com/basetenlabs/truss-examples/tree/main/stable-diffusion/stable-diffusion).

```
data/
scheduler/
scheduler_config.json
text_encoder/
config.json
diffusion_pytorch_model.bin
tokenizer/
merges.txt
special_tokens_map.json
tokenizer_config.json
vocab.json
unet/
config.json
diffusion_pytorch_model.bin
vae/
config.json
diffusion_pytorch_model.bin
model_index.json
```

To access the data in the model, use the `self._data_dir` variable in the `load()` function of `model/model.py`:

```python
class Model:
def __init__(self, **kwargs) -> None:
self._data_dir = kwargs["data_dir"]

def load(self):
self.model = StableDiffusionPipeline.from_pretrained(
str(self._data_dir), # Set to "data" by default from config.yaml
revision="fp16",
torch_dtype=torch.float16,
).to("cuda")
```

## Loading public model weights from S3

Bundling multi-gigabyte files with your Truss can be difficult if you have limited local storage and can make deployment slow. Instead, you can store your model weights and other files in cloud storage like S3.

Using files from S3 requires four steps:

1. Uploading the content of your data directory to S3
2. Setting `external_data` in config.yaml
3. Removing unneeded files from the `data` directory
4. Accessing data correctly in the model

Here's an example of that setup for Stable Diffusion, where we have already uploaded the content of our `data/` directory to S3.

First, add the URLs for hosted versions of the large files to `config.yaml`:

```yaml
external_data:
- url: https://baseten-public.s3.us-west-2.amazonaws.com/models/stable-diffusion-truss/unet/diffusion_pytorch_model.bin
local_data_path: unet/diffusion_pytorch_model.bin
- url: https://baseten-public.s3.us-west-2.amazonaws.com/models/stable-diffusion-truss/text_encoder/pytorch_model.bin
local_data_path: text_encoder/pytorch_model.bin
- url: https://baseten-public.s3.us-west-2.amazonaws.com/models/stable-diffusion-truss/vae/diffusion_pytorch_model.bin
local_data_path: vae/diffusion_pytorch_model.bin
```
Each URL matches with a local data path that represents where the model data would be stored if everything was bundled together locally. This is how your model code will know where to look for the data.
Then, get rid of the large files from your `data` folder. The Stable Diffusion Truss has the following directory structure after large files are removed:

```
data/
scheduler/
scheduler_config.json
text_encoder/
config.json
tokenizer/
merges.txt
special_tokens_map.json
tokenizer_config.json
vocab.json
unet/
config.json
vae/
config.json
model_index.json
```
The code in `model/model.py` does not need to be changed and will automatically pull the large files from the provided links.
## Loading private model weights from S3
If your model weights are proprietary, you'll be storing them in a private S3 bucket or similar access-restricted data store. Accessing these model files works exactly the same as above, but first uses [secrets](/guides/secrets) to securely authenticate your model with the data store.
First, set the following secrets in `config.yaml`. Set the values to `null`, only the keys are needed here.
```yaml
secrets:
aws_access_key_id: null
aws_secret_access_key: null
aws_region: null # e.g. us-east-1
aws_bucket: null
```

Then, [add secrets to your Baseten account](https://docs.baseten.co/observability/secrets) for your AWS access key id, secret access key, region, and bucket. This time, use the actual values as they will be securely stored and provided to your model at runtime.

In your model code, authenticate with AWS in the `__init__()` function:

```python
def __init__(self, **kwargs) -> None:
self._config = kwargs.get("config")
secrets = kwargs.get("secrets")
self.s3_config = (
{
"aws_access_key_id": secrets["aws_access_key_id"],
"aws_secret_access_key": secrets["aws_secret_access_key"],
"aws_region": secrets["aws_region"],
}
)
self.s3_bucket = (secrets["aws_bucket"])
```

You can then use the `boto3` package to access your model weights in `load()`.

When you're ready to deploy your model, make sure to pass `is_trusted=True` to `baseten.deploy()`:

```python
import baseten
import truss

my_model = truss.load("my-model")
baseten.deploy(
my_model,
model_name="My model",
is_trusted=True
)
```

For further details, see [docs on using secrets in models](/guides/secrets).
2 changes: 1 addition & 1 deletion docs/guides/model-cache.mdx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
title: Caching model weights
title: How to cache model weights
description: "Accelerate cold starts by caching your weights"
---

Expand Down
1 change: 1 addition & 0 deletions docs/mint.json
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"pages": [
"guides/secrets",
"guides/base-images",
"guides/data-directory",
"guides/model-cache",
"guides/concurrency",
"guides/tgi"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.7.21"
version = "0.7.22"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
16 changes: 15 additions & 1 deletion truss/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,13 +409,22 @@ def predict(
default=False,
help="Trust truss with hosted secrets.",
)
@click.option(
"--promote",
type=bool,
is_flag=True,
required=False,
default=False,
help="After deploy completes, promotes the truss to production.",
)
@error_handling
def push(
target_directory: str,
remote: str,
model_name: str,
publish: bool = False,
trusted: bool = False,
promote: bool = False,
) -> None:
"""
Pushes a truss to a TrussRemote.
Expand All @@ -440,7 +449,7 @@ def push(
tr.spec.config.write_to_yaml_file(tr.spec.config_path, verbose=False)

# TODO(Abu): This needs to be refactored to be more generic
service = remote_provider.push(tr, model_name, publish=publish, trusted=trusted) # type: ignore
service = remote_provider.push(tr, model_name, publish=publish, trusted=trusted, promote=promote) # type: ignore

click.echo(f"✨ Model {model_name} was successfully pushed ✨")

Expand All @@ -467,6 +476,11 @@ def push(
"""
console.print(not_trusted_text, style="red")

if promote:
promotion_text = """Your Truss has been deployed as a production model. After it successfully deploys,
it will become the next production deployment of your model."""
console.print(promotion_text, style="green")

logs_url = remote_provider.get_remote_logs_url(service) # type: ignore[attr-defined]
rich.print(f"🪵 View logs for your deployment at {logs_url}")

Expand Down
2 changes: 2 additions & 0 deletions truss/remote/baseten/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def create_model_version_from_truss(
semver_bump: str,
client_version: str,
is_trusted: bool,
promote: bool = False,
):
query_string = f"""
mutation {{
Expand All @@ -110,6 +111,7 @@ def create_model_version_from_truss(
semver_bump: "{semver_bump}",
client_version: "{client_version}",
is_trusted: {'true' if is_trusted else 'false'}
promote_after_deploy: {'true' if promote else 'false'}
) {{
id
}}
Expand Down
3 changes: 3 additions & 0 deletions truss/remote/baseten/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def create_truss_service(
config: str,
semver_bump: str = "MINOR",
is_trusted: bool = False,
promote: bool = False,
is_draft: Optional[bool] = False,
model_id: Optional[str] = None,
) -> Tuple[str, str]:
Expand All @@ -162,6 +163,7 @@ def create_truss_service(
config: Base64 encoded JSON string of the Truss config
semver_bump: Semver bump type, defaults to "MINOR"
is_trusted: Whether the model is trusted, defaults to False
promote: Whether to promote the model after deploy, defaults to False
Returns:
A tuple of the model ID and version ID
Expand Down Expand Up @@ -196,6 +198,7 @@ def create_truss_service(
semver_bump=semver_bump,
client_version=f"truss=={truss.version()}",
is_trusted=is_trusted,
promote=promote,
)
model_version_id = model_version_json["id"]
return (model_id, model_version_id)
8 changes: 8 additions & 0 deletions truss/remote/baseten/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def push( # type: ignore
model_name: str,
publish: bool = True,
trusted: bool = False,
promote: bool = False,
):
if model_name.isspace():
raise ValueError("Model name cannot be empty")
Expand All @@ -55,6 +56,12 @@ def push( # type: ignore
gathered_truss = TrussHandle(truss_handle.gather())
if gathered_truss.spec.model_server != ModelServer.TrussServer:
publish = True

if promote:
# If we are promoting a model after deploy, it must be published.
# Draft models cannot be promoted.
publish = True

encoded_config_str = base64_encoded_json_str(
gathered_truss._spec._config.to_dict()
)
Expand All @@ -70,6 +77,7 @@ def push( # type: ignore
is_draft=not publish,
model_id=model_id,
is_trusted=trusted,
promote=promote,
)

return BasetenService(
Expand Down

0 comments on commit 7649d3e

Please sign in to comment.