Skip to content

Commit 414e91d

Browse files
author
hpal
committed
make optional oauth param configurable including audience and resource
remove scope from this. fix build error due to dict usage add another test for empty oauth properties change name for the method
1 parent 36b56eb commit 414e91d

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

pyiceberg/catalog/rest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ class Endpoints:
105105
CREDENTIAL = "credential"
106106
GRANT_TYPE = "grant_type"
107107
SCOPE = "scope"
108+
AUDIENCE = "audience"
109+
RESOURCE = "resource"
108110
TOKEN_EXCHANGE = "urn:ietf:params:oauth:grant-type:token-exchange"
109111
SEMICOLON = ":"
110112
KEY = "key"
@@ -289,12 +291,25 @@ def auth_url(self) -> str:
289291
else:
290292
return self.url(Endpoints.get_token, prefixed=False)
291293

294+
def _extract_optional_oauth_params(self) -> Dict[str, str]:
295+
set_of_optional_params = {AUDIENCE, RESOURCE}
296+
optional_oauth_param = {}
297+
for param in set_of_optional_params:
298+
if param_value := self.properties.get(param):
299+
optional_oauth_param[param] = param_value
300+
301+
return optional_oauth_param
302+
292303
def _fetch_access_token(self, session: Session, credential: str) -> str:
293304
if SEMICOLON in credential:
294305
client_id, client_secret = credential.split(SEMICOLON)
295306
else:
296307
client_id, client_secret = None, credential
297308
data = {GRANT_TYPE: CLIENT_CREDENTIALS, CLIENT_ID: client_id, CLIENT_SECRET: client_secret, SCOPE: CATALOG_SCOPE}
309+
310+
optional_oauth_params = self._extract_optional_oauth_params()
311+
data.update(optional_oauth_params)
312+
298313
response = session.post(
299314
url=self.auth_url, data=data, headers={**session.headers, "Content-type": "application/x-www-form-urlencoded"}
300315
)

tests/catalog/test_rest.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@
4646
TEST_CREDENTIALS = "client:secret"
4747
TEST_AUTH_URL = "https://auth-endpoint/"
4848
TEST_TOKEN = "some_jwt_token"
49+
TEST_AUDIENCE = "test_audience"
50+
TEST_RESOURCE = "test_resource"
51+
4952
TEST_HEADERS = {
5053
"Content-type": "application/json",
5154
"X-Client-Version": "0.14.1",
@@ -136,6 +139,48 @@ def test_token_200_without_optional_fields(rest_mock: Mocker) -> None:
136139
)
137140

138141

142+
def test_token_with_optional_oauth_params(rest_mock: Mocker) -> None:
143+
mock_request = rest_mock.post(
144+
f"{TEST_URI}v1/oauth/tokens",
145+
json={
146+
"access_token": TEST_TOKEN,
147+
"token_type": "Bearer",
148+
"expires_in": 86400,
149+
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
150+
},
151+
status_code=200,
152+
request_headers=OAUTH_TEST_HEADERS,
153+
)
154+
assert (
155+
RestCatalog(
156+
"rest", uri=TEST_URI, credential=TEST_CREDENTIALS, audience=TEST_AUDIENCE, resource=TEST_RESOURCE
157+
)._session.headers["Authorization"]
158+
== f"Bearer {TEST_TOKEN}"
159+
)
160+
assert TEST_AUDIENCE in mock_request.last_request.text
161+
assert TEST_RESOURCE in mock_request.last_request.text
162+
163+
164+
def test_token_with_optional_oauth_params_as_empty(rest_mock: Mocker) -> None:
165+
mock_request = rest_mock.post(
166+
f"{TEST_URI}v1/oauth/tokens",
167+
json={
168+
"access_token": TEST_TOKEN,
169+
"token_type": "Bearer",
170+
"expires_in": 86400,
171+
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
172+
},
173+
status_code=200,
174+
request_headers=OAUTH_TEST_HEADERS,
175+
)
176+
assert (
177+
RestCatalog("rest", uri=TEST_URI, credential=TEST_CREDENTIALS, audience="", resource="")._session.headers["Authorization"]
178+
== f"Bearer {TEST_TOKEN}"
179+
)
180+
assert TEST_AUDIENCE not in mock_request.last_request.text
181+
assert TEST_RESOURCE not in mock_request.last_request.text
182+
183+
139184
def test_token_200_w_auth_url(rest_mock: Mocker) -> None:
140185
rest_mock.post(
141186
TEST_AUTH_URL,

0 commit comments

Comments
 (0)