From c8ca4b13ffdca5b7d1a6059ad049f6fda72eda44 Mon Sep 17 00:00:00 2001 From: Alexandre Girard Date: Mon, 8 Jan 2024 17:40:48 -0800 Subject: [PATCH] :bug: fix declarative oauth initialization (#32967) Co-authored-by: girarda --- .../sources/declarative/auth/oauth.py | 72 +++++++++++-------- .../requests_native_auth/abstract_oauth.py | 45 ++++++------ .../sources/declarative/auth/test_oauth.py | 25 ++++++- .../test_model_to_component_factory.py | 14 ++-- 4 files changed, 97 insertions(+), 59 deletions(-) diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/auth/oauth.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/auth/oauth.py index 4e83c570be6e..d858677b6324 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/auth/oauth.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/auth/oauth.py @@ -46,8 +46,8 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut refresh_token: Optional[Union[InterpolatedString, str]] = None scopes: Optional[List[str]] = None token_expiry_date: Optional[Union[InterpolatedString, str]] = None - _token_expiry_date: pendulum.DateTime = field(init=False, repr=False, default=None) - token_expiry_date_format: str = None + _token_expiry_date: Optional[pendulum.DateTime] = field(init=False, repr=False, default=None) + token_expiry_date_format: Optional[str] = None token_expiry_is_time_of_expiration: bool = False access_token_name: Union[InterpolatedString, str] = "access_token" expires_in_name: Union[InterpolatedString, str] = "expires_in" @@ -55,65 +55,79 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut grant_type: Union[InterpolatedString, str] = "refresh_token" message_repository: MessageRepository = NoopMessageRepository() - def __post_init__(self, parameters: Mapping[str, Any]): - self.token_refresh_endpoint = InterpolatedString.create(self.token_refresh_endpoint, parameters=parameters) - self.client_id = InterpolatedString.create(self.client_id, parameters=parameters) - self.client_secret = InterpolatedString.create(self.client_secret, parameters=parameters) + def __post_init__(self, parameters: Mapping[str, Any]) -> None: + super().__init__() + self._token_refresh_endpoint = InterpolatedString.create(self.token_refresh_endpoint, parameters=parameters) + self._client_id = InterpolatedString.create(self.client_id, parameters=parameters) + self._client_secret = InterpolatedString.create(self.client_secret, parameters=parameters) if self.refresh_token is not None: - self.refresh_token = InterpolatedString.create(self.refresh_token, parameters=parameters) + self._refresh_token = InterpolatedString.create(self.refresh_token, parameters=parameters) + else: + self._refresh_token = None self.access_token_name = InterpolatedString.create(self.access_token_name, parameters=parameters) self.expires_in_name = InterpolatedString.create(self.expires_in_name, parameters=parameters) self.grant_type = InterpolatedString.create(self.grant_type, parameters=parameters) self._refresh_request_body = InterpolatedMapping(self.refresh_request_body or {}, parameters=parameters) - self._token_expiry_date = ( - pendulum.parse(InterpolatedString.create(self.token_expiry_date, parameters=parameters).eval(self.config)) + self._token_expiry_date: pendulum.DateTime = ( + pendulum.parse(InterpolatedString.create(self.token_expiry_date, parameters=parameters).eval(self.config)) # type: ignore # pendulum.parse returns a datetime in this context if self.token_expiry_date - else pendulum.now().subtract(days=1) + else pendulum.now().subtract(days=1) # type: ignore # substract does not have type hints ) - self._access_token = None + self._access_token: Optional[str] = None # access_token is initialized by a setter - if self.get_grant_type() == "refresh_token" and self.refresh_token is None: + if self.get_grant_type() == "refresh_token" and self._refresh_token is None: raise ValueError("OAuthAuthenticator needs a refresh_token parameter if grant_type is set to `refresh_token`") def get_token_refresh_endpoint(self) -> str: - return self.token_refresh_endpoint.eval(self.config) + refresh_token: str = self._token_refresh_endpoint.eval(self.config) + if not refresh_token: + raise ValueError("OAuthAuthenticator was unable to evaluate token_refresh_endpoint parameter") + return refresh_token def get_client_id(self) -> str: - return self.client_id.eval(self.config) + client_id: str = self._client_id.eval(self.config) + if not client_id: + raise ValueError("OAuthAuthenticator was unable to evaluate client_id parameter") + return client_id def get_client_secret(self) -> str: - return self.client_secret.eval(self.config) + client_secret: str = self._client_secret.eval(self.config) + if not client_secret: + raise ValueError("OAuthAuthenticator was unable to evaluate client_secret parameter") + return client_secret def get_refresh_token(self) -> Optional[str]: - return None if self.refresh_token is None else self.refresh_token.eval(self.config) + return None if self._refresh_token is None else self._refresh_token.eval(self.config) - def get_scopes(self) -> [str]: - return self.scopes + def get_scopes(self) -> List[str]: + return self.scopes or [] - def get_access_token_name(self) -> InterpolatedString: - return self.access_token_name.eval(self.config) + def get_access_token_name(self) -> str: + return self.access_token_name.eval(self.config) # type: ignore # eval returns a string in this context - def get_expires_in_name(self) -> InterpolatedString: - return self.expires_in_name.eval(self.config) + def get_expires_in_name(self) -> str: + return self.expires_in_name.eval(self.config) # type: ignore # eval returns a string in this context - def get_grant_type(self) -> InterpolatedString: - return self.grant_type.eval(self.config) + def get_grant_type(self) -> str: + return self.grant_type.eval(self.config) # type: ignore # eval returns a string in this context def get_refresh_request_body(self) -> Mapping[str, Any]: - return self._refresh_request_body.eval(self.config) + return self._refresh_request_body.eval(self.config) # type: ignore # eval should return a Mapping in this context def get_token_expiry_date(self) -> pendulum.DateTime: - return self._token_expiry_date + return self._token_expiry_date # type: ignore # _token_expiry_date is a pendulum.DateTime. It is never None despite what mypy thinks - def set_token_expiry_date(self, value: Union[str, int]): + def set_token_expiry_date(self, value: Union[str, int]) -> None: self._token_expiry_date = self._parse_token_expiration_date(value) @property def access_token(self) -> str: + if self._access_token is None: + raise ValueError("access_token is not set") return self._access_token @access_token.setter - def access_token(self, value: str): + def access_token(self, value: str) -> None: self._access_token = value @property @@ -130,5 +144,5 @@ class DeclarativeSingleUseRefreshTokenOauth2Authenticator(SingleUseRefreshTokenO Declarative version of SingleUseRefreshTokenOauth2Authenticator which can be used in declarative connectors. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index 22e2caa6a2e8..0dd450413dd4 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -45,7 +45,7 @@ def __init__( self._refresh_token_error_key = refresh_token_error_key self._refresh_token_error_values = refresh_token_error_values - def __call__(self, request: requests.Request) -> requests.Request: + def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: """Attach the HTTP headers required to authenticate on the HTTP request""" request.headers.update(self.get_auth_header()) return request @@ -65,7 +65,7 @@ def get_access_token(self) -> str: def token_has_expired(self) -> bool: """Returns True if the token is expired""" - return pendulum.now() > self.get_token_expiry_date() + return pendulum.now() > self.get_token_expiry_date() # type: ignore # this is always a bool despite what mypy thinks def build_refresh_request_body(self) -> Mapping[str, Any]: """ @@ -80,7 +80,7 @@ def build_refresh_request_body(self) -> Mapping[str, Any]: "refresh_token": self.get_refresh_token(), } - if self.get_scopes: + if self.get_scopes(): payload["scopes"] = self.get_scopes() if self.get_refresh_request_body(): @@ -93,7 +93,10 @@ def build_refresh_request_body(self) -> Mapping[str, Any]: def _wrap_refresh_token_exception(self, exception: requests.exceptions.RequestException) -> bool: try: - exception_content = exception.response.json() + if exception.response is not None: + exception_content = exception.response.json() + else: + return False except JSONDecodeError: return False return ( @@ -109,15 +112,16 @@ def _wrap_refresh_token_exception(self, exception: requests.exceptions.RequestEx ), max_time=300, ) - def _get_refresh_access_token_response(self): + def _get_refresh_access_token_response(self) -> Any: try: response = requests.request(method="POST", url=self.get_token_refresh_endpoint(), data=self.build_refresh_request_body()) self._log_response(response) response.raise_for_status() return response.json() except requests.exceptions.RequestException as e: - if e.response.status_code == 429 or e.response.status_code >= 500: - raise DefaultBackoffException(request=e.response.request, response=e.response) + if e.response is not None: + if e.response.status_code == 429 or e.response.status_code >= 500: + raise DefaultBackoffException(request=e.response.request, response=e.response) if self._wrap_refresh_token_exception(e): message = "Refresh token is invalid or expired. Please re-authenticate from Sources//Settings." raise AirbyteTracedException(internal_message=message, message=message, failure_type=FailureType.config_error) @@ -147,7 +151,7 @@ def _parse_token_expiration_date(self, value: Union[str, int]) -> pendulum.DateT raise ValueError( f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required." ) - return pendulum.from_format(value, self.token_expiry_date_format) + return pendulum.from_format(str(value), self.token_expiry_date_format) else: return pendulum.now().add(seconds=int(float(value))) @@ -192,7 +196,7 @@ def get_token_expiry_date(self) -> pendulum.DateTime: """Expiration date of the access token""" @abstractmethod - def set_token_expiry_date(self, value: Union[str, int]): + def set_token_expiry_date(self, value: Union[str, int]) -> None: """Setter for access token expiration date""" @abstractmethod @@ -228,14 +232,15 @@ def _message_repository(self) -> Optional[MessageRepository]: """ return _NOOP_MESSAGE_REPOSITORY - def _log_response(self, response: requests.Response): - self._message_repository.log_message( - Level.DEBUG, - lambda: format_http_message( - response, - "Refresh token", - "Obtains access token", - self._NO_STREAM_NAME, - is_auxiliary=True, - ), - ) + def _log_response(self, response: requests.Response) -> None: + if self._message_repository: + self._message_repository.log_message( + Level.DEBUG, + lambda: format_http_message( + response, + "Refresh token", + "Obtains access token", + self._NO_STREAM_NAME, + is_auxiliary=True, + ), + ) diff --git a/airbyte-cdk/python/unit_tests/sources/declarative/auth/test_oauth.py b/airbyte-cdk/python/unit_tests/sources/declarative/auth/test_oauth.py index 0992d0e331bc..bd019d374987 100644 --- a/airbyte-cdk/python/unit_tests/sources/declarative/auth/test_oauth.py +++ b/airbyte-cdk/python/unit_tests/sources/declarative/auth/test_oauth.py @@ -81,7 +81,6 @@ def test_refresh_with_encode_config_params(self): "client_id": base64.b64encode(config["client_id"].encode("utf-8")).decode(), "client_secret": base64.b64encode(config["client_secret"].encode("utf-8")).decode(), "refresh_token": None, - "scopes": None, } assert body == expected @@ -104,7 +103,6 @@ def test_refresh_with_decode_config_params(self): "client_id": "some_client_id", "client_secret": "some_client_secret", "refresh_token": None, - "scopes": None, } assert body == expected @@ -126,7 +124,6 @@ def test_refresh_without_refresh_token(self): "client_id": "some_client_id", "client_secret": "some_client_secret", "refresh_token": None, - "scopes": None, } assert body == expected @@ -278,6 +275,28 @@ def test_set_token_expiry_date_no_format(self, mocker, expires_in_response, next assert "access_token" == token assert oauth.get_token_expiry_date() == pendulum.parse(next_day) + def test_error_handling(self, mocker): + oauth = DeclarativeOauth2Authenticator( + token_refresh_endpoint="{{ config['refresh_endpoint'] }}", + client_id="{{ config['client_id'] }}", + client_secret="{{ config['client_secret'] }}", + refresh_token="{{ config['refresh_token'] }}", + config=config, + scopes=["scope1", "scope2"], + refresh_request_body={ + "custom_field": "{{ config['custom_field'] }}", + "another_field": "{{ config['another_field'] }}", + "scopes": ["no_override"], + }, + parameters={}, + ) + resp.status_code = 400 + mocker.patch.object(resp, "json", return_value={"access_token": "access_token", "expires_in": 123}) + mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) + with pytest.raises(requests.exceptions.HTTPError) as e: + oauth.refresh_access_token() + assert e.value.errno == 400 + def mock_request(method, url, data): if url == "refresh_end": diff --git a/airbyte-cdk/python/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py b/airbyte-cdk/python/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py index 15f774879c60..08cea962086e 100644 --- a/airbyte-cdk/python/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py +++ b/airbyte-cdk/python/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py @@ -303,10 +303,10 @@ def test_interpolate_config(): ) assert isinstance(authenticator, DeclarativeOauth2Authenticator) - assert authenticator.client_id.eval(input_config) == "some_client_id" - assert authenticator.client_secret.string == "some_client_secret" - assert authenticator.token_refresh_endpoint.eval(input_config) == "https://api.sendgrid.com/v3/auth" - assert authenticator.refresh_token.eval(input_config) == "verysecrettoken" + assert authenticator._client_id.eval(input_config) == "some_client_id" + assert authenticator._client_secret.string == "some_client_secret" + assert authenticator._token_refresh_endpoint.eval(input_config) == "https://api.sendgrid.com/v3/auth" + assert authenticator._refresh_token.eval(input_config) == "verysecrettoken" assert authenticator._refresh_request_body.mapping == {"body_field": "yoyoyo", "interpolated_body_field": "{{ config['apikey'] }}"} assert authenticator.get_refresh_request_body() == {"body_field": "yoyoyo", "interpolated_body_field": "verysecrettoken"} @@ -332,9 +332,9 @@ def test_interpolate_config_with_token_expiry_date_format(): assert isinstance(authenticator, DeclarativeOauth2Authenticator) assert authenticator.token_expiry_date_format == "%Y-%m-%d %H:%M:%S.%f+00:00" assert authenticator.token_expiry_is_time_of_expiration - assert authenticator.client_id.eval(input_config) == "some_client_id" - assert authenticator.client_secret.string == "some_client_secret" - assert authenticator.token_refresh_endpoint.eval(input_config) == "https://api.sendgrid.com/v3/auth" + assert authenticator._client_id.eval(input_config) == "some_client_id" + assert authenticator._client_secret.string == "some_client_secret" + assert authenticator._token_refresh_endpoint.eval(input_config) == "https://api.sendgrid.com/v3/auth" def test_single_use_oauth_branch():