diff --git a/src/posit/auth.py b/src/posit/auth.py index 546281b5..44bdc00e 100644 --- a/src/posit/auth.py +++ b/src/posit/auth.py @@ -3,7 +3,7 @@ class Auth(AuthBase): - def __init__(self, key) -> None: + def __init__(self, key: str) -> None: self.key = key def __call__(self, r: PreparedRequest) -> PreparedRequest: diff --git a/src/posit/client.py b/src/posit/client.py index 6e760313..a5ecae2c 100644 --- a/src/posit/client.py +++ b/src/posit/client.py @@ -1,22 +1,36 @@ from requests import Session from typing import Optional +from . import hooks + from .auth import Auth from .config import ConfigBuilder +from .users import Users class Client: + users: Users + def __init__( - self, endpoint: Optional[str] = None, api_key: Optional[str] = None + self, + api_key: Optional[str] = None, + endpoint: Optional[str] = None, ) -> None: builder = ConfigBuilder() - builder.set_api_key(api_key) - builder.set_endpoint(endpoint) + if api_key: + builder.set_api_key(api_key) + if endpoint: + builder.set_endpoint(endpoint) self._config = builder.build() + + if self._config.api_key is None: + raise ValueError("Invalid value for 'api_key': Must be a non-empty string.") + if self._config.endpoint is None: + raise ValueError( + "Invalid value for 'endpoint': Must be a non-empty string." + ) + self._session = Session() + self._session.hooks["response"].append(hooks.handle_errors) self._session.auth = Auth(self._config.api_key) - - def get(self, endpoint: str, *args, **kwargs): # pragma: no cover - return self._session.request( - "GET", f"{self._config.endpoint}/{endpoint}", *args, **kwargs - ) + self.users = Users(self._config.endpoint, self._session) diff --git a/src/posit/client_test.py b/src/posit/client_test.py index 1c3ba960..d10e534a 100644 --- a/src/posit/client_test.py +++ b/src/posit/client_test.py @@ -1,3 +1,5 @@ +import pytest + from unittest.mock import MagicMock, Mock, patch from .client import Client @@ -24,3 +26,31 @@ def test_init(self, Auth: MagicMock, ConfigBuilder: MagicMock, Session: MagicMoc Session.assert_called_once() Auth.assert_called_once_with(api_key) assert client._config == config + + @patch("posit.client.ConfigBuilder") + def test_init_without_api_key(self, ConfigBuilder: MagicMock): + api_key = None + endpoint = "http://foo.bar" + config = Mock() + config.api_key = api_key + config.endpoint = endpoint + builder = ConfigBuilder.return_value + builder.set_api_key = Mock() + builder.set_endpoint = Mock() + builder.build = Mock(return_value=config) + with pytest.raises(ValueError): + Client(api_key=api_key, endpoint=endpoint) + + @patch("posit.client.ConfigBuilder") + def test_init_without_endpoint(self, ConfigBuilder: MagicMock): + api_key = "foobar" + endpoint = None + config = Mock() + config.api_key = api_key + config.endpoint = endpoint + builder = ConfigBuilder.return_value + builder.set_api_key = Mock() + builder.set_endpoint = Mock() + builder.build = Mock(return_value=config) + with pytest.raises(ValueError): + Client(api_key=api_key, endpoint=endpoint) diff --git a/src/posit/config.py b/src/posit/config.py index 534305df..ecbfc131 100644 --- a/src/posit/config.py +++ b/src/posit/config.py @@ -1,4 +1,5 @@ import os +import dataclasses from abc import ABC, abstractmethod from dataclasses import dataclass @@ -20,10 +21,22 @@ def get_value(self, key: str) -> Optional[str]: class EnvironmentConfigProvider(ConfigProvider): def get_value(self, key: str) -> Optional[str]: if key == "api_key": - return os.environ.get("CONNECT_API_KEY") + value = os.environ.get("CONNECT_API_KEY") + if value: + return value + if value == "": + raise ValueError( + "Invalid value for 'CONNECT_API_KEY': Must be a non-empty string." + ) if key == "endpoint": - return os.environ.get("CONNECT_SERVER") + value = os.environ.get("CONNECT_SERVER") + if value: + return os.path.join(value, "__api__") + if value == "": + raise ValueError( + "Invalid value for 'CONNECT_SERVER': Must be a non-empty string." + ) return None @@ -36,7 +49,8 @@ def __init__( self._providers = providers def build(self) -> Config: - for key in Config.__annotations__: + for field in dataclasses.fields(Config): + key = field.name if not getattr(self._config, key): setattr( self._config, @@ -47,8 +61,8 @@ def build(self) -> Config: ) return self._config - def set_api_key(self, api_key: Optional[str]): + def set_api_key(self, api_key: str): self._config.api_key = api_key - def set_endpoint(self, endpoint: Optional[str]): + def set_endpoint(self, endpoint: str): self._config.endpoint = endpoint diff --git a/src/posit/config_test.py b/src/posit/config_test.py index 35cb58ff..10e431d2 100644 --- a/src/posit/config_test.py +++ b/src/posit/config_test.py @@ -1,3 +1,5 @@ +import pytest + from unittest.mock import Mock, patch from .config import Config, ConfigBuilder, EnvironmentConfigProvider @@ -10,11 +12,33 @@ def test_get_api_key(self): api_key = provider.get_value("api_key") assert api_key == "foobar" + @patch.dict("os.environ", {"CONNECT_API_KEY": ""}) + def test_get_api_key_empty(self): + provider = EnvironmentConfigProvider() + with pytest.raises(ValueError): + provider.get_value("api_key") + + def test_get_api_key_miss(self): + provider = EnvironmentConfigProvider() + api_key = provider.get_value("api_key") + assert api_key is None + @patch.dict("os.environ", {"CONNECT_SERVER": "http://foo.bar"}) def test_get_endpoint(self): provider = EnvironmentConfigProvider() endpoint = provider.get_value("endpoint") - assert endpoint == "http://foo.bar" + assert endpoint == "http://foo.bar/__api__" + + @patch.dict("os.environ", {"CONNECT_SERVER": ""}) + def test_get_endpoint_empty(self): + provider = EnvironmentConfigProvider() + with pytest.raises(ValueError): + provider.get_value("endpoint") + + def test_get_endpoint_miss(self): + provider = EnvironmentConfigProvider() + endpoint = provider.get_value("endpoint") + assert endpoint is None def test_get_value_miss(self): provider = EnvironmentConfigProvider() diff --git a/src/posit/errors.py b/src/posit/errors.py new file mode 100644 index 00000000..40f777e7 --- /dev/null +++ b/src/posit/errors.py @@ -0,0 +1,11 @@ +class ClientError(Exception): + def __init__( + self, error_code: int, error_message: str, http_status: int, http_message: str + ): + self.error_code = error_code + self.error_message = error_message + self.http_status = http_status + self.http_message = http_message + super().__init__( + f"{error_message} (Error Code: {error_code}, HTTP Status: {http_status} {http_message})" + ) diff --git a/src/posit/errors_test.py b/src/posit/errors_test.py new file mode 100644 index 00000000..99efcc9d --- /dev/null +++ b/src/posit/errors_test.py @@ -0,0 +1,20 @@ +import pytest + +from .errors import ClientError + + +class TestClientError: + def test(self): + error_code = 0 + error_message = "foo" + http_status = 404 + http_message = "Foo Bar" + with pytest.raises( + ClientError, match=r"foo \(Error Code: 0, HTTP Status: 404 Foo Bar\)" + ): + raise ClientError( + error_code=error_code, + error_message=error_message, + http_status=http_status, + http_message=http_message, + ) diff --git a/src/posit/hooks.py b/src/posit/hooks.py new file mode 100644 index 00000000..7a41d532 --- /dev/null +++ b/src/posit/hooks.py @@ -0,0 +1,15 @@ +from http.client import responses +from requests import Response + +from .errors import ClientError + + +def handle_errors(response: Response, *args, **kwargs) -> Response: + if response.status_code >= 400 and response.status_code < 500: + data = response.json() + error_code = data["code"] + message = data["error"] + http_status = response.status_code + http_status_message = responses[http_status] + raise ClientError(error_code, message, http_status, http_status_message) + return response diff --git a/src/posit/hooks_test.py b/src/posit/hooks_test.py new file mode 100644 index 00000000..1ca5d754 --- /dev/null +++ b/src/posit/hooks_test.py @@ -0,0 +1,20 @@ +import pytest + +from unittest.mock import Mock + +from .hooks import handle_errors + + +class TestHandleErrors: + def test(self): + response = Mock() + response.status_code = 200 + assert handle_errors(response) == response + + def test_client_error(self): + response = Mock() + response.status_code = 400 + response.json = Mock() + response.json.return_value = {"code": 0, "error": "foobar"} + with pytest.raises(Exception): + handle_errors(response) diff --git a/src/posit/users.py b/src/posit/users.py new file mode 100644 index 00000000..ade6a8bb --- /dev/null +++ b/src/posit/users.py @@ -0,0 +1,17 @@ +import os + +from requests import Session, Response + + +class Users: + def __init__(self, endpoint: str, session: Session) -> None: + self._endpoint = endpoint + self._session = session + + def get_user(self, user_id: str) -> Response: + endpoint = os.path.join(self._endpoint, "v1/users", user_id) + return self._session.get(endpoint) + + def get_current_user(self) -> Response: + endpoint = os.path.join(self._endpoint, "v1/user") + return self._session.get(endpoint) diff --git a/src/posit/users_test.py b/src/posit/users_test.py new file mode 100644 index 00000000..b79463e8 --- /dev/null +++ b/src/posit/users_test.py @@ -0,0 +1,21 @@ +from unittest.mock import Mock + +from .users import Users + + +class TestUsers: + def test_get_user(self): + session = Mock() + session.get = Mock(return_value={}) + users = Users(endpoint="http://foo.bar/", session=session) + response = users.get_user(user_id="foo") + assert response == {} + session.get.assert_called_once_with("http://foo.bar/v1/users/foo") + + def test_get_current_user(self): + session = Mock() + session.get = Mock(return_value={}) + users = Users(endpoint="http://foo.bar/", session=session) + response = users.get_current_user() + assert response == {} + session.get.assert_called_once_with("http://foo.bar/v1/user") diff --git a/tinkering.py b/tinkering.py new file mode 100644 index 00000000..d6fb000d --- /dev/null +++ b/tinkering.py @@ -0,0 +1,5 @@ +from posit.client import Client + +client = Client() +res = client.users.get_current_user() +print(res.json())