Skip to content

Commit

Permalink
add custom api path
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya committed Jun 13, 2024
1 parent 5e6eedd commit 066ee66
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def __init__(
timeout: Union[float, bool] = 30,
max_batch_size: int = 1,
batch_timeout: float = 0.0,
api_path: str = "/predict",
stream: bool = False,
spec: Optional[LitSpec] = None,
):
Expand All @@ -323,6 +324,13 @@ def __init__(
if isinstance(spec, OpenAISpec):
stream = True

if not api_path.startswith("/"):
raise ValueError(
"api_path must start with '/'. "
"Please provide a valid api path like '/predict', '/classify', or '/v1/predict'"
)

self.api_path = api_path
lit_api.stream = stream
lit_api.sanitize(max_batch_size, spec=spec)
self.app = FastAPI(lifespan=self.lifespan)
Expand Down Expand Up @@ -549,7 +557,7 @@ async def stream_predict(request: self.request_type, background_tasks: Backgroun
stream = self.lit_api.stream
# In the future we might want to differentiate endpoints for streaming vs non-streaming
# For now we allow either one or the other
endpoint = "/predict"
endpoint = self.api_path
methods = ["POST"]
self.app.add_api_route(
endpoint, stream_predict if stream else predict, methods=methods, dependencies=[Depends(setup_auth())]
Expand Down
12 changes: 12 additions & 0 deletions tests/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)
from litserve.server import LitServer
import litserve as ls
from fastapi.testclient import TestClient


def test_index(sync_testclient):
Expand Down Expand Up @@ -370,3 +371,14 @@ def dummy_load_and_raise(resp):
with pytest.raises(TypeError, match=re.escape("predict() missing 1 required positional argument: 'y'")):
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
resp = await ac.post("/predict", json={"input": 5.0}, timeout=10)


def test_custom_api_path(sync_testclient):
with pytest.raises(ValueError, match="api_path must start with '/'. "):
LitServer(SimpleLitAPI(), api_path="predict")

server = LitServer(SimpleLitAPI(), api_path="/v1/custom_predict")
url = server.api_path
with TestClient(server.app) as client:
response = client.post(url, json={"input": 4.0})
assert response.status_code == 200

0 comments on commit 066ee66

Please sign in to comment.