diff --git a/gazu/__init__.py b/gazu/__init__.py index 158d7b4..1bd44b3 100644 --- a/gazu/__init__.py +++ b/gazu/__init__.py @@ -87,22 +87,8 @@ def log_out(client=raw.default_client): return tokens -def refresh_token(client=raw.default_client): - headers = {"User-Agent": "CGWire Gazu %s" % __version__} - if "refresh_token" in client.tokens: - headers["Authorization"] = "Bearer %s" % client.tokens["refresh_token"] - - response = client.session.get( - raw.get_full_url("auth/refresh-token", client=client), - headers=headers, - ) - raw.check_status(response, "auth/refresh-token") - - tokens = response.json() - - client.tokens["access_token"] = tokens["access_token"] - - return tokens +def refresh_access_token(client=raw.default_client): + return client.refresh_access_token() def get_event_host(client=raw.default_client): diff --git a/gazu/client.py b/gazu/client.py index bc029f9..2e37fa4 100644 --- a/gazu/client.py +++ b/gazu/client.py @@ -36,32 +36,107 @@ def __init__( host, ssl_verify=True, cert=None, - automatic_refresh_token=False, + use_refresh_token=True, callback_not_authenticated=None, + tokens={"access_token": None, "refresh_token": None}, + access_token=None, + refresh_token=None, ): - self.tokens = {"access_token": "", "refresh_token": ""} + self.tokens = tokens + if access_token: + self.access_token = access_token + if refresh_token: + self.refresh_token = refresh_token + self.use_refresh_token = use_refresh_token + self.callback_not_authenticated = callback_not_authenticated + self.session = requests.Session() self.session.verify = ssl_verify self.session.cert = cert self.host = host self.event_host = host - self.automatic_refresh_token = automatic_refresh_token - self.callback_not_authenticated = callback_not_authenticated + + @property + def access_token(self): + return self.tokens.get("access_token", None) + + @access_token.setter + def access_token(self, token): + self.tokens["access_token"] = token + + @property + def refresh_token(self): + return self.tokens.get("refresh_token", None) + + @refresh_token.setter + def refresh_token(self, token): + self.tokens["refresh_token"] = token + + def refresh_access_token(self): + """ + Refresh access tokens for this client. + + Returns: + dict: The new access token. + """ + response = self.session.get( + get_full_url("auth/refresh-token", client=self), + headers={ + "User-Agent": "CGWire Gazu " + __version__, + "Authorization": "Bearer " + self.refresh_token, + }, + ) + check_status(response, "auth/refresh-token") + tokens = response.json() + + self.access_token = tokens["access_token"] + self.refresh_token = None + + return tokens + + def make_auth_header(self): + """ + Make headers required to authenticate. + + Returns: + dict: Headers required to authenticate. + """ + headers = {"User-Agent": "CGWire Gazu " + __version__} + + if self.access_token: + headers["Authorization"] = "Bearer " + self.access_token + + return headers def create_client( host, ssl_verify=True, cert=None, - automatic_refresh_token=False, + use_refresh_token=False, callback_not_authenticated=None, + **kwargs ): + """ + Create a client with given parameters. + + Args: + host (str): The host to use for requests. + ssl_verify (bool): Whether to verify SSL certificates. + cert (str): Path to a client certificate. + use_refresh_token (bool): Whether to automatically refresh tokens. + callback_not_authenticated (function): Function to call when not authenticated. + + Returns: + KitsuClient: The created client. + """ return KitsuClient( host, ssl_verify, cert=cert, - automatic_refresh_token=automatic_refresh_token, + use_refresh_token=use_refresh_token, callback_not_authenticated=callback_not_authenticated, + **kwargs ) @@ -81,8 +156,13 @@ def create_client( def host_is_up(client=default_client): """ + Check if the host is up. + + Args: + client (KitsuClient): The client to use for the request. + Returns: - True if the host is up. + bool: True if the host is up. """ try: response = client.session.head(client.host) @@ -94,8 +174,12 @@ def host_is_up(client=default_client): def host_is_valid(client=default_client): """ Check if the host is valid by simulating a fake login. + + Args: + client (KitsuClient): The client to use for the request. + Returns: - True if the host is valid. + bool: True if the host is valid. """ if not host_is_up(client): return False @@ -107,14 +191,23 @@ def host_is_valid(client=default_client): def get_host(client=default_client): """ + Get client.host. + + Args: + client (KitsuClient): The client to use for the request. + Returns: - Host on which requests are sent. + str: The host of the client. """ return client.host def get_api_url_from_host(client=default_client): """ + Get the API url from the host. + + Args: + client (KitsuClient): The client to use for the request. Returns: Zou url, retrieved from host. """ @@ -123,8 +216,14 @@ def get_api_url_from_host(client=default_client): def set_host(new_host, client=default_client): """ + Set the host for the client. + + Args: + new_host (str): The new host to set. + client (KitsuClient): The client to use for the request. + Returns: - Set currently configured host on which requests are sent. + str: The new host. """ client.host = new_host return client.host @@ -132,16 +231,27 @@ def set_host(new_host, client=default_client): def get_event_host(client=default_client): """ + Get the host on which listening for events. + + Args: + client (KitsuClient): The client to use for the request. + Returns: - Host on which listening for events. + str: The event host. """ return client.event_host or client.host def set_event_host(new_host, client=default_client): """ + Set the host on which listening for events. + + Args: + new_host (str): The new host to set. + client (KitsuClient): The client to use for the request. + Returns: - Set currently configured host on which listening for events. + str: The new event host. """ client.event_host = new_host return client.event_host @@ -153,6 +263,10 @@ def set_tokens(new_tokens, client=default_client): Args: new_tokens (dict): Tokens to use for authentication. + client (KitsuClient): The client to use for the request. + + Returns: + dict: The stored tokens. """ client.tokens = new_tokens return client.tokens @@ -160,13 +274,15 @@ def set_tokens(new_tokens, client=default_client): def make_auth_header(client=default_client): """ + Make headers required to authenticate. + + Args: + client (KitsuClient): The client to use for the request. + Returns: - Headers required to authenticate. + dict: Headers required to authenticate. """ - headers = {"User-Agent": "CGWire Gazu %s" % __version__} - if "access_token" in client.tokens: - headers["Authorization"] = "Bearer %s" % client.tokens["access_token"] - return headers + return client.make_auth_header() def url_path_join(*items): @@ -176,12 +292,17 @@ def url_path_join(*items): Args: items (list): Path elements + + Returns: + str: The joined path. """ return "/".join([item.lstrip("/").rstrip("/") for item in items]) def get_full_url(path, client=default_client): """ + Join host url with given path. + Args: path (str): The path to integrate to host url. @@ -219,6 +340,12 @@ def get(path, json_response=True, params=None, client=default_client): """ Run a get request toward given path for configured host. + Args: + path (str): The path to query. + json_response (bool): Whether to return a json response. + params (dict): The parameters to pass to the request. + client (KitsuClient): The client to use for the request. + Returns: The request result. """ @@ -243,6 +370,11 @@ def post(path, data, client=default_client): """ Run a post request toward given path for configured host. + Args: + path (str): The path to query. + data (dict): The data to post. + client (KitsuClient): The client to use for the request. + Returns: The request result. """ @@ -270,6 +402,11 @@ def put(path, data, client=default_client): """ Run a put request toward given path for configured host. + Args: + path (str): The path to query. + data (dict): The data to put. + client (KitsuClient): The client to use for the request. + Returns: The request result. """ @@ -291,6 +428,11 @@ def delete(path, params=None, client=default_client): """ Run a delete request toward given path for configured host. + Args: + path (str): The path to query. + params (dict): The parameters to pass to the request. + client (KitsuClient): The client to use for the request. + Returns: The request result. """ @@ -320,7 +462,7 @@ def get_message_from_response( default_message: str - An optional default value to revert to if no message is detected. Returns: - The message of a given response, or a default message - if any. + str: The message to display to the user. """ message = default_message message_json = response.json() @@ -340,6 +482,8 @@ def check_status(request, path, client=None): Args: request (Request): The request to validate. + path (str): The path of the request. + client (KitsuClient): The client to use for the request. Returns: int: Status code @@ -369,19 +513,18 @@ def check_status(request, path, client=None): ) elif status_code in [401, 422]: try: - if client is not None and client.automatic_refresh_token: - from . import refresh_token - - refresh_token(client=client) - + if ( + client + and client.refresh_token + and client.use_refresh_token + and request.json()["message"] == "Signature has expired" + ): + client.refresh_access_token() return status_code, True else: raise NotAuthenticatedException(path) except NotAuthenticatedException: - if ( - client is not None - and client.callback_not_authenticated is not None - ): + if client and client.callback_not_authenticated: retry = client.callback_not_authenticated(client, path) if retry: return status_code, True @@ -410,6 +553,8 @@ def fetch_all( """ Args: path (str): The path for which we want to retrieve all entries. + params (dict): The parameters to pass to the request. + client (KitsuClient): The client to use for the request. paginated (bool): Will query entries page by page. limit (int): Limit the number of entries per page. @@ -453,6 +598,8 @@ def fetch_first(path, params=None, client=default_client): """ Args: path (str): The path for which we want to retrieve the first entry. + params (dict): The parameters to pass to the request. + client (KitsuClient): The client to use for the request. Returns: dict: The first entry for which a model is required. @@ -472,6 +619,8 @@ def fetch_one(model_name, id, params=None, client=default_client): Args: model_name (str): Model type name. id (str): Model instance ID. + params (dict): The parameters to pass to the request. + client (KitsuClient): The client to use for the request. Returns: dict: The model instance matching id and model name. @@ -486,8 +635,9 @@ def create(model_name, data, client=default_client): Create an entry for given model and data. Args: - model (str): The model type involved - data (str): The data to use for creation + model_name (str): The model type involved. + data (str): The data to use for creation. + client (KitsuClient): The client to use for the request. Returns: dict: Created entry @@ -500,9 +650,10 @@ def update(model_name, model_id, data, client=default_client): Update an entry for given model, id and data. Args: - model (str): The model type involved - id (str): The target model id - data (dict): The data to update + model_name (str): The model type involved. + model_id (str): The target model id. + data (dict): The data to update. + client (KitsuClient): The client to use for the request. Returns: dict: Updated entry @@ -519,6 +670,9 @@ def upload(path, file_path, data={}, extra_files=[], client=default_client): Args: path (str): The url path to upload file. file_path (str): The file location on the hard drive. + data (dict): The data to send with the file. + extra_files (list): List of extra files to upload. + client (KitsuClient): The client to use for the request. Returns: Response: Request response object. @@ -548,6 +702,17 @@ def upload(path, file_path, data={}, extra_files=[], client=default_client): def _build_file_dict(file_path, extra_files): + """ + Build a dictionary of files to upload. + + Args: + file_path (str): The file location on the hard drive. + extra_files (list): List of extra files to upload. + + Returns: + dict: The dictionary of files to upload. + """ + files = {"file": open(file_path, "rb")} i = 2 for file_path in extra_files: @@ -563,6 +728,8 @@ def download(path, file_path, params=None, client=default_client): Args: path (str): The url path to download file from. file_path (str): The location to store the file on the hard drive. + params (dict): The parameters to pass to the request. + client (KitsuClient): The client to use for the request. Returns: Response: Request response object. @@ -582,6 +749,14 @@ def download(path, file_path, params=None, client=default_client): def get_file_data_from_url(url, full=False, client=default_client): """ Return data found at given url. + + Args: + url (str): The url to fetch data from. + full (bool): Whether to use full url. + client (KitsuClient): The client to use for the request. + + Returns: + bytes: The data found at the given url. """ if not full: url = get_full_url(url) @@ -598,15 +773,26 @@ def get_file_data_from_url(url, full=False, client=default_client): def import_data(model_name, data, client=default_client): """ + Import data for given model. + Args: - model_name (str): The data model to import - data (dict): The data to import + model_name (str): The data model to import. + data (dict): The data to import. + client (KitsuClient): The client to use for the request. + + Returns: + dict: The imported data. """ return post("/import/kitsu/%s" % model_name, data, client=client) def get_api_version(client=default_client): """ + Get the current version of the API. + + Args: + client (KitsuClient): The client to use for the request. + Returns: str: Current version of the API. """ @@ -615,6 +801,11 @@ def get_api_version(client=default_client): def get_current_user(client=default_client): """ + Get the current user. + + Args: + client (KitsuClient): The client to use for the request. + Returns: dict: User database information for user linked to auth tokens. """ diff --git a/gazu/events.py b/gazu/events.py index 9223fb0..13592e6 100644 --- a/gazu/events.py +++ b/gazu/events.py @@ -52,11 +52,13 @@ def init( Returns: Event client that will be able to set listeners. """ - params = {"ssl_verify": ssl_verify} + params = { + "ssl_verify": ssl_verify, + "reconnection": reconnection, + "logger": logger, + } params.update(kwargs) - event_client = socketio.Client( - logger=logger, reconnection=reconnection, **params - ) + event_client = socketio.Client(**params) event_client.on("connect_error", connect_error) event_client.register_namespace(EventsNamespace("/events")) event_client.connect(get_event_host(client), make_auth_header()) diff --git a/tests/test_client.py b/tests/test_client.py index e208407..5f07ae2 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -70,7 +70,10 @@ def test_set_tokens(self): pass def test_make_auth_header(self): - pass + self.assertEqual( + raw.default_client.make_auth_header(), + raw.make_auth_header(), + ) def test_url_path_join(self): root = raw.get_host() @@ -255,6 +258,7 @@ def test_version(self): def test_make_auth_token(self): tokens = {"access_token": "token_test"} + raw.set_tokens(tokens) self.assertEqual( raw.make_auth_header(), @@ -437,7 +441,7 @@ def test_init_refresh_token(self): "auth/refresh-token", text={"access_token": "tokentest1"}, ) - gazu.refresh_token() + gazu.refresh_access_token() self.assertEqual( raw.default_client.tokens["access_token"], "tokentest1" )