diff --git a/src/litserve/server.py b/src/litserve/server.py index 42e4f097..f676c7f1 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -43,7 +43,6 @@ from litserve.loops import LitLoop, get_default_loop, inference_worker from litserve.middlewares import MaxSizeMiddleware, RequestCountMiddleware from litserve.python_client import client_template -from litserve.specs import OpenAISpec from litserve.specs.base import LitSpec from litserve.utils import LitAPIStatus, WorkerSetupStatus, call_after_stream @@ -146,8 +145,8 @@ def __init__( raise ValueError("batch_timeout must be less than timeout") if max_batch_size <= 0: raise ValueError("max_batch_size must be greater than 0") - if isinstance(spec, OpenAISpec): - stream = True + if isinstance(spec, LitSpec): + stream = spec.stream if loop is None: loop = "auto" diff --git a/src/litserve/specs/base.py b/src/litserve/specs/base.py index 6f5afc7b..2db26e95 100644 --- a/src/litserve/specs/base.py +++ b/src/litserve/specs/base.py @@ -26,6 +26,10 @@ def __init__(self): self._server: LitServer = None + @property + def stream(self): + return False + def pre_setup(self, lit_api: "LitAPI"): pass diff --git a/src/litserve/specs/openai.py b/src/litserve/specs/openai.py index f7447d88..90a63618 100644 --- a/src/litserve/specs/openai.py +++ b/src/litserve/specs/openai.py @@ -262,6 +262,10 @@ def __init__( self.add_endpoint("/v1/chat/completions", self.chat_completion, ["POST"]) self.add_endpoint("/v1/chat/completions", self.options_chat_completions, ["OPTIONS"]) + @property + def stream(self): + return True + def pre_setup(self, lit_api: "LitAPI"): from litserve import LitAPI