Skip to content

Commit

Permalink
Merge pull request #191 from rstudio/wb-update
Browse files Browse the repository at this point in the history
  • Loading branch information
isabelizimm authored Sep 14, 2023
2 parents be4f6b6 + d5dce53 commit 5639bf6
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 13 deletions.
31 changes: 23 additions & 8 deletions vetiver/server.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
from typing import Callable, List, Union
from urllib.parse import urljoin

import re
import httpx
import json
import pandas as pd
import requests
import uvicorn
import logging
import pandas as pd
from fastapi import FastAPI, Request, testclient
from fastapi.exceptions import RequestValidationError
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.responses import PlainTextResponse
from fastapi.responses import HTMLResponse, RedirectResponse, PlainTextResponse
from textwrap import dedent
from warnings import warn
from urllib.parse import urljoin
from typing import Callable, List, Union

from .utils import _jupyter_nb
from .utils import _jupyter_nb, get_workbench_path
from .vetiver_model import VetiverModel
from .meta import VetiverMeta
from .helpers import api_data_to_frame, response_to_frame
Expand Down Expand Up @@ -71,6 +70,7 @@ def __init__(
self.model = model
self.app_factory = app_factory
self.app = app_factory()
self.workbench_path = None

if "check_ptype" in kwargs:
check_prototype = kwargs.pop("check_ptype")
Expand All @@ -93,6 +93,14 @@ def _init_app(self):
app = self.app
app.openapi = self._custom_openapi

@app.on_event("startup")
async def startup_event():
logger = logging.getLogger("uvicorn.error")
if self.workbench_path:
logger.info(f"VetiverAPI starting at {self.workbench_path}")
else:
logger.info("VetiverAPI starting...")

@app.get("/", include_in_schema=False)
def docs_redirect():

Expand Down Expand Up @@ -261,7 +269,14 @@ def run(self, port: int = 8000, host: str = "127.0.0.1", **kw):
>>> v_api.run() # doctest: +SKIP
"""
_jupyter_nb()
uvicorn.run(self.app, port=port, host=host, **kw)
self.workbench_path = get_workbench_path(port)

if self.workbench_path:
uvicorn.run(
self.app, port=port, host=host, root_path=self.workbench_path, **kw
)
else:
uvicorn.run(self.app, port=port, host=host, **kw)

def _custom_openapi(self):
import vetiver
Expand Down
13 changes: 10 additions & 3 deletions vetiver/tests/test_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,17 @@
from vetiver.data import mtcars # noqa
from vetiver.handlers.xgboost import XGBoostHandler # noqa
import numpy as np # noqa
import sys # noqa
from fastapi.testclient import TestClient # noqa

import vetiver # noqa

# hack since xgboost 2.0 dropped 3.7 support
if sys.version_info[0] == 3 and sys.version_info[1] < 8:
PREDICT_VALUE = 21.064373016357422
else:
PREDICT_VALUE = 19.963224411010742


@pytest.fixture
def xgb_model():
Expand Down Expand Up @@ -57,7 +64,7 @@ def test_vetiver_build(vetiver_client):

response = vetiver.predict(endpoint=vetiver_client, data=data)

assert response.iloc[0, 0] == 21.064373016357422
assert response.iloc[0, 0] == PREDICT_VALUE
assert len(response) == 1


Expand All @@ -66,7 +73,7 @@ def test_batch(vetiver_client):

response = vetiver.predict(endpoint=vetiver_client, data=data)

assert response.iloc[0, 0] == 21.064373016357422
assert response.iloc[0, 0] == PREDICT_VALUE
assert len(response) == 3


Expand All @@ -75,7 +82,7 @@ def test_no_ptype(vetiver_client_check_ptype_false):

response = vetiver.predict(endpoint=vetiver_client_check_ptype_false, data=data)

assert response.iloc[0, 0] == 21.064373016357422
assert response.iloc[0, 0] == PREDICT_VALUE
assert len(response) == 1


Expand Down
26 changes: 24 additions & 2 deletions vetiver/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import nest_asyncio
import warnings
import sys
import os
import subprocess
from types import SimpleNamespace

no_notebook = False
Expand All @@ -14,8 +16,8 @@ def _jupyter_nb():

if not no_notebook:
warnings.warn(
"WARNING: Jupyter Notebooks are not considered stable environments "
"for production code"
"You may be running from a notebook environment. Jupyter Notebooks are "
"not considered stable environments for production code"
)
nest_asyncio.apply()
else:
Expand All @@ -31,3 +33,23 @@ def inform(log, msg):

if not modelcard_options.quiet:
print(msg, file=sys.stderr)


def get_workbench_path(port):
# check to see if in Posit Workbench, pulled from FastAPI section of user guide
# https://docs.posit.co/ide/server-pro/user/vs-code/guide/proxying-web-servers.html#running-fastapi-with-uvicorn # noqa

if "RS_SERVER_URL" in os.environ and os.environ["RS_SERVER_URL"]:
path = (
subprocess.run(
f"echo $(/usr/lib/rstudio-server/bin/rserver-url -l {port})",
stdout=subprocess.PIPE,
shell=True,
)
.stdout.decode()
.strip()
)
# subprocess is run, new URL given
return path
else:
return None

0 comments on commit 5639bf6

Please sign in to comment.