Skip to content

Commit 7d80a66

Browse files
authored
Merge pull request #644 from microsoft/fix/azure-auth-derive-scope-from-hostname
fix(authentication-azure): derive token scope from hostname, not netloc.
2 parents 9158456 + af8d841 commit 7d80a66

2 files changed

Lines changed: 119 additions & 9 deletions

File tree

packages/authentication/azure/kiota_authentication_azure/azure_identity_access_token_provider.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,21 +96,18 @@ async def get_authorization_token(
9696
decoded_bytes = base64.b64decode(additional_authentication_context[self.CLAIMS_KEY])
9797
decoded_claim = decoded_bytes.decode("utf-8")
9898

99-
if not self._scopes:
100-
self._scopes = [f"{parsed_url.scheme}://{parsed_url.netloc}/.default"]
101-
span.set_attribute(self.SCOPES, ",".join(self._scopes))
99+
# Derive the scope per-call from the request hostname.
100+
scopes = self._resolve_scopes(parsed_url, span)
101+
span.set_attribute(self.SCOPES, ",".join(scopes))
102102
span.set_attribute(self.ADDITIONAL_CLAIMS_PROVIDED, bool(self._options))
103103

104104
if self._options:
105105
result = self._credentials.get_token(
106-
*self._scopes,
107-
claims=decoded_claim,
108-
enable_cae=self._is_cae_enabled,
109-
**self._options
106+
*scopes, claims=decoded_claim, enable_cae=self._is_cae_enabled, **self._options
110107
)
111108
else:
112109
result = self._credentials.get_token(
113-
*self._scopes, claims=decoded_claim, enable_cae=self._is_cae_enabled
110+
*scopes, claims=decoded_claim, enable_cae=self._is_cae_enabled
114111
)
115112

116113
if inspect.isawaitable(result):
@@ -127,3 +124,24 @@ def get_allowed_hosts_validator(self) -> AllowedHostsValidator:
127124
AllowedHostsValidator: The allowed hosts validator.
128125
"""
129126
return self._allowed_hosts_validator
127+
128+
def _resolve_scopes(self, parsed_url, span) -> list[str]:
129+
"""Return the scopes to pass to `get_token` for this request.
130+
131+
Caller-supplied scopes are returned verbatim. Otherwise a default
132+
`.default` scope is derived from the request hostname only, so that
133+
userinfo (`user:password@`) and ports (which Entra ID rejects for
134+
`.default` scopes) are never copied into the scope or telemetry.
135+
IPv6 literal brackets stripped by `urlparse` are re-added.
136+
"""
137+
if self._scopes:
138+
return self._scopes
139+
hostname = parsed_url.hostname
140+
if not hostname:
141+
span.set_attribute(self.IS_VALID_URL, False)
142+
exc = HTTPError("Valid url scheme and host required")
143+
span.record_exception(exc)
144+
raise exc
145+
if ":" in hostname:
146+
hostname = f"[{hostname}]"
147+
return [f"{parsed_url.scheme}://{hostname}/.default"]

packages/authentication/azure/tests/test_azure_identity_access_token_provider.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,96 @@ async def test_get_authorization_token_localhost():
8383
token_provider = AzureIdentityAccessTokenProvider(DummySyncAzureTokenCredential(), None)
8484
token = await token_provider.get_authorization_token('HTTP://LOCALHOST:8080')
8585
assert token
86-
86+
87+
88+
class RecordingSyncAzureTokenCredential(DummySyncAzureTokenCredential):
89+
"""Sync credential that records the scopes passed to get_token."""
90+
91+
def __init__(self):
92+
self.received_scopes: list[tuple[str, ...]] = []
93+
94+
def get_token(self, *scopes, **kwargs):
95+
self.received_scopes.append(scopes)
96+
return super().get_token(*scopes, **kwargs)
97+
98+
99+
@pytest.mark.asyncio
100+
async def test_derived_scope_strips_userinfo_and_port():
101+
"""The default `.default` scope passed to `get_token` must be derived
102+
from the hostname only — never include userinfo or
103+
a `:port` (which Entra ID rejects for `.default` scopes).
104+
"""
105+
credential = RecordingSyncAzureTokenCredential()
106+
token_provider = AzureIdentityAccessTokenProvider(credential, None)
107+
108+
await token_provider.get_authorization_token(
109+
'https://alice:secret@graph.microsoft.com:8443/v1.0/me'
110+
)
111+
112+
assert credential.received_scopes == [('https://graph.microsoft.com/.default',)]
113+
114+
115+
@pytest.mark.asyncio
116+
async def test_derived_scope_is_not_cached_across_hosts():
117+
"""The first URL's derived scope must not be reused for later URLs.
118+
119+
Previously the scope was assigned to `self._scopes`, making it sticky for
120+
the lifetime of the provider instance and causing tokens to be requested
121+
for the wrong audience after the first call.
122+
"""
123+
credential = RecordingSyncAzureTokenCredential()
124+
token_provider = AzureIdentityAccessTokenProvider(credential, None)
125+
126+
await token_provider.get_authorization_token('https://graph.microsoft.com/v1.0/me')
127+
await token_provider.get_authorization_token('https://graph.microsoft.us/v1.0/me')
128+
129+
assert credential.received_scopes == [
130+
('https://graph.microsoft.com/.default',),
131+
('https://graph.microsoft.us/.default',),
132+
]
133+
# Provider must not have cached derived scopes into `_scopes`.
134+
assert token_provider._scopes == []
135+
136+
137+
@pytest.mark.asyncio
138+
async def test_explicit_scopes_are_respected():
139+
credential = RecordingSyncAzureTokenCredential()
140+
token_provider = AzureIdentityAccessTokenProvider(
141+
credential, None, scopes=['https://graph.microsoft.com/.default']
142+
)
143+
144+
await token_provider.get_authorization_token('https://graph.microsoft.com/v1.0/me')
145+
await token_provider.get_authorization_token('https://graph.microsoft.us/v1.0/me')
146+
147+
assert credential.received_scopes == [
148+
('https://graph.microsoft.com/.default',),
149+
('https://graph.microsoft.com/.default',),
150+
]
151+
152+
153+
@pytest.mark.asyncio
154+
async def test_derived_scope_rejects_url_without_hostname():
155+
"""A URI whose netloc has no hostname (e.g. `https://@/path`) must not
156+
silently derive a scope like `https://None/.default`; it must raise.
157+
"""
158+
credential = RecordingSyncAzureTokenCredential()
159+
token_provider = AzureIdentityAccessTokenProvider(credential, None)
160+
161+
with pytest.raises(Exception):
162+
await token_provider.get_authorization_token('https://@/path')
163+
assert credential.received_scopes == []
164+
165+
166+
@pytest.mark.asyncio
167+
async def test_derived_scope_brackets_ipv6_hostname():
168+
"""`urlparse` strips brackets from IPv6 literals; the derived scope
169+
must re-add them so the resulting URL is syntactically valid.
170+
"""
171+
credential = RecordingSyncAzureTokenCredential()
172+
token_provider = AzureIdentityAccessTokenProvider(credential, None)
173+
174+
await token_provider.get_authorization_token('https://[2001:db8::1]/v1.0/me')
175+
176+
assert credential.received_scopes == [('https://[2001:db8::1]/.default',)]
177+
178+

0 commit comments

Comments
 (0)