Skip to content

Commit

Permalink
feat: adds get_user and get_current_user methods. (#7)
Browse files Browse the repository at this point in the history
Adds a new property to the Client class called 'users', which is an
instance of a User class. The User class is responsible for managing
requests to the /v1/user and /v1/users endpoints.

Additionally, a hook is added to check and parse client errors from
Connect.
  • Loading branch information
tdstein authored Jan 31, 2024
1 parent cd757b8 commit 07cc7bf
Show file tree
Hide file tree
Showing 12 changed files with 206 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/posit/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


class Auth(AuthBase):
def __init__(self, key) -> None:
def __init__(self, key: str) -> None:
self.key = key

def __call__(self, r: PreparedRequest) -> PreparedRequest:
Expand Down
30 changes: 22 additions & 8 deletions src/posit/client.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,36 @@
from requests import Session
from typing import Optional

from . import hooks

from .auth import Auth
from .config import ConfigBuilder
from .users import Users


class Client:
users: Users

def __init__(
self, endpoint: Optional[str] = None, api_key: Optional[str] = None
self,
api_key: Optional[str] = None,
endpoint: Optional[str] = None,
) -> None:
builder = ConfigBuilder()
builder.set_api_key(api_key)
builder.set_endpoint(endpoint)
if api_key:
builder.set_api_key(api_key)
if endpoint:
builder.set_endpoint(endpoint)
self._config = builder.build()

if self._config.api_key is None:
raise ValueError("Invalid value for 'api_key': Must be a non-empty string.")
if self._config.endpoint is None:
raise ValueError(
"Invalid value for 'endpoint': Must be a non-empty string."
)

self._session = Session()
self._session.hooks["response"].append(hooks.handle_errors)
self._session.auth = Auth(self._config.api_key)

def get(self, endpoint: str, *args, **kwargs): # pragma: no cover
return self._session.request(
"GET", f"{self._config.endpoint}/{endpoint}", *args, **kwargs
)
self.users = Users(self._config.endpoint, self._session)
30 changes: 30 additions & 0 deletions src/posit/client_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from unittest.mock import MagicMock, Mock, patch

from .client import Client
Expand All @@ -24,3 +26,31 @@ def test_init(self, Auth: MagicMock, ConfigBuilder: MagicMock, Session: MagicMoc
Session.assert_called_once()
Auth.assert_called_once_with(api_key)
assert client._config == config

@patch("posit.client.ConfigBuilder")
def test_init_without_api_key(self, ConfigBuilder: MagicMock):
api_key = None
endpoint = "http://foo.bar"
config = Mock()
config.api_key = api_key
config.endpoint = endpoint
builder = ConfigBuilder.return_value
builder.set_api_key = Mock()
builder.set_endpoint = Mock()
builder.build = Mock(return_value=config)
with pytest.raises(ValueError):
Client(api_key=api_key, endpoint=endpoint)

@patch("posit.client.ConfigBuilder")
def test_init_without_endpoint(self, ConfigBuilder: MagicMock):
api_key = "foobar"
endpoint = None
config = Mock()
config.api_key = api_key
config.endpoint = endpoint
builder = ConfigBuilder.return_value
builder.set_api_key = Mock()
builder.set_endpoint = Mock()
builder.build = Mock(return_value=config)
with pytest.raises(ValueError):
Client(api_key=api_key, endpoint=endpoint)
24 changes: 19 additions & 5 deletions src/posit/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import dataclasses

from abc import ABC, abstractmethod
from dataclasses import dataclass
Expand All @@ -20,10 +21,22 @@ def get_value(self, key: str) -> Optional[str]:
class EnvironmentConfigProvider(ConfigProvider):
def get_value(self, key: str) -> Optional[str]:
if key == "api_key":
return os.environ.get("CONNECT_API_KEY")
value = os.environ.get("CONNECT_API_KEY")
if value:
return value
if value == "":
raise ValueError(
"Invalid value for 'CONNECT_API_KEY': Must be a non-empty string."
)

if key == "endpoint":
return os.environ.get("CONNECT_SERVER")
value = os.environ.get("CONNECT_SERVER")
if value:
return os.path.join(value, "__api__")
if value == "":
raise ValueError(
"Invalid value for 'CONNECT_SERVER': Must be a non-empty string."
)

return None

Expand All @@ -36,7 +49,8 @@ def __init__(
self._providers = providers

def build(self) -> Config:
for key in Config.__annotations__:
for field in dataclasses.fields(Config):
key = field.name
if not getattr(self._config, key):
setattr(
self._config,
Expand All @@ -47,8 +61,8 @@ def build(self) -> Config:
)
return self._config

def set_api_key(self, api_key: Optional[str]):
def set_api_key(self, api_key: str):
self._config.api_key = api_key

def set_endpoint(self, endpoint: Optional[str]):
def set_endpoint(self, endpoint: str):
self._config.endpoint = endpoint
26 changes: 25 additions & 1 deletion src/posit/config_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from unittest.mock import Mock, patch

from .config import Config, ConfigBuilder, EnvironmentConfigProvider
Expand All @@ -10,11 +12,33 @@ def test_get_api_key(self):
api_key = provider.get_value("api_key")
assert api_key == "foobar"

@patch.dict("os.environ", {"CONNECT_API_KEY": ""})
def test_get_api_key_empty(self):
provider = EnvironmentConfigProvider()
with pytest.raises(ValueError):
provider.get_value("api_key")

def test_get_api_key_miss(self):
provider = EnvironmentConfigProvider()
api_key = provider.get_value("api_key")
assert api_key is None

@patch.dict("os.environ", {"CONNECT_SERVER": "http://foo.bar"})
def test_get_endpoint(self):
provider = EnvironmentConfigProvider()
endpoint = provider.get_value("endpoint")
assert endpoint == "http://foo.bar"
assert endpoint == "http://foo.bar/__api__"

@patch.dict("os.environ", {"CONNECT_SERVER": ""})
def test_get_endpoint_empty(self):
provider = EnvironmentConfigProvider()
with pytest.raises(ValueError):
provider.get_value("endpoint")

def test_get_endpoint_miss(self):
provider = EnvironmentConfigProvider()
endpoint = provider.get_value("endpoint")
assert endpoint is None

def test_get_value_miss(self):
provider = EnvironmentConfigProvider()
Expand Down
11 changes: 11 additions & 0 deletions src/posit/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
class ClientError(Exception):
def __init__(
self, error_code: int, error_message: str, http_status: int, http_message: str
):
self.error_code = error_code
self.error_message = error_message
self.http_status = http_status
self.http_message = http_message
super().__init__(
f"{error_message} (Error Code: {error_code}, HTTP Status: {http_status} {http_message})"
)
20 changes: 20 additions & 0 deletions src/posit/errors_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest

from .errors import ClientError


class TestClientError:
def test(self):
error_code = 0
error_message = "foo"
http_status = 404
http_message = "Foo Bar"
with pytest.raises(
ClientError, match=r"foo \(Error Code: 0, HTTP Status: 404 Foo Bar\)"
):
raise ClientError(
error_code=error_code,
error_message=error_message,
http_status=http_status,
http_message=http_message,
)
15 changes: 15 additions & 0 deletions src/posit/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from http.client import responses
from requests import Response

from .errors import ClientError


def handle_errors(response: Response, *args, **kwargs) -> Response:
if response.status_code >= 400 and response.status_code < 500:
data = response.json()
error_code = data["code"]
message = data["error"]
http_status = response.status_code
http_status_message = responses[http_status]
raise ClientError(error_code, message, http_status, http_status_message)
return response
20 changes: 20 additions & 0 deletions src/posit/hooks_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest

from unittest.mock import Mock

from .hooks import handle_errors


class TestHandleErrors:
def test(self):
response = Mock()
response.status_code = 200
assert handle_errors(response) == response

def test_client_error(self):
response = Mock()
response.status_code = 400
response.json = Mock()
response.json.return_value = {"code": 0, "error": "foobar"}
with pytest.raises(Exception):
handle_errors(response)
17 changes: 17 additions & 0 deletions src/posit/users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import os

from requests import Session, Response


class Users:
def __init__(self, endpoint: str, session: Session) -> None:
self._endpoint = endpoint
self._session = session

def get_user(self, user_id: str) -> Response:
endpoint = os.path.join(self._endpoint, "v1/users", user_id)
return self._session.get(endpoint)

def get_current_user(self) -> Response:
endpoint = os.path.join(self._endpoint, "v1/user")
return self._session.get(endpoint)
21 changes: 21 additions & 0 deletions src/posit/users_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from unittest.mock import Mock

from .users import Users


class TestUsers:
def test_get_user(self):
session = Mock()
session.get = Mock(return_value={})
users = Users(endpoint="http://foo.bar/", session=session)
response = users.get_user(user_id="foo")
assert response == {}
session.get.assert_called_once_with("http://foo.bar/v1/users/foo")

def test_get_current_user(self):
session = Mock()
session.get = Mock(return_value={})
users = Users(endpoint="http://foo.bar/", session=session)
response = users.get_current_user()
assert response == {}
session.get.assert_called_once_with("http://foo.bar/v1/user")
5 changes: 5 additions & 0 deletions tinkering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from posit.client import Client

client = Client()
res = client.users.get_current_user()
print(res.json())

0 comments on commit 07cc7bf

Please sign in to comment.