diff --git a/src/posit/connect/client.py b/src/posit/connect/client.py index 97476ed6..6e005b0e 100644 --- a/src/posit/connect/client.py +++ b/src/posit/connect/client.py @@ -7,8 +7,7 @@ from .auth import Auth from .config import Config -from .resources import CachedResources -from .users import Users, User +from .users import CachedUsers, Users, User class Client: @@ -32,14 +31,24 @@ def __init__( session.auth = Auth(config=self.config) # Add error handling hooks to the session. session.hooks["response"].append(hooks.handle_errors) - - # Initialize the Users instance. - self.users: CachedResources[User] = Users(config=self.config, session=session) # Store the Session object. self.session = session - # Place to cache the server settings + # Internal properties for storing public resources self.server_settings = None + self._current_user: Optional[User] = None + + @property + def me(self) -> User: + if self._current_user is None: + url = urls.append_path(self.config.url, "v1/user") + response = self.session.get(url) + self._current_user = User(**response.json()) + return self._current_user + + @property + def users(self) -> CachedUsers: + return Users(client=self) @property def connect_version(self): diff --git a/src/posit/connect/client_test.py b/src/posit/connect/client_test.py index 7f207b79..cfb56606 100644 --- a/src/posit/connect/client_test.py +++ b/src/posit/connect/client_test.py @@ -45,9 +45,31 @@ def test_init( MockAuth.assert_called_once_with(config=MockConfig.return_value) MockConfig.assert_called_once_with(api_key=api_key, url=url) MockSession.assert_called_once() - MockUsers.assert_called_once_with( - config=MockConfig.return_value, session=MockSession.return_value - ) + + def test_users( + self, + MockUsers: MagicMock, + ): + api_key = "foobar" + url = "http://foo.bar/__api__" + client = Client(api_key=api_key, url=url) + client.users + MockUsers.assert_called_once_with(client=client) + + @patch("posit.connect.client.Session") + @patch("posit.connect.client.User") + def test_me( + self, + User: MagicMock, + Session: MagicMock, + ): + api_key = "foobar" + url = "http://foo.bar/__api__" + client = Client(api_key=api_key, url=url) + User.assert_not_called() + assert client._current_user is None + client.me + User.assert_called_once() def test__del__(self, MockAuth, MockConfig, MockSession, MockUsers): api_key = "foobar" diff --git a/src/posit/connect/users.py b/src/posit/connect/users.py index e58641a6..15a77fd0 100644 --- a/src/posit/connect/users.py +++ b/src/posit/connect/users.py @@ -1,13 +1,13 @@ from __future__ import annotations from datetime import datetime -from typing import Iterator, Callable, List +from typing import Iterator, Callable, List, TYPE_CHECKING -from requests import Session +if TYPE_CHECKING: + from .client import Client from . import urls -from .config import Config from .resources import Resources, Resource, CachedResources # The maximum page size supported by the API. @@ -43,17 +43,14 @@ def get(self, id: str) -> User: class Users(CachedUsers, Resources[User]): - def __init__( - self, config: Config, session: Session, *, page_size=_MAX_PAGE_SIZE - ) -> None: + def __init__(self, client: Client, *, page_size=_MAX_PAGE_SIZE) -> None: if page_size > _MAX_PAGE_SIZE: raise ValueError( f"page_size must be less than or equal to {_MAX_PAGE_SIZE}" ) - super().__init__(config.url) - self.config = config - self.session = session + super().__init__(client.config.url) + self.client = client self.page_size = page_size def fetch(self, index) -> tuple[Iterator[User] | None, bool]: @@ -68,9 +65,9 @@ def fetch(self, index) -> tuple[Iterator[User] | None, bool]: # Define query parameters for pagination. params = {"page_number": page_number, "page_size": self.page_size} # Create the URL for the endpoint. - url = urls.append_path(self.config.url, "v1/users") + url = urls.append_path(self.client.config.url, "v1/users") # Send a GET request to the endpoint with the specified parameters. - response = self.session.get(url, params=params) + response = self.client.session.get(url, params=params) # Convert response to dict json: dict = dict(response.json()) # Parse the JSON response and extract the results. @@ -82,6 +79,6 @@ def fetch(self, index) -> tuple[Iterator[User] | None, bool]: return (users, exhausted) def get(self, id: str) -> User: - url = urls.append_path(self.config.url, f"v1/users/{id}") - response = self.session.get(url) + url = urls.append_path(self.client.config.url, f"v1/users/{id}") + response = self.client.session.get(url) return User(**response.json()) diff --git a/src/posit/connect/users_test.py b/src/posit/connect/users_test.py index e2304214..3a7f7546 100644 --- a/src/posit/connect/users_test.py +++ b/src/posit/connect/users_test.py @@ -6,18 +6,12 @@ @pytest.fixture -def mock_config(): - with patch("posit.connect.users.Config") as mock: - yield mock.return_value - - -@pytest.fixture -def mock_session(): - with patch("posit.connect.users.Session") as mock: +def mock_client(): + with patch("posit.connect.client.Client") as mock: yield mock.return_value class TestUsers: - def test_init(self, mock_config, mock_session): + def test_init(self, mock_client): with pytest.raises(ValueError): - Users(mock_config, mock_session, page_size=9999) + Users(mock_client, page_size=9999) diff --git a/tinkering.py b/tinkering.py index 6054fdff..4273d2ac 100644 --- a/tinkering.py +++ b/tinkering.py @@ -1,9 +1,9 @@ from posit.connect import Client with Client() as client: + print(client.me) print(client.get("v1/users")) print(client.users.get("f55ca95d-ce52-43ed-b31b-48dc4a07fe13")) - users = client.users users = users.find(lambda user: user["first_name"].startswith("T")) users = users.find(lambda user: user["last_name"].startswith("S"))