Skip to content

Commit

Permalink
Fxups
Browse files Browse the repository at this point in the history
  • Loading branch information
marius-baseten committed Sep 18, 2024
1 parent 957024c commit 25e67cf
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 75 deletions.
2 changes: 1 addition & 1 deletion truss/templates/server/common/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TrussSchema(BaseModel):
supports_streaming: bool

@classmethod
def from_parameters(
def from_signature(
cls, input_parameters: MappingProxyType, output_annotation: Any
) -> Optional["TrussSchema"]:
"""
Expand Down
2 changes: 1 addition & 1 deletion truss/templates/server/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def from_model(cls, model) -> "ModelDescriptor":
preprocess=preprocess,
predict=predict,
postprocess=postprocess,
truss_schema=TrussSchema.from_parameters(parameters, return_annotation),
truss_schema=TrussSchema.from_signature(parameters, return_annotation),
)


Expand Down
42 changes: 2 additions & 40 deletions truss/test_data/test_streaming_truss_with_tracing/config.yaml
Original file line number Diff line number Diff line change
@@ -1,42 +1,4 @@
apply_library_patches: true
base_image: null
build:
arguments: {}
model_server: TrussServer
secret_to_path_mapping: {}
build_commands: []
bundled_packages_dir: packages
data_dir: data
description: null
environment_variables:
OTEL_TRACING_NDJSON_FILE: /tmp/otel_traces.ndjson
examples_filename: examples.yaml
external_data: null
external_package_dirs: []
input_type: Any
live_reload: false
model_cache: []
model_class_filename: model.py
model_class_name: Model
model_framework: custom
model_metadata: {}
model_module_dir: model
model_name: Test Streaming
model_type: Model
python_version: py39
requirements: []
requirements_file: null
resources:
accelerator: null
cpu: '1'
memory: 2Gi
use_gpu: false
runtime:
enable_tracing_data: false
num_workers: 1
predict_concurrency: 1
streaming_read_timeout: 60
secrets: {}
spec_version: '2.0'
system_packages: []
trt_llm: null
environment_variables:
OTEL_TRACING_NDJSON_FILE: "/tmp/otel_traces.ndjson"
38 changes: 20 additions & 18 deletions truss/tests/templates/server/test_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,24 @@
from starlette.requests import Request


@pytest.fixture
def anyio_backend():
return "asyncio"


@pytest.fixture
def app_path(truss_container_fs: Path, helpers: Any):
truss_container_app_path = truss_container_fs / "app"
model_file_content = """
class Model:
def __init__(self):
self.load_count = 0
def load(self):
self.load_count += 1
if self.load_count <= 2:
raise RuntimeError('Simulated error')
def predict(self, request):
return request
"""
Expand All @@ -34,45 +41,40 @@ def predict(self, request):
yield truss_container_app_path


# TODO: Make this test work
@pytest.mark.skip(
reason="Succeeds when tests in this file are run alone, but fails with the whole suit"
)
def test_model_wrapper_load_error_once(app_path):
@pytest.mark.anyio
async def test_model_wrapper_load_error_once(app_path):
if "model_wrapper" in sys.modules:
model_wrapper_module = sys.modules["model_wrapper"]
importlib.reload(model_wrapper_module)
else:
model_wrapper_module = importlib.import_module("model_wrapper")
model_wraper_class = getattr(model_wrapper_module, "ModelWrapper")
model_wrapper_class = getattr(model_wrapper_module, "ModelWrapper")
config = yaml.safe_load((app_path / "config.yaml").read_text())
model_wrapper = model_wraper_class(config)
os.chdir(app_path)
model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer())
model_wrapper.load()
# Allow load thread to execute
time.sleep(1)
output = model_wrapper.predict({}, MagicMock(spec=Request))
output = await model_wrapper.predict({}, MagicMock(spec=Request))
assert output == {}
assert model_wrapper._model.load_count == 3
assert model_wrapper._model.load_count == 2


# TODO: Make this test work
@pytest.mark.skip(
reason="Succeeds when tests in this file are run alone, but fails with the whole suit"
)
def test_model_wrapper_load_error_more_than_allowed(app_path, helpers):
with helpers.env_var("NUM_LOAD_RETRIES_TRUSS", "0"):
if "model_wrapper" in sys.modules:
model_wrapper_module = sys.modules["model_wrapper"]
importlib.reload(model_wrapper_module)
else:
model_wrapper_module = importlib.import_module("model_wrapper")
model_wraper_class = getattr(model_wrapper_module, "ModelWrapper")
model_wrapper_class = getattr(model_wrapper_module, "ModelWrapper")
config = yaml.safe_load((app_path / "config.yaml").read_text())
model_wrapper = model_wraper_class(config)
os.chdir(app_path)
model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer())
model_wrapper.load()
# Allow load thread to execute
time.sleep(1)
assert model_wrapper.load_failed()
assert model_wrapper.load_failed


@pytest.mark.anyio
Expand All @@ -83,12 +85,12 @@ async def test_model_wrapper_streaming_timeout(app_path):
importlib.reload(model_wrapper_module)
else:
model_wrapper_module = importlib.import_module("model_wrapper")
model_wraper_class = getattr(model_wrapper_module, "ModelWrapper")
model_wrapper_class = getattr(model_wrapper_module, "ModelWrapper")

# Create an instance of ModelWrapper with streaming_read_timeout set to 5 seconds
config = yaml.safe_load((app_path / "config.yaml").read_text())
config["runtime"]["streaming_read_timeout"] = 5
model_wrapper = model_wraper_class(config)
model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer())
model_wrapper.load()
assert model_wrapper._config.get("runtime").get("streaming_read_timeout") == 5

Expand Down
30 changes: 15 additions & 15 deletions truss/tests/templates/server/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def predict(self, request):
input_signature = inspect.signature(model.predict).parameters
output_signature = inspect.signature(model.predict).return_annotation

schema = TrussSchema.from_parameters(input_signature, output_signature)
schema = TrussSchema.from_signature(input_signature, output_signature)

assert schema is None

Expand All @@ -39,7 +39,7 @@ def predict(self, request: ModelInput) -> ModelOutput:
input_signature = inspect.signature(model.predict).parameters
output_signature = inspect.signature(model.predict).return_annotation

schema = TrussSchema.from_parameters(input_signature, output_signature)
schema = TrussSchema.from_signature(input_signature, output_signature)

assert schema.input_type == ModelInput
assert schema.output_type == ModelOutput
Expand All @@ -56,7 +56,7 @@ def predict(self, request: str) -> ModelOutput:
input_signature = inspect.signature(model.predict).parameters
output_signature = inspect.signature(model.predict).return_annotation

schema = TrussSchema.from_parameters(input_signature, output_signature)
schema = TrussSchema.from_signature(input_signature, output_signature)

assert schema is None

Expand All @@ -71,7 +71,7 @@ def predict(self, request: ModelInput) -> str:
input_signature = inspect.signature(model.predict).parameters
output_signature = inspect.signature(model.predict).return_annotation

schema = TrussSchema.from_parameters(input_signature, output_signature)
schema = TrussSchema.from_signature(input_signature, output_signature)

assert schema is None

Expand All @@ -86,7 +86,7 @@ def predict(self, request: list[str]) -> list[str]:
input_signature = inspect.signature(model.predict).parameters
output_signature = inspect.signature(model.predict).return_annotation

schema = TrussSchema.from_parameters(input_signature, output_signature)
schema = TrussSchema.from_signature(input_signature, output_signature)

assert schema is None

Expand All @@ -101,7 +101,7 @@ def predict(self, request: dict[str, str]) -> dict[str, str]:
input_signature = inspect.signature(model.predict).parameters
output_signature = inspect.signature(model.predict).return_annotation

schema = TrussSchema.from_parameters(input_signature, output_signature)
schema = TrussSchema.from_signature(input_signature, output_signature)

assert schema is None

Expand All @@ -116,7 +116,7 @@ async def predict(self, request: ModelInput) -> Awaitable[ModelOutput]:
input_signature = inspect.signature(model.predict).parameters
output_signature = inspect.signature(model.predict).return_annotation

schema = TrussSchema.from_parameters(input_signature, output_signature)
schema = TrussSchema.from_signature(input_signature, output_signature)

assert schema.input_type == ModelInput
assert schema.output_type == ModelOutput
Expand All @@ -133,7 +133,7 @@ def predict(self, request: ModelInput) -> Generator[str, None, None]:
input_signature = inspect.signature(model.predict).parameters
output_signature = inspect.signature(model.predict).return_annotation

schema = TrussSchema.from_parameters(input_signature, output_signature)
schema = TrussSchema.from_signature(input_signature, output_signature)

assert schema.input_type == ModelInput
assert schema.output_type is None
Expand All @@ -150,7 +150,7 @@ async def predict(self, request: ModelInput) -> AsyncGenerator[str, None]:
input_signature = inspect.signature(model.predict).parameters
output_signature = inspect.signature(model.predict).return_annotation

schema = TrussSchema.from_parameters(input_signature, output_signature)
schema = TrussSchema.from_signature(input_signature, output_signature)

assert schema.input_type == ModelInput
assert schema.output_type is None
Expand All @@ -172,7 +172,7 @@ def predict(
input_signature = inspect.signature(model.predict).parameters
output_signature = inspect.signature(model.predict).return_annotation

schema = TrussSchema.from_parameters(input_signature, output_signature)
schema = TrussSchema.from_signature(input_signature, output_signature)
assert schema.input_type == ModelInput
assert schema.output_type == ModelOutput
assert schema.supports_streaming
Expand All @@ -199,7 +199,7 @@ def inner():
input_signature = inspect.signature(model.predict).parameters
output_signature = inspect.signature(model.predict).return_annotation

schema = TrussSchema.from_parameters(input_signature, output_signature)
schema = TrussSchema.from_signature(input_signature, output_signature)
assert schema.input_type == ModelInput
assert schema.output_type is ModelOutput
assert schema.supports_streaming
Expand All @@ -218,7 +218,7 @@ async def predict(
input_signature = inspect.signature(model.predict).parameters
output_signature = inspect.signature(model.predict).return_annotation

schema = TrussSchema.from_parameters(input_signature, output_signature)
schema = TrussSchema.from_signature(input_signature, output_signature)
assert schema is None


Expand All @@ -232,7 +232,7 @@ def predict(self, request: ModelInput) -> Union[str, int]:
input_signature = inspect.signature(model.predict).parameters
output_signature = inspect.signature(model.predict).return_annotation

schema = TrussSchema.from_parameters(input_signature, output_signature)
schema = TrussSchema.from_signature(input_signature, output_signature)

assert schema is None

Expand All @@ -247,7 +247,7 @@ async def predict(self, request: str) -> Awaitable[str]:
input_signature = inspect.signature(model.predict).parameters
output_signature = inspect.signature(model.predict).return_annotation

schema = TrussSchema.from_parameters(input_signature, output_signature)
schema = TrussSchema.from_signature(input_signature, output_signature)
assert schema is None


Expand All @@ -268,6 +268,6 @@ def predict(
input_signature = inspect.signature(model.predict).parameters
output_signature = inspect.signature(model.predict).return_annotation

schema = TrussSchema.from_parameters(input_signature, output_signature)
schema = TrussSchema.from_signature(input_signature, output_signature)

assert schema is None
36 changes: 36 additions & 0 deletions truss/tests/test_model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from threading import Thread
from typing import Iterator, Mapping

import httpx
import opentelemetry.trace.propagation.tracecontext as tracecontext
import pytest
import requests
Expand Down Expand Up @@ -1109,3 +1110,38 @@ def postprocess(self, inputs):
"If the predict function returns a response object, you cannot "
"use postprocessing.",
)


@pytest.mark.integration
def test_async_streaming_with_cancellation():
model = """
import fastapi, asyncio, logging
class Model:
async def predict(self, inputs, request: fastapi.Request):
await asyncio.sleep(1)
if await request.is_disconnected():
logging.warning("Cancelled (before gen).")
return
for i in range(5):
await asyncio.sleep(1.0)
logging.warning(i)
yield str(i)
if await request.is_disconnected():
logging.warning("Cancelled (during gen).")
return
"""
with ensure_kill_all(), temp_truss(model, "") as tr:
container = tr.docker_run(
local_port=8090, detach=True, wait_for_server_ready=True
)
with pytest.raises(httpx.ReadTimeout):
with httpx.Client(
timeout=httpx.Timeout(1.0, connect=1.0, read=1.0)
) as client:
response = client.post(PREDICT_URL, json={}, timeout=1.0)
response.raise_for_status()

time.sleep(2) # Wait a bit to get all logs.
assert "Cancelled (during gen)." in container.logs()

0 comments on commit 25e67cf

Please sign in to comment.