Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TrussServer supports request/repsonse #1148

Merged
merged 9 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions truss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import warnings
from pathlib import Path

from pydantic import PydanticDeprecatedSince20
from single_source import get_version

# Suppress Pydantic V1 warnings, because we have to use it for backwards compat.
warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20)


__version__ = get_version(__name__, Path(__file__).parent.parent)


Expand Down
6 changes: 5 additions & 1 deletion truss/config/trt_llm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import json
import logging
import warnings
from enum import Enum
from typing import Optional

from huggingface_hub.errors import HFValidationError
from huggingface_hub.utils import validate_repo_id
from pydantic import BaseModel, validator
from pydantic import BaseModel, PydanticDeprecatedSince20, validator
from rich.console import Console

# Suppress Pydantic V1 warnings, because we have to use it for backwards compat.
warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

Expand Down
4 changes: 4 additions & 0 deletions truss/remote/baseten/service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import enum
import time
import urllib.parse
import warnings
from typing import (
Any,
Dict,
Expand All @@ -17,6 +18,9 @@
from truss.truss_handle import TrussHandle
from truss.util.errors import RemoteNetworkError

# "classes created inside an enum will not become a member" -> intended here anyway.
warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*enum.*")

DEFAULT_STREAM_ENCODING = "utf-8"


Expand Down
7 changes: 4 additions & 3 deletions truss/remote/remote_factory.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import inspect

try:
from configparser import DEFAULTSECT, ConfigParser # type: ignore
except ImportError:
# We need to do this for old python.
from configparser import DEFAULTSECT
from configparser import SafeConfigParser as ConfigParser
except ImportError:
# We need to do this for py312 and onwards.
from configparser import DEFAULTSECT, ConfigParser # type: ignore


from functools import partial
from operator import is_not
Expand Down
26 changes: 16 additions & 10 deletions truss/templates/control/control/application.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import contextlib
import logging
import re
from pathlib import Path
Expand Down Expand Up @@ -45,6 +46,20 @@ async def handle_model_load_failed(_, error):
return JSONResponse({"error": str(error)}, 503)


@contextlib.asynccontextmanager
async def lifespan_context(app: FastAPI):
# Before start.
yield # Run.
# Shutdown.
# FastApi handles the term signal to start the shutdown flow. Here we
# make sure that the inference server is stopeed when control server
# shuts down. Inference server has logic to wait until all requests are
# finished before exiting. By waiting on that, we inherit the same
# behavior for control server.
app.state.logger.info("Term signal received, shutting down.")
app.state.inference_server_process_controller.terminate_with_wait()


def create_app(base_config: Dict):
app_state = State()
setup_logging()
Expand Down Expand Up @@ -99,20 +114,11 @@ async def start_background_inference_startup():
ModelLoadFailed: handle_model_load_failed,
Exception: generic_error_handler,
},
lifespan=lifespan_context,
)
app.state = app_state
app.include_router(control_app)

@app.on_event("shutdown")
def on_shutdown():
# FastApi handles the term signal to start the shutdown flow. Here we
# make sure that the inference server is stopeed when control server
# shuts down. Inference server has logic to wait until all requests are
# finished before exiting. By waiting on that, we inherit the same
# behavior for control server.
app.state.logger.info("Term signal received, shutting down.")
app.state.inference_server_process_controller.terminate_with_wait()

return app


Expand Down
9 changes: 6 additions & 3 deletions truss/templates/server/common/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)

import fastapi
import starlette.responses
from fastapi import HTTPException
from fastapi.responses import JSONResponse

Expand Down Expand Up @@ -62,6 +63,10 @@ def _make_baseten_error_headers(error_code: int) -> Mapping[str, str]:
}


def add_error_headers_to_user_response(response: starlette.responses.Response) -> None:
response.headers.update(_make_baseten_error_headers(_BASETEN_CLIENT_ERROR_CODE))


def _make_baseten_response(
http_status: int,
info: Union[str, Exception],
Expand All @@ -75,9 +80,7 @@ def _make_baseten_response(
)


async def exception_handler(
request: fastapi.Request, exc: Exception
) -> fastapi.Response:
async def exception_handler(_: fastapi.Request, exc: Exception) -> fastapi.Response:
if isinstance(exc, ModelMissingError):
return _make_baseten_response(
HTTPStatus.NOT_FOUND.value, exc, _BASETEN_DOWNSTREAM_ERROR_CODE
Expand Down
13 changes: 7 additions & 6 deletions truss/templates/server/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,18 @@ def from_signature(
if _is_request_type(param1.annotation):
raise errors.ModelDefinitionError(
f"`{method_name}` method with two arguments is not allowed to "
"have only request as first argument, must be second. "
"have request as first argument, request must be second. "
f"Got: {signature}"
)
if not (param2.annotation and _is_request_type(param1.annotation)):
if not (param2.annotation and _is_request_type(param2.annotation)):
raise errors.ModelDefinitionError(
f"`{method_name}` method with two arguments must have request as "
f"second argument (type annotated). Got: {signature} "
)
return cls.INPUTS_AND_REQUEST
else:
raise errors.ModelDefinitionError(
f"`{method_name}` method cannot have more than to arguments. "
f"`{method_name}` method cannot have more than two arguments. "
f"Got: {signature}"
)

Expand Down Expand Up @@ -545,7 +545,8 @@ async def __call__(
with errors.intercept_exceptions(self._logger):
raise errors.ModelDefinitionError(
"If the predict function returns a generator (streaming), "
"you cannot use postprocessing."
"you cannot use postprocessing. Include all processing in "
"the predict method."
)

if request.headers.get("accept") == "application/json":
Expand All @@ -564,8 +565,8 @@ async def __call__(
if self.model_descriptor.postprocess:
with errors.intercept_exceptions(self._logger):
raise errors.ModelDefinitionError(
"If the predict function returns a response object, "
"you cannot use postprocessing."
"If the predict function returns a response object, you cannot "
"use postprocessing."
)
else:
return predict_result
Expand Down
3 changes: 3 additions & 0 deletions truss/templates/server/truss_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import socket
import sys
import time
from http import HTTPStatus
from pathlib import Path
from typing import Dict, List, Optional, Union

Expand Down Expand Up @@ -195,6 +196,8 @@ async def predict(
# media_type in StreamingResponse sets the Content-Type header
return StreamingResponse(result, media_type="application/octet-stream")
elif isinstance(result, Response):
if result.status_code >= HTTPStatus.MULTIPLE_CHOICES.value:
errors.add_error_headers_to_user_response(result)
return result

response_headers = {}
Expand Down
7 changes: 4 additions & 3 deletions truss/tests/templates/server/test_model_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import importlib

Check failure on line 1 in truss/tests/templates/server/test_model_wrapper.py

View workflow job for this annotation

GitHub Actions / JUnit Test Report

test_model_wrapper.test_trt_llm_truss_init_extension[trio]

ModuleNotFoundError: No module named 'trio'
Raw output
asynclib_name = 'trio'

    def get_async_backend(asynclib_name: str | None = None) -> type[AsyncBackend]:
        if asynclib_name is None:
            asynclib_name = sniffio.current_async_library()
    
        # We use our own dict instead of sys.modules to get the already imported back-end
        # class because the appropriate modules in sys.modules could potentially be only
        # partially initialized
        try:
>           return loaded_backends[asynclib_name]
E           KeyError: 'trio'

.venv/lib/python3.9/site-packages/anyio/_core/_eventloop.py:162: KeyError

During handling of the above exception, another exception occurred:

pyfuncitem = <Function test_trt_llm_truss_init_extension[trio]>

    @pytest.hookimpl(tryfirst=True)
    def pytest_pyfunc_call(pyfuncitem: Any) -> bool | None:
        def run_with_hypothesis(**kwargs: Any) -> None:
            with get_runner(backend_name, backend_options) as runner:
                runner.run_test(original_func, kwargs)
    
        backend = pyfuncitem.funcargs.get("anyio_backend")
        if backend:
            backend_name, backend_options = extract_backend_and_options(backend)
    
            if hasattr(pyfuncitem.obj, "hypothesis"):
                # Wrap the inner test function unless it's already wrapped
                original_func = pyfuncitem.obj.hypothesis.inner_test
                if original_func.__qualname__ != run_with_hypothesis.__qualname__:
                    if iscoroutinefunction(original_func):
                        pyfuncitem.obj.hypothesis.inner_test = run_with_hypothesis
    
                return None
    
            if iscoroutinefunction(pyfuncitem.obj):
                funcargs = pyfuncitem.funcargs
                testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
>               with get_runner(backend_name, backend_options) as runner:

.venv/lib/python3.9/site-packages/anyio/pytest_plugin.py:123: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/opt/hostedtoolcache/Python/3.9.9/x64/lib/python3.9/contextlib.py:119: in __enter__
    return next(self.gen)
.venv/lib/python3.9/site-packages/anyio/pytest_plugin.py:35: in get_runner
    asynclib = get_async_backend(backend_name)
.venv/lib/python3.9/site-packages/anyio/_core/_eventloop.py:164: in get_async_backend
    module = import_module(f"anyio._backends._{asynclib_name}")
/opt/hostedtoolcache/Python/3.9.9/x64/lib/python3.9/importlib/__init__.py:127: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import array
    import math
    import socket
    import sys
    import types
    import weakref
    from collections.abc import AsyncIterator, Iterable
    from concurrent.futures import Future
    from dataclasses import dataclass
    from functools import partial
    from io import IOBase
    from os import PathLike
    from signal import Signals
    from socket import AddressFamily, SocketKind
    from types import TracebackType
    from typing import (
        IO,
        Any,
        AsyncGenerator,
        Awaitable,
        Callable,
        Collection,
        ContextManager,
        Coroutine,
        Generic,
        Mapping,
        NoReturn,
        Sequence,
        TypeVar,
        cast,
        overload,
    )
    
>   import trio.from_thread
E   ModuleNotFoundError: No module named 'trio'

.venv/lib/python3.9/site-packages/anyio/_backends/_trio.py:36: ModuleNotFoundError

Check failure on line 1 in truss/tests/templates/server/test_model_wrapper.py

View workflow job for this annotation

GitHub Actions / JUnit Test Report

test_model_wrapper.test_trt_llm_truss_predict[trio]

ModuleNotFoundError: No module named 'trio'
Raw output
asynclib_name = 'trio'

    def get_async_backend(asynclib_name: str | None = None) -> type[AsyncBackend]:
        if asynclib_name is None:
            asynclib_name = sniffio.current_async_library()
    
        # We use our own dict instead of sys.modules to get the already imported back-end
        # class because the appropriate modules in sys.modules could potentially be only
        # partially initialized
        try:
>           return loaded_backends[asynclib_name]
E           KeyError: 'trio'

.venv/lib/python3.9/site-packages/anyio/_core/_eventloop.py:162: KeyError

During handling of the above exception, another exception occurred:

pyfuncitem = <Function test_trt_llm_truss_predict[trio]>

    @pytest.hookimpl(tryfirst=True)
    def pytest_pyfunc_call(pyfuncitem: Any) -> bool | None:
        def run_with_hypothesis(**kwargs: Any) -> None:
            with get_runner(backend_name, backend_options) as runner:
                runner.run_test(original_func, kwargs)
    
        backend = pyfuncitem.funcargs.get("anyio_backend")
        if backend:
            backend_name, backend_options = extract_backend_and_options(backend)
    
            if hasattr(pyfuncitem.obj, "hypothesis"):
                # Wrap the inner test function unless it's already wrapped
                original_func = pyfuncitem.obj.hypothesis.inner_test
                if original_func.__qualname__ != run_with_hypothesis.__qualname__:
                    if iscoroutinefunction(original_func):
                        pyfuncitem.obj.hypothesis.inner_test = run_with_hypothesis
    
                return None
    
            if iscoroutinefunction(pyfuncitem.obj):
                funcargs = pyfuncitem.funcargs
                testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
>               with get_runner(backend_name, backend_options) as runner:

.venv/lib/python3.9/site-packages/anyio/pytest_plugin.py:123: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/opt/hostedtoolcache/Python/3.9.9/x64/lib/python3.9/contextlib.py:119: in __enter__
    return next(self.gen)
.venv/lib/python3.9/site-packages/anyio/pytest_plugin.py:35: in get_runner
    asynclib = get_async_backend(backend_name)
.venv/lib/python3.9/site-packages/anyio/_core/_eventloop.py:164: in get_async_backend
    module = import_module(f"anyio._backends._{asynclib_name}")
/opt/hostedtoolcache/Python/3.9.9/x64/lib/python3.9/importlib/__init__.py:127: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import array
    import math
    import socket
    import sys
    import types
    import weakref
    from collections.abc import AsyncIterator, Iterable
    from concurrent.futures import Future
    from dataclasses import dataclass
    from functools import partial
    from io import IOBase
    from os import PathLike
    from signal import Signals
    from socket import AddressFamily, SocketKind
    from types import TracebackType
    from typing import (
        IO,
        Any,
        AsyncGenerator,
        Awaitable,
        Callable,
        Collection,
        ContextManager,
        Coroutine,
        Generic,
        Mapping,
        NoReturn,
        Sequence,
        TypeVar,
        cast,
        overload,
    )
    
>   import trio.from_thread
E   ModuleNotFoundError: No module named 'trio'

.venv/lib/python3.9/site-packages/anyio/_backends/_trio.py:36: ModuleNotFoundError

Check failure on line 1 in truss/tests/templates/server/test_model_wrapper.py

View workflow job for this annotation

GitHub Actions / JUnit Test Report

test_model_wrapper.test_trt_llm_truss_missing_model_py[trio]

ModuleNotFoundError: No module named 'trio'
Raw output
asynclib_name = 'trio'

    def get_async_backend(asynclib_name: str | None = None) -> type[AsyncBackend]:
        if asynclib_name is None:
            asynclib_name = sniffio.current_async_library()
    
        # We use our own dict instead of sys.modules to get the already imported back-end
        # class because the appropriate modules in sys.modules could potentially be only
        # partially initialized
        try:
>           return loaded_backends[asynclib_name]
E           KeyError: 'trio'

.venv/lib/python3.9/site-packages/anyio/_core/_eventloop.py:162: KeyError

During handling of the above exception, another exception occurred:

pyfuncitem = <Function test_trt_llm_truss_missing_model_py[trio]>

    @pytest.hookimpl(tryfirst=True)
    def pytest_pyfunc_call(pyfuncitem: Any) -> bool | None:
        def run_with_hypothesis(**kwargs: Any) -> None:
            with get_runner(backend_name, backend_options) as runner:
                runner.run_test(original_func, kwargs)
    
        backend = pyfuncitem.funcargs.get("anyio_backend")
        if backend:
            backend_name, backend_options = extract_backend_and_options(backend)
    
            if hasattr(pyfuncitem.obj, "hypothesis"):
                # Wrap the inner test function unless it's already wrapped
                original_func = pyfuncitem.obj.hypothesis.inner_test
                if original_func.__qualname__ != run_with_hypothesis.__qualname__:
                    if iscoroutinefunction(original_func):
                        pyfuncitem.obj.hypothesis.inner_test = run_with_hypothesis
    
                return None
    
            if iscoroutinefunction(pyfuncitem.obj):
                funcargs = pyfuncitem.funcargs
                testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
>               with get_runner(backend_name, backend_options) as runner:

.venv/lib/python3.9/site-packages/anyio/pytest_plugin.py:123: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/opt/hostedtoolcache/Python/3.9.9/x64/lib/python3.9/contextlib.py:119: in __enter__
    return next(self.gen)
.venv/lib/python3.9/site-packages/anyio/pytest_plugin.py:35: in get_runner
    asynclib = get_async_backend(backend_name)
.venv/lib/python3.9/site-packages/anyio/_core/_eventloop.py:164: in get_async_backend
    module = import_module(f"anyio._backends._{asynclib_name}")
/opt/hostedtoolcache/Python/3.9.9/x64/lib/python3.9/importlib/__init__.py:127: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import array
    import math
    import socket
    import sys
    import types
    import weakref
    from collections.abc import AsyncIterator, Iterable
    from concurrent.futures import Future
    from dataclasses import dataclass
    from functools import partial
    from io import IOBase
    from os import PathLike
    from signal import Signals
    from socket import AddressFamily, SocketKind
    from types import TracebackType
    from typing import (
        IO,
        Any,
        AsyncGenerator,
        Awaitable,
        Callable,
        Collection,
        ContextManager,
        Coroutine,
        Generic,
        Mapping,
        NoReturn,
        Sequence,
        TypeVar,
        cast,
        overload,
    )
    
>   import trio.from_thread
E   ModuleNotFoundError: No module named 'trio'

.venv/lib/python3.9/site-packages/anyio/_backends/_trio.py:36: ModuleNotFoundError
import os
import sys
import time
Expand Down Expand Up @@ -75,6 +75,7 @@
assert model_wrapper.load_failed()


@pytest.mark.anyio
@pytest.mark.integration
async def test_model_wrapper_streaming_timeout(app_path):
if "model_wrapper" in sys.modules:
Expand All @@ -92,7 +93,7 @@
assert model_wrapper._config.get("runtime").get("streaming_read_timeout") == 5


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_trt_llm_truss_init_extension(trt_llm_truss_container_fs, helpers):
app_path = trt_llm_truss_container_fs / "app"
packages_path = trt_llm_truss_container_fs / "packages"
Expand All @@ -116,7 +117,7 @@
), "Expected extension_name was not called"


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_trt_llm_truss_predict(trt_llm_truss_container_fs, helpers):
app_path = trt_llm_truss_container_fs / "app"
packages_path = trt_llm_truss_container_fs / "packages"
Expand Down Expand Up @@ -151,7 +152,7 @@
assert resp == expected_predict_response


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_trt_llm_truss_missing_model_py(trt_llm_truss_container_fs, helpers):
app_path = trt_llm_truss_container_fs / "app"
(app_path / "model" / "model.py").unlink()
Expand Down
Loading
Loading