Skip to content

Commit

Permalink
Merge pull request #837 from basetenlabs/bump-version-0.9.2
Browse files Browse the repository at this point in the history
Release 0.9.2
  • Loading branch information
htrivedi99 authored Feb 26, 2024
2 parents 3f5d53e + 76ff824 commit 1211956
Show file tree
Hide file tree
Showing 10 changed files with 191 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.1"
version = "0.9.2"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
23 changes: 14 additions & 9 deletions truss/templates/server/common/schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from inspect import Signature
from types import MappingProxyType
from typing import (
Any,
Expand Down Expand Up @@ -68,14 +67,20 @@ def _parse_input_type(input_parameters: MappingProxyType) -> Optional[type]:

input_type = parameter_types[0].annotation

if (
input_type == Signature.empty
or not isinstance(input_type, type)
or not issubclass(input_type, BaseModel)
):
return None
if _annotation_is_pydantic_model(input_type):
return input_type

return None


return input_type
def _annotation_is_pydantic_model(annotation: Any) -> bool:
# This try/except clause a workaround for the fact that issubclass()
# does not work with generic types (ie: list, dict),
# and raises a TypeError
try:
return issubclass(annotation, BaseModel)
except TypeError:
return False


def _parse_output_type(output_annotation: Any) -> Optional[OutputType]:
Expand All @@ -90,7 +95,7 @@ def _parse_output_type(output_annotation: Any) -> Optional[OutputType]:
If the output_annotation does not match one of these cases, returns None
"""
if isinstance(output_annotation, type) and issubclass(output_annotation, BaseModel):
if _annotation_is_pydantic_model(output_annotation):
return OutputType(type=output_annotation, supports_streaming=False)

if _is_generator_type(output_annotation):
Expand Down
8 changes: 7 additions & 1 deletion truss/templates/server/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,11 @@ async def __call__(
Generator: In case of streaming response
"""

# The streaming read timeout is the amount of time in between streamed chunks before a timeout is triggered
streaming_read_timeout = self._config.get("runtime", {}).get(
"streaming_read_timeout", STREAMING_RESPONSE_QUEUE_READ_TIMEOUT_SECS
)

if self.truss_schema is not None:
try:
body = self.truss_schema.input_type(**body)
Expand Down Expand Up @@ -328,11 +333,12 @@ async def __call__(
task.add_done_callback(lambda _: semaphore_release_function())
task.add_done_callback(self._background_tasks.discard)

# The gap between responses in a stream must be < streaming_read_timeout
async def _response_generator():
while True:
chunk = await asyncio.wait_for(
response_queue.get(),
timeout=STREAMING_RESPONSE_QUEUE_READ_TIMEOUT_SECS,
timeout=streaming_read_timeout,
)
if chunk is None:
return
Expand Down
12 changes: 12 additions & 0 deletions truss/test_data/test_streaming_read_timeout/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
model_metadata: {}
model_name: null
model_type: custom
python_version: py39
requirements: []
resources:
accelerator: null
cpu: 500m
memory: 512Mi
use_gpu: false
runtime:
streaming_read_timeout: 1
24 changes: 24 additions & 0 deletions truss/test_data/test_streaming_read_timeout/model/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import time
from typing import Any, Dict, List


class Model:
def __init__(self, **kwargs) -> None:
self._data_dir = kwargs["data_dir"]
self._config = kwargs["config"]
self._secrets = kwargs["secrets"]
self._model = None

def load(self):
# Load model here and assign to self._model.
pass

def predict(self, model_input: Any) -> Dict[str, List]:
# Invoke model on model_input and calculate predictions here.
def inner():
time.sleep(2)
for i in range(5):
time.sleep(3)
yield str(i)

return inner()
17 changes: 17 additions & 0 deletions truss/tests/templates/server/test_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,20 @@ def test_model_wrapper_load_error_more_than_allowed(app_path, helpers):
# Allow load thread to execute
time.sleep(1)
assert model_wrapper.load_failed()


@pytest.mark.integration
async def test_model_wrapper_streaming_timeout(app_path):
if "model_wrapper" in sys.modules:
model_wrapper_module = sys.modules["model_wrapper"]
importlib.reload(model_wrapper_module)
else:
model_wrapper_module = importlib.import_module("model_wrapper")
model_wraper_class = getattr(model_wrapper_module, "ModelWrapper")

# Create an instance of ModelWrapper with streaming_read_timeout set to 5 seconds
config = yaml.safe_load((app_path / "config.yaml").read_text())
config["runtime"]["streaming_read_timeout"] = 5
model_wrapper = model_wraper_class(config)
model_wrapper.load()
assert model_wrapper._config.get("runtime").get("streaming_read_timeout") == 5
45 changes: 45 additions & 0 deletions truss/tests/templates/server/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,21 @@ class ModelOutput(BaseModel):
output: str


def test_truss_schema_pydantic_empty_annotations():
class Model:
def predict(self, request):
return "hello"

model = Model()

input_signature = inspect.signature(model.predict).parameters
output_signature = inspect.signature(model.predict).return_annotation

schema = TrussSchema.from_signature(input_signature, output_signature)

assert schema is None


def test_truss_schema_pydantic_input_and_output():
class Model:
def predict(self, request: ModelInput) -> ModelOutput:
Expand Down Expand Up @@ -61,6 +76,36 @@ def predict(self, request: ModelInput) -> str:
assert schema is None


def test_truss_schema_list_types():
class Model:
def predict(self, request: list[str]) -> list[str]:
return ["foo", "bar"]

model = Model()

input_signature = inspect.signature(model.predict).parameters
output_signature = inspect.signature(model.predict).return_annotation

schema = TrussSchema.from_signature(input_signature, output_signature)

assert schema is None


def test_truss_schema_dict_types():
class Model:
def predict(self, request: dict[str, str]) -> dict[str, str]:
return {"foo": "bar"}

model = Model()

input_signature = inspect.signature(model.predict).parameters
output_signature = inspect.signature(model.predict).return_annotation

schema = TrussSchema.from_signature(input_signature, output_signature)

assert schema is None


def test_truss_schema_async():
class Model:
async def predict(self, request: ModelInput) -> Awaitable[ModelOutput]:
Expand Down
42 changes: 38 additions & 4 deletions truss/tests/test_model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,22 @@

logger = logging.getLogger(__name__)

DEFAULT_LOG_ERROR = "Internal Server Error"

def _log_contains_error(line: dict, error: str):

def _log_contains_error(line: dict, error: str, message: str):
return (
line["levelname"] == "ERROR"
and line["message"] == "Internal Server Error"
and line["message"] == message
and error in line["exc_info"]
)


def assert_logs_contain_error(logs: str, error: str):
def assert_logs_contain_error(logs: str, error: str, message=DEFAULT_LOG_ERROR):
loglines = logs.splitlines()
assert any(_log_contains_error(json.loads(line), error) for line in loglines)
assert any(
_log_contains_error(json.loads(line), error, message) for line in loglines
)


class PropagatingThread(Thread):
Expand Down Expand Up @@ -232,6 +236,36 @@ def test_async_streaming():
assert predict_non_stream_response.json() == "01234"


@pytest.mark.integration
def test_async_streaming_timeout():
with ensure_kill_all():
truss_root = Path(__file__).parent.parent.parent.resolve() / "truss"

truss_dir = truss_root / "test_data" / "test_streaming_read_timeout"

tr = TrussHandle(truss_dir)

container = tr.docker_run(
local_port=8090, detach=True, wait_for_server_ready=True
)
truss_server_addr = "http://localhost:8090"
predict_url = f"{truss_server_addr}/v1/models/model:predict"

# ChunkedEncodingError is raised when the chunk does not get processed due to streaming read timeout
with pytest.raises(requests.exceptions.ChunkedEncodingError):
response = requests.post(predict_url, json={}, stream=True)

for chunk in response.iter_content():
pass

# Check to ensure the Timeout error is in the container logs
assert_logs_contain_error(
container.logs(),
error="raise exceptions.TimeoutError()",
message="Exception in ASGI application\n",
)


@pytest.mark.integration
def test_streaming_with_error():
with ensure_kill_all():
Expand Down
27 changes: 26 additions & 1 deletion truss/tests/test_model_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,32 @@ def test_truss_with_no_annotations():
schema_response = requests.get(SCHEMA_URL)
assert schema_response.status_code == 404

schema_response.json()["error"] == "No schema found"
assert schema_response.json()["error"] == "No schema found"


@pytest.mark.integration
def test_truss_with_non_pydantic_annotations():
truss_non_pydantic_annotations = """
class Model:
def predict(self, request: str) -> list[str]:
return ["hello"]
"""

with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
truss_dir = Path(tmp_work_dir, "truss")

create_truss(truss_dir, DEFAULT_CONFIG, truss_non_pydantic_annotations)

tr = TrussHandle(truss_dir)
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)

response = requests.post(INFERENCE_URL, json={"prompt": "value"})
assert response.json() == ["hello"]

schema_response = requests.get(SCHEMA_URL)
assert schema_response.status_code == 404

assert schema_response.json()["error"] == "No schema found"


@pytest.mark.integration
Expand Down
7 changes: 7 additions & 0 deletions truss/truss_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
DEFAULT_SPEC_VERSION = "2.0"
DEFAULT_PREDICT_CONCURRENCY = 1
DEFAULT_NUM_WORKERS = 1
DEFAULT_STREAMING_RESPONSE_READ_TIMEOUT = 60

DEFAULT_CPU = "1"
DEFAULT_MEMORY = "2Gi"
Expand Down Expand Up @@ -140,21 +141,27 @@ def to_list(self, verbose=False) -> List[Dict[str, str]]:
class Runtime:
predict_concurrency: int = DEFAULT_PREDICT_CONCURRENCY
num_workers: int = DEFAULT_NUM_WORKERS
streaming_read_timeout: int = DEFAULT_STREAMING_RESPONSE_READ_TIMEOUT

@staticmethod
def from_dict(d):
predict_concurrency = d.get("predict_concurrency", DEFAULT_PREDICT_CONCURRENCY)
num_workers = d.get("num_workers", DEFAULT_NUM_WORKERS)
streaming_read_timeout = d.get(
"streaming_read_timeout", DEFAULT_STREAMING_RESPONSE_READ_TIMEOUT
)

return Runtime(
predict_concurrency=predict_concurrency,
num_workers=num_workers,
streaming_read_timeout=streaming_read_timeout,
)

def to_dict(self):
return {
"predict_concurrency": self.predict_concurrency,
"num_workers": self.num_workers,
"streaming_read_timeout": self.streaming_read_timeout,
}


Expand Down

0 comments on commit 1211956

Please sign in to comment.