From 2185bd9a20d69ec7b8d70df872a23f299fc60334 Mon Sep 17 00:00:00 2001 From: Daryna Ishchenko <80129833+darynaishchenko@users.noreply.github.com> Date: Thu, 16 Jan 2025 16:57:25 +0200 Subject: [PATCH] feat(low-code): pass refresh headers to oauth (#219) --- airbyte_cdk/sources/declarative/auth/oauth.py | 8 +++ .../declarative_component_schema.yaml | 8 +++ .../models/declarative_component_schema.py | 11 +++ .../parsers/model_to_component_factory.py | 4 ++ .../requests_native_auth/abstract_oauth.py | 13 ++++ .../http/requests_native_auth/oauth.py | 8 +++ .../sources/declarative/auth/test_oauth.py | 72 ++++++++++++++++++- .../test_requests_native_auth.py | 67 ++++++++++++++++- 8 files changed, 187 insertions(+), 4 deletions(-) diff --git a/airbyte_cdk/sources/declarative/auth/oauth.py b/airbyte_cdk/sources/declarative/auth/oauth.py index 3bfed9c2a..36508fd7e 100644 --- a/airbyte_cdk/sources/declarative/auth/oauth.py +++ b/airbyte_cdk/sources/declarative/auth/oauth.py @@ -39,6 +39,7 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut token_expiry_date_format str: format of the datetime; provide it if expires_in is returned in datetime instead of seconds token_expiry_is_time_of_expiration bool: set True it if expires_in is returned as time of expiration instead of the number seconds until expiration refresh_request_body (Optional[Mapping[str, Any]]): The request body to send in the refresh request + refresh_request_headers (Optional[Mapping[str, Any]]): The request headers to send in the refresh request grant_type: The grant_type to request for access_token. If set to refresh_token, the refresh_token parameter has to be provided message_repository (MessageRepository): the message repository used to emit logs on HTTP requests """ @@ -61,6 +62,7 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut expires_in_name: Union[InterpolatedString, str] = "expires_in" refresh_token_name: Union[InterpolatedString, str] = "refresh_token" refresh_request_body: Optional[Mapping[str, Any]] = None + refresh_request_headers: Optional[Mapping[str, Any]] = None grant_type_name: Union[InterpolatedString, str] = "grant_type" grant_type: Union[InterpolatedString, str] = "refresh_token" message_repository: MessageRepository = NoopMessageRepository() @@ -101,6 +103,9 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: self._refresh_request_body = InterpolatedMapping( self.refresh_request_body or {}, parameters=parameters ) + self._refresh_request_headers = InterpolatedMapping( + self.refresh_request_headers or {}, parameters=parameters + ) self._token_expiry_date: pendulum.DateTime = ( pendulum.parse( InterpolatedString.create(self.token_expiry_date, parameters=parameters).eval( @@ -178,6 +183,9 @@ def get_grant_type(self) -> str: def get_refresh_request_body(self) -> Mapping[str, Any]: return self._refresh_request_body.eval(self.config) + def get_refresh_request_headers(self) -> Mapping[str, Any]: + return self._refresh_request_headers.eval(self.config) + def get_token_expiry_date(self) -> pendulum.DateTime: return self._token_expiry_date # type: ignore # _token_expiry_date is a pendulum.DateTime. It is never None despite what mypy thinks diff --git a/airbyte_cdk/sources/declarative/declarative_component_schema.yaml b/airbyte_cdk/sources/declarative/declarative_component_schema.yaml index 272fad750..53547c166 100644 --- a/airbyte_cdk/sources/declarative/declarative_component_schema.yaml +++ b/airbyte_cdk/sources/declarative/declarative_component_schema.yaml @@ -1139,6 +1139,14 @@ definitions: - applicationId: "{{ config['application_id'] }}" applicationSecret: "{{ config['application_secret'] }}" token: "{{ config['token'] }}" + refresh_request_headers: + title: Refresh Request Headers + description: Headers of the request sent to get a new access token. + type: object + additionalProperties: true + examples: + - Authorization: "" + Content-Type: "application/x-www-form-urlencoded" scopes: title: Scopes description: List of scopes that should be granted to the access token. diff --git a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py index 2fc8b6b42..7d81d9afb 100644 --- a/airbyte_cdk/sources/declarative/models/declarative_component_schema.py +++ b/airbyte_cdk/sources/declarative/models/declarative_component_schema.py @@ -571,6 +571,17 @@ class OAuthAuthenticator(BaseModel): ], title="Refresh Request Body", ) + refresh_request_headers: Optional[Dict[str, Any]] = Field( + None, + description="Headers of the request sent to get a new access token.", + examples=[ + { + "Authorization": "", + "Content-Type": "application/x-www-form-urlencoded", + } + ], + title="Refresh Request Headers", + ) scopes: Optional[List[str]] = Field( None, description="List of scopes that should be granted to the access token.", diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index 4bcd18c1e..593784546 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -1919,6 +1919,9 @@ def create_oauth_authenticator( refresh_request_body=InterpolatedMapping( model.refresh_request_body or {}, parameters=model.parameters or {} ).eval(config), + refresh_request_headers=InterpolatedMapping( + model.refresh_request_headers or {}, parameters=model.parameters or {} + ).eval(config), scopes=model.scopes, token_expiry_date_format=model.token_expiry_date_format, message_repository=self._message_repository, @@ -1938,6 +1941,7 @@ def create_oauth_authenticator( grant_type_name=model.grant_type_name or "grant_type", grant_type=model.grant_type or "refresh_token", refresh_request_body=model.refresh_request_body, + refresh_request_headers=model.refresh_request_headers, refresh_token_name=model.refresh_token_name or "refresh_token", refresh_token=model.refresh_token, scopes=model.scopes, diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index d70d318fe..753d79269 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -98,6 +98,14 @@ def build_refresh_request_body(self) -> Mapping[str, Any]: return payload + def build_refresh_request_headers(self) -> Mapping[str, Any] | None: + """ + Returns the request headers to set on the refresh request + + """ + headers = self.get_refresh_request_headers() + return headers if headers else None + def _wrap_refresh_token_exception( self, exception: requests.exceptions.RequestException ) -> bool: @@ -128,6 +136,7 @@ def _get_refresh_access_token_response(self) -> Any: method="POST", url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected. data=self.build_refresh_request_body(), + headers=self.build_refresh_request_headers(), ) if response.ok: response_json = response.json() @@ -254,6 +263,10 @@ def get_expires_in_name(self) -> str: def get_refresh_request_body(self) -> Mapping[str, Any]: """Returns the request body to set on the refresh request""" + @abstractmethod + def get_refresh_request_headers(self) -> Mapping[str, Any]: + """Returns the request headers to set on the refresh request""" + @abstractmethod def get_grant_type(self) -> str: """Returns grant_type specified for requesting access_token""" diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py index 3f3111ce8..f244e6508 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py @@ -39,6 +39,7 @@ def __init__( access_token_name: str = "access_token", expires_in_name: str = "expires_in", refresh_request_body: Mapping[str, Any] | None = None, + refresh_request_headers: Mapping[str, Any] | None = None, grant_type_name: str = "grant_type", grant_type: str = "refresh_token", token_expiry_is_time_of_expiration: bool = False, @@ -57,6 +58,7 @@ def __init__( self._access_token_name = access_token_name self._expires_in_name = expires_in_name self._refresh_request_body = refresh_request_body + self._refresh_request_headers = refresh_request_headers self._grant_type_name = grant_type_name self._grant_type = grant_type @@ -101,6 +103,9 @@ def get_expires_in_name(self) -> str: def get_refresh_request_body(self) -> Mapping[str, Any]: return self._refresh_request_body # type: ignore [return-value] + def get_refresh_request_headers(self) -> Mapping[str, Any]: + return self._refresh_request_headers # type: ignore [return-value] + def get_grant_type_name(self) -> str: return self._grant_type_name @@ -149,6 +154,7 @@ def __init__( expires_in_name: str = "expires_in", refresh_token_name: str = "refresh_token", refresh_request_body: Mapping[str, Any] | None = None, + refresh_request_headers: Mapping[str, Any] | None = None, grant_type_name: str = "grant_type", grant_type: str = "refresh_token", client_id_name: str = "client_id", @@ -174,6 +180,7 @@ def __init__( expires_in_name (str, optional): Name of the name of the field that characterizes when the current access token will expire, used to parse the refresh token response. Defaults to "expires_in". refresh_token_name (str, optional): Name of the name of the refresh token field, used to parse the refresh token response. Defaults to "refresh_token". refresh_request_body (Mapping[str, Any], optional): Custom key value pair that will be added to the refresh token request body. Defaults to None. + refresh_request_headers (Mapping[str, Any], optional): Custom key value pair that will be added to the refresh token request headers. Defaults to None. grant_type (str, optional): OAuth grant type. Defaults to "refresh_token". client_id (Optional[str]): The client id to authenticate. If not specified, defaults to credentials.client_id in the config object. client_secret (Optional[str]): The client secret to authenticate. If not specified, defaults to credentials.client_secret in the config object. @@ -220,6 +227,7 @@ def __init__( access_token_name=access_token_name, expires_in_name=expires_in_name, refresh_request_body=refresh_request_body, + refresh_request_headers=refresh_request_headers, grant_type_name=self._grant_type_name, grant_type=grant_type, token_expiry_date_format=token_expiry_date_format, diff --git a/unit_tests/sources/declarative/auth/test_oauth.py b/unit_tests/sources/declarative/auth/test_oauth.py index 4130a9dc8..dc384bb10 100644 --- a/unit_tests/sources/declarative/auth/test_oauth.py +++ b/unit_tests/sources/declarative/auth/test_oauth.py @@ -69,6 +69,42 @@ def test_refresh_request_body(self): } assert body == expected + def test_refresh_request_headers(self): + """ + Request headers should match given configuration. + """ + oauth = DeclarativeOauth2Authenticator( + token_refresh_endpoint="{{ config['refresh_endpoint'] }}", + client_id="{{ config['client_id'] }}", + client_secret="{{ config['client_secret'] }}", + refresh_token="{{ parameters['refresh_token'] }}", + config=config, + token_expiry_date="{{ config['token_expiry_date'] }}", + refresh_request_headers={ + "Authorization": "Basic {{ [config['client_id'], config['client_secret']] | join(':') | base64encode }}", + "Content-Type": "application/x-www-form-urlencoded", + }, + parameters=parameters, + ) + headers = oauth.build_refresh_request_headers() + expected = { + "Authorization": "Basic c29tZV9jbGllbnRfaWQ6c29tZV9jbGllbnRfc2VjcmV0", + "Content-Type": "application/x-www-form-urlencoded", + } + assert headers == expected + + oauth = DeclarativeOauth2Authenticator( + token_refresh_endpoint="{{ config['refresh_endpoint'] }}", + client_id="{{ config['client_id'] }}", + client_secret="{{ config['client_secret'] }}", + refresh_token="{{ parameters['refresh_token'] }}", + config=config, + token_expiry_date="{{ config['token_expiry_date'] }}", + parameters=parameters, + ) + headers = oauth.build_refresh_request_headers() + assert headers is None + def test_refresh_with_encode_config_params(self): oauth = DeclarativeOauth2Authenticator( token_refresh_endpoint="{{ config['refresh_endpoint'] }}", @@ -191,6 +227,36 @@ def test_refresh_access_token(self, mocker): filtered = filter_secrets("access_token") assert filtered == "****" + def test_refresh_access_token_when_headers_provided(self, mocker): + expected_headers = { + "Authorization": "Bearer some_access_token", + "Content-Type": "application/x-www-form-urlencoded", + } + oauth = DeclarativeOauth2Authenticator( + token_refresh_endpoint="{{ config['refresh_endpoint'] }}", + client_id="{{ config['client_id'] }}", + client_secret="{{ config['client_secret'] }}", + refresh_token="{{ config['refresh_token'] }}", + config=config, + scopes=["scope1", "scope2"], + token_expiry_date="{{ config['token_expiry_date'] }}", + refresh_request_headers=expected_headers, + parameters={}, + ) + + resp.status_code = 200 + mocker.patch.object( + resp, "json", return_value={"access_token": "access_token", "expires_in": 1000} + ) + mocked_request = mocker.patch.object( + requests, "request", side_effect=mock_request, autospec=True + ) + token = oauth.refresh_access_token() + + assert ("access_token", 1000) == token + + assert mocked_request.call_args.kwargs["headers"] == expected_headers + def test_refresh_access_token_missing_access_token(self, mocker): oauth = DeclarativeOauth2Authenticator( token_refresh_endpoint="{{ config['refresh_endpoint'] }}", @@ -371,7 +437,9 @@ def test_error_handling(self, mocker): assert e.value.errno == 400 -def mock_request(method, url, data): +def mock_request(method, url, data, headers): if url == "refresh_end": return resp - raise Exception(f"Error while refreshing access token with request: {method}, {url}, {data}") + raise Exception( + f"Error while refreshing access token with request: {method}, {url}, {data}, {headers}" + ) diff --git a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py index 1fd2c611d..4d0572f31 100644 --- a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py +++ b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py @@ -165,6 +165,38 @@ def test_refresh_request_body(self): } assert body == expected + def test_refresh_request_headers(self): + """ + Request headers should match given configuration. + """ + oauth = Oauth2Authenticator( + token_refresh_endpoint="refresh_end", + client_id="some_client_id", + client_secret="some_client_secret", + refresh_token="some_refresh_token", + token_expiry_date=pendulum.now().add(days=3), + refresh_request_headers={ + "Authorization": "Bearer some_refresh_token", + "Content-Type": "application/x-www-form-urlencoded", + }, + ) + headers = oauth.build_refresh_request_headers() + expected = { + "Authorization": "Bearer some_refresh_token", + "Content-Type": "application/x-www-form-urlencoded", + } + assert headers == expected + + oauth = Oauth2Authenticator( + token_refresh_endpoint="refresh_end", + client_id="some_client_id", + client_secret="some_client_secret", + refresh_token="some_refresh_token", + token_expiry_date=pendulum.now().add(days=3), + ) + headers = oauth.build_refresh_request_headers() + assert headers is None + def test_refresh_request_body_with_keys_override(self): """ Request body should match given configuration. @@ -245,6 +277,35 @@ def test_refresh_access_token(self, mocker): assert isinstance(expires_in, str) assert ("access_token", "2022-04-24T00:00:00Z") == (token, expires_in) + def test_refresh_access_token_when_headers_provided(self, mocker): + expected_headers = { + "Authorization": "Bearer some_access_token", + "Content-Type": "application/x-www-form-urlencoded", + } + oauth = Oauth2Authenticator( + token_refresh_endpoint="refresh_end", + client_id="some_client_id", + client_secret="some_client_secret", + refresh_token="some_refresh_token", + scopes=["scope1", "scope2"], + token_expiry_date=pendulum.now().add(days=3), + refresh_request_headers=expected_headers, + ) + + resp.status_code = 200 + mocker.patch.object( + resp, "json", return_value={"access_token": "access_token", "expires_in": 1000} + ) + mocked_request = mocker.patch.object( + requests, "request", side_effect=mock_request, autospec=True + ) + token, expires_in = oauth.refresh_access_token() + + assert isinstance(expires_in, int) + assert ("access_token", 1000) == (token, expires_in) + + assert mocked_request.call_args.kwargs["headers"] == expected_headers + @pytest.mark.parametrize( "expires_in_response, token_expiry_date_format, expected_token_expiry_date", [ @@ -557,7 +618,9 @@ def test_refresh_access_token(self, mocker, connector_config): ) -def mock_request(method, url, data): +def mock_request(method, url, data, headers): if url == "refresh_end": return resp - raise Exception(f"Error while refreshing access token with request: {method}, {url}, {data}") + raise Exception( + f"Error while refreshing access token with request: {method}, {url}, {data}, {headers}" + )