-
Notifications
You must be signed in to change notification settings - Fork 75
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add MVP workflow without pydantic types. Add programmatic deployment
- Loading branch information
1 parent
426236f
commit 568b73a
Showing
99 changed files
with
11,149 additions
and
173 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
179 changes: 179 additions & 0 deletions
179
example_workflow_mvp/.slay_gen/processor_GenerateData/processor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
import logging | ||
import pathlib | ||
|
||
from slay import definitions | ||
from truss.templates.shared import secrets_resolver | ||
|
||
log_format = "%(levelname).1s%(asctime)s %(filename)s:%(lineno)d] %(message)s" | ||
date_format = "%m%d %H:%M:%S" | ||
logging.basicConfig(level=logging.DEBUG, format=log_format, datefmt=date_format) | ||
|
||
|
||
import random | ||
import string | ||
import subprocess | ||
from typing import Protocol | ||
|
||
import pydantic | ||
import slay | ||
from user_package import shared_processor | ||
|
||
IMAGE_COMMON = slay.Image().pip_requirements_txt("common_requirements.txt") | ||
|
||
|
||
class GenerateData(slay.ProcessorBase): | ||
|
||
default_config = slay.Config(image=IMAGE_COMMON) | ||
|
||
def run(self, length: int) -> str: | ||
return "".join(random.choices(string.ascii_letters + string.digits, k=length)) | ||
|
||
|
||
IMAGE_TRANSFORMERS_GPU = ( | ||
slay.Image() | ||
.cuda("12.8") | ||
.pip_requirements_txt("common_requirements.txt") | ||
.pip_install("transformers") | ||
) | ||
|
||
|
||
class MistraLLMConfig(pydantic.BaseModel): | ||
hf_model_name: str | ||
|
||
|
||
class MistralLLM(slay.ProcessorBase[MistraLLMConfig]): | ||
|
||
default_config = slay.Config( | ||
image=IMAGE_TRANSFORMERS_GPU, | ||
resources=slay.Resources().cpu(12).gpu("A100"), | ||
user_config=MistraLLMConfig(hf_model_name="EleutherAI/mistral-6.7B"), | ||
) | ||
# default_config = slay.Config(config_path="mistral_config.yaml") | ||
|
||
def __init__( | ||
self, | ||
context: slay.Context = slay.provide_context(), | ||
) -> None: | ||
super().__init__(context) | ||
try: | ||
subprocess.check_output(["nvidia-smi"], text=True) | ||
except: | ||
raise RuntimeError( | ||
f"Cannot run `{self.__class__}`, because host has no CUDA." | ||
) | ||
import transformers | ||
|
||
model_name = self.user_config.hf_model_name | ||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) | ||
model = transformers.AutoModelForCausalLM.from_pretrained(model_name) | ||
self._model = transformers.pipeline( | ||
"text-generation", model=model, tokenizer=tokenizer | ||
) | ||
|
||
def run(self, data: str) -> str: | ||
return self._model(data, max_length=50) | ||
|
||
|
||
class MistralP(Protocol): | ||
def __init__(self, context: slay.Context) -> None: | ||
... | ||
|
||
def run(self, data: str) -> str: | ||
... | ||
|
||
|
||
class TextToNum(slay.ProcessorBase): | ||
default_config = slay.Config(image=IMAGE_COMMON) | ||
|
||
def __init__( | ||
self, | ||
context: slay.Context = slay.provide_context(), | ||
mistral: MistralP = slay.provide(MistralLLM), | ||
) -> None: | ||
super().__init__(context) | ||
self._mistral = mistral | ||
|
||
def run(self, data: str) -> int: | ||
number = 0 | ||
generated_text = self._mistral.run(data) | ||
for char in generated_text: | ||
number += ord(char) | ||
|
||
return number | ||
|
||
|
||
class Workflow(slay.ProcessorBase): | ||
default_config = slay.Config(image=IMAGE_COMMON) | ||
|
||
def __init__( | ||
self, | ||
context: slay.Context = slay.provide_context(), | ||
data_generator: GenerateData = slay.provide(GenerateData), | ||
splitter: shared_processor.SplitText = slay.provide(shared_processor.SplitText), | ||
text_to_num: TextToNum = slay.provide(TextToNum), | ||
) -> None: | ||
super().__init__(context) | ||
self._data_generator = data_generator | ||
self._data_splitter = splitter | ||
self._text_to_num = text_to_num | ||
|
||
async def run(self, length: int, num_partitions: int) -> tuple[int, str, int]: | ||
data = self._data_generator.run(length) | ||
text_parts, number = await self._data_splitter.run(data, num_partitions) | ||
value = 0 | ||
for part in text_parts: | ||
value += self._text_to_num.run(part) | ||
return value, data, number | ||
|
||
|
||
if __name__ == "__main__": | ||
import asyncio | ||
|
||
# Local test or dev execution - context manager makes sure local processors | ||
# are instantiated and injected. | ||
# with slay.run_local(): | ||
# wf = Workflow() | ||
# params = Parameters() | ||
# result = wf.run(params=params) | ||
# print(result) | ||
# class FakeMistralLLM(slay.ProcessorBase): | ||
# def run(self, data: str) -> str: | ||
# return data.upper() | ||
# with slay.run_local(): | ||
# text_to_num = TextToNum(mistral=FakeMistralLLM()) | ||
# wf = Workflow(text_to_num=text_to_num) | ||
# result = asyncio.run(wf.run(length=80, num_partitions=5)) | ||
# print(result) | ||
# # Gives a `UsageError`, because not in `run_local` context. | ||
# try: | ||
# wf = Workflow() | ||
# except slay.UsageError as e: | ||
# print(e) | ||
# A "marker" to designate which processors should be deployed as public remote | ||
# service points. Depenedency processors will also be deployed, but only as | ||
# "internal" services, not as a "public" sevice endpoint. | ||
slay.deploy_remotely([Workflow]) | ||
|
||
|
||
class Model: | ||
_context: definitions.Context | ||
_processor: GenerateData | ||
|
||
def __init__( | ||
self, config: dict, data_dir: pathlib.Path, secrets: secrets_resolver.Secrets | ||
) -> None: | ||
truss_metadata = definitions.TrussMetadata.parse_obj( | ||
config["model_metadata"]["slay_metadata"] | ||
) | ||
self._context = definitions.Context( | ||
user_config=truss_metadata.user_config, | ||
stub_cls_to_url=truss_metadata.stub_cls_to_url, | ||
secrets=secrets, | ||
) | ||
|
||
def load(self) -> None: | ||
self._processor = GenerateData(context=self._context) | ||
|
||
def predict(self, payload): | ||
result = self._processor.run(length=payload["length"]) | ||
return result |
6 changes: 6 additions & 0 deletions
6
example_workflow_mvp/.slay_gen/processor_GenerateData/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
git+https://github.com/basetenlabs/truss.git@marius-orchestration | ||
httpx | ||
libcst | ||
rope | ||
black | ||
isort |
23 changes: 23 additions & 0 deletions
23
example_workflow_mvp/.slay_gen/processor_GenerateData/truss/config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
base_image: | ||
image: python:3.11-slim | ||
python_executable_path: '' | ||
environment_variables: {} | ||
external_package_dirs: [] | ||
model_metadata: | ||
slay_metadata: | ||
model_config: | ||
arbitrary_types_allowed: true | ||
stub_cls_to_url: {} | ||
user_config: null | ||
model_name: GenerateData | ||
python_version: '3.11' | ||
requirements: [] | ||
requirements_file: requirements.txt | ||
resources: | ||
accelerator: null | ||
cpu: '1' | ||
memory: 2Gi | ||
use_gpu: false | ||
secrets: | ||
baseten_api_key: BASETEN_API_KEY | ||
system_packages: [] |
Oops, something went wrong.