diff --git a/vetiver/server.py b/vetiver/server.py index eea55ab..cf7d24b 100644 --- a/vetiver/server.py +++ b/vetiver/server.py @@ -1,21 +1,20 @@ -from typing import Callable, List, Union -from urllib.parse import urljoin - import re import httpx import json -import pandas as pd import requests import uvicorn +import logging +import pandas as pd from fastapi import FastAPI, Request, testclient from fastapi.exceptions import RequestValidationError from fastapi.openapi.utils import get_openapi -from fastapi.responses import HTMLResponse, RedirectResponse -from fastapi.responses import PlainTextResponse +from fastapi.responses import HTMLResponse, RedirectResponse, PlainTextResponse from textwrap import dedent from warnings import warn +from urllib.parse import urljoin +from typing import Callable, List, Union -from .utils import _jupyter_nb +from .utils import _jupyter_nb, get_workbench_path from .vetiver_model import VetiverModel from .meta import VetiverMeta from .helpers import api_data_to_frame, response_to_frame @@ -71,6 +70,7 @@ def __init__( self.model = model self.app_factory = app_factory self.app = app_factory() + self.workbench_path = None if "check_ptype" in kwargs: check_prototype = kwargs.pop("check_ptype") @@ -93,6 +93,14 @@ def _init_app(self): app = self.app app.openapi = self._custom_openapi + @app.on_event("startup") + async def startup_event(): + logger = logging.getLogger("uvicorn.error") + if self.workbench_path: + logger.info(f"VetiverAPI starting at {self.workbench_path}") + else: + logger.info("VetiverAPI starting...") + @app.get("/", include_in_schema=False) def docs_redirect(): @@ -261,7 +269,14 @@ def run(self, port: int = 8000, host: str = "127.0.0.1", **kw): >>> v_api.run() # doctest: +SKIP """ _jupyter_nb() - uvicorn.run(self.app, port=port, host=host, **kw) + self.workbench_path = get_workbench_path(port) + + if self.workbench_path: + uvicorn.run( + self.app, port=port, host=host, root_path=self.workbench_path, **kw + ) + else: + uvicorn.run(self.app, port=port, host=host, **kw) def _custom_openapi(self): import vetiver diff --git a/vetiver/tests/test_xgboost.py b/vetiver/tests/test_xgboost.py index 425898a..ed54c86 100644 --- a/vetiver/tests/test_xgboost.py +++ b/vetiver/tests/test_xgboost.py @@ -5,10 +5,17 @@ from vetiver.data import mtcars # noqa from vetiver.handlers.xgboost import XGBoostHandler # noqa import numpy as np # noqa +import sys # noqa from fastapi.testclient import TestClient # noqa import vetiver # noqa +# hack since xgboost 2.0 dropped 3.7 support +if sys.version_info[0] == 3 and sys.version_info[1] < 8: + PREDICT_VALUE = 21.064373016357422 +else: + PREDICT_VALUE = 19.963224411010742 + @pytest.fixture def xgb_model(): @@ -57,7 +64,7 @@ def test_vetiver_build(vetiver_client): response = vetiver.predict(endpoint=vetiver_client, data=data) - assert response.iloc[0, 0] == 21.064373016357422 + assert response.iloc[0, 0] == PREDICT_VALUE assert len(response) == 1 @@ -66,7 +73,7 @@ def test_batch(vetiver_client): response = vetiver.predict(endpoint=vetiver_client, data=data) - assert response.iloc[0, 0] == 21.064373016357422 + assert response.iloc[0, 0] == PREDICT_VALUE assert len(response) == 3 @@ -75,7 +82,7 @@ def test_no_ptype(vetiver_client_check_ptype_false): response = vetiver.predict(endpoint=vetiver_client_check_ptype_false, data=data) - assert response.iloc[0, 0] == 21.064373016357422 + assert response.iloc[0, 0] == PREDICT_VALUE assert len(response) == 1 diff --git a/vetiver/utils.py b/vetiver/utils.py index dd7ecb1..c910aea 100644 --- a/vetiver/utils.py +++ b/vetiver/utils.py @@ -1,6 +1,8 @@ import nest_asyncio import warnings import sys +import os +import subprocess from types import SimpleNamespace no_notebook = False @@ -14,8 +16,8 @@ def _jupyter_nb(): if not no_notebook: warnings.warn( - "WARNING: Jupyter Notebooks are not considered stable environments " - "for production code" + "You may be running from a notebook environment. Jupyter Notebooks are " + "not considered stable environments for production code" ) nest_asyncio.apply() else: @@ -31,3 +33,23 @@ def inform(log, msg): if not modelcard_options.quiet: print(msg, file=sys.stderr) + + +def get_workbench_path(port): + # check to see if in Posit Workbench, pulled from FastAPI section of user guide + # https://docs.posit.co/ide/server-pro/user/vs-code/guide/proxying-web-servers.html#running-fastapi-with-uvicorn # noqa + + if "RS_SERVER_URL" in os.environ and os.environ["RS_SERVER_URL"]: + path = ( + subprocess.run( + f"echo $(/usr/lib/rstudio-server/bin/rserver-url -l {port})", + stdout=subprocess.PIPE, + shell=True, + ) + .stdout.decode() + .strip() + ) + # subprocess is run, new URL given + return path + else: + return None