Skip to content

Commit

Permalink
Add MVP workflow without pydantic types. Add programmatic deployment
Browse files Browse the repository at this point in the history
  • Loading branch information
marius-baseten committed Mar 18, 2024
1 parent 426236f commit 568b73a
Show file tree
Hide file tree
Showing 99 changed files with 11,149 additions and 173 deletions.
6 changes: 3 additions & 3 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
"[python]": {
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.fixAll": true,
"source.organizeImports": true,
"source.sortMembers": true
"source.fixAll": "explicit",
"source.organizeImports": "explicit",
"source.sortMembers": "explicit"
}
}
}
179 changes: 179 additions & 0 deletions example_workflow_mvp/.slay_gen/processor_GenerateData/processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import logging
import pathlib

from slay import definitions
from truss.templates.shared import secrets_resolver

log_format = "%(levelname).1s%(asctime)s %(filename)s:%(lineno)d] %(message)s"
date_format = "%m%d %H:%M:%S"
logging.basicConfig(level=logging.DEBUG, format=log_format, datefmt=date_format)


import random
import string
import subprocess
from typing import Protocol

import pydantic
import slay
from user_package import shared_processor

IMAGE_COMMON = slay.Image().pip_requirements_txt("common_requirements.txt")


class GenerateData(slay.ProcessorBase):

default_config = slay.Config(image=IMAGE_COMMON)

def run(self, length: int) -> str:
return "".join(random.choices(string.ascii_letters + string.digits, k=length))


IMAGE_TRANSFORMERS_GPU = (
slay.Image()
.cuda("12.8")
.pip_requirements_txt("common_requirements.txt")
.pip_install("transformers")
)


class MistraLLMConfig(pydantic.BaseModel):
hf_model_name: str


class MistralLLM(slay.ProcessorBase[MistraLLMConfig]):

default_config = slay.Config(
image=IMAGE_TRANSFORMERS_GPU,
resources=slay.Resources().cpu(12).gpu("A100"),
user_config=MistraLLMConfig(hf_model_name="EleutherAI/mistral-6.7B"),
)
# default_config = slay.Config(config_path="mistral_config.yaml")

def __init__(
self,
context: slay.Context = slay.provide_context(),
) -> None:
super().__init__(context)
try:
subprocess.check_output(["nvidia-smi"], text=True)
except:
raise RuntimeError(
f"Cannot run `{self.__class__}`, because host has no CUDA."
)
import transformers

model_name = self.user_config.hf_model_name
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
self._model = transformers.pipeline(
"text-generation", model=model, tokenizer=tokenizer
)

def run(self, data: str) -> str:
return self._model(data, max_length=50)


class MistralP(Protocol):
def __init__(self, context: slay.Context) -> None:
...

def run(self, data: str) -> str:
...


class TextToNum(slay.ProcessorBase):
default_config = slay.Config(image=IMAGE_COMMON)

def __init__(
self,
context: slay.Context = slay.provide_context(),
mistral: MistralP = slay.provide(MistralLLM),
) -> None:
super().__init__(context)
self._mistral = mistral

def run(self, data: str) -> int:
number = 0
generated_text = self._mistral.run(data)
for char in generated_text:
number += ord(char)

return number


class Workflow(slay.ProcessorBase):
default_config = slay.Config(image=IMAGE_COMMON)

def __init__(
self,
context: slay.Context = slay.provide_context(),
data_generator: GenerateData = slay.provide(GenerateData),
splitter: shared_processor.SplitText = slay.provide(shared_processor.SplitText),
text_to_num: TextToNum = slay.provide(TextToNum),
) -> None:
super().__init__(context)
self._data_generator = data_generator
self._data_splitter = splitter
self._text_to_num = text_to_num

async def run(self, length: int, num_partitions: int) -> tuple[int, str, int]:
data = self._data_generator.run(length)
text_parts, number = await self._data_splitter.run(data, num_partitions)
value = 0
for part in text_parts:
value += self._text_to_num.run(part)
return value, data, number


if __name__ == "__main__":
import asyncio

# Local test or dev execution - context manager makes sure local processors
# are instantiated and injected.
# with slay.run_local():
# wf = Workflow()
# params = Parameters()
# result = wf.run(params=params)
# print(result)
# class FakeMistralLLM(slay.ProcessorBase):
# def run(self, data: str) -> str:
# return data.upper()
# with slay.run_local():
# text_to_num = TextToNum(mistral=FakeMistralLLM())
# wf = Workflow(text_to_num=text_to_num)
# result = asyncio.run(wf.run(length=80, num_partitions=5))
# print(result)
# # Gives a `UsageError`, because not in `run_local` context.
# try:
# wf = Workflow()
# except slay.UsageError as e:
# print(e)
# A "marker" to designate which processors should be deployed as public remote
# service points. Depenedency processors will also be deployed, but only as
# "internal" services, not as a "public" sevice endpoint.
slay.deploy_remotely([Workflow])


class Model:
_context: definitions.Context
_processor: GenerateData

def __init__(
self, config: dict, data_dir: pathlib.Path, secrets: secrets_resolver.Secrets
) -> None:
truss_metadata = definitions.TrussMetadata.parse_obj(
config["model_metadata"]["slay_metadata"]
)
self._context = definitions.Context(
user_config=truss_metadata.user_config,
stub_cls_to_url=truss_metadata.stub_cls_to_url,
secrets=secrets,
)

def load(self) -> None:
self._processor = GenerateData(context=self._context)

def predict(self, payload):
result = self._processor.run(length=payload["length"])
return result
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
git+https://github.com/basetenlabs/truss.git@marius-orchestration
httpx
libcst
rope
black
isort
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
base_image:
image: python:3.11-slim
python_executable_path: ''
environment_variables: {}
external_package_dirs: []
model_metadata:
slay_metadata:
model_config:
arbitrary_types_allowed: true
stub_cls_to_url: {}
user_config: null
model_name: GenerateData
python_version: '3.11'
requirements: []
requirements_file: requirements.txt
resources:
accelerator: null
cpu: '1'
memory: 2Gi
use_gpu: false
secrets:
baseten_api_key: BASETEN_API_KEY
system_packages: []
Loading

0 comments on commit 568b73a

Please sign in to comment.