From 1fa71aba74de797bf8b9c4a58cf5e52de8ddef43 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 12 Dec 2024 10:39:09 +0000 Subject: [PATCH] update validation --- src/litserve/loops.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/litserve/loops.py b/src/litserve/loops.py index 07188cfe..521c6783 100644 --- a/src/litserve/loops.py +++ b/src/litserve/loops.py @@ -725,11 +725,18 @@ def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]): "Continuous batching loop requires streaming to be enabled. Please set LitServe(..., stream=True)" ) - if lit_api.max_batch_size <= 1: - raise ValueError( - "Continuous batching loop requires max_batch_size to be greater than 1. " - "Please set LitServe(..., max_batch_size=2)" - ) + if not hasattr(lit_api, "step") and not hasattr(lit_api, "predict"): + raise ValueError("""Using the default step method with Continuous batching loop requires the lit_api to +have a `predict` method which accepts decoded request inputs and a list of generated_sequence. +Please implement the has_finished method in the lit_api. + + class ExampleAPI(LitAPI): + ... + def predict(self, inputs, generated_sequence): + # implement predict logic + # return list of new tokens + ... + """) if not hasattr(lit_api, "step") and not hasattr(lit_api, "has_finished"): raise ValueError("""Using the default step method with Continuous batching loop