diff --git a/docs/source/tutorials/provider-specific-setup/providers/azuread.md b/docs/source/tutorials/provider-specific-setup/providers/azuread.md index 7528159e..09359d69 100644 --- a/docs/source/tutorials/provider-specific-setup/providers/azuread.md +++ b/docs/source/tutorials/provider-specific-setup/providers/azuread.md @@ -31,3 +31,20 @@ AzureAdOAuthenticator expands OAuthenticator with the following config that may be relevant to read more about in the configuration reference: - {attr}`.AzureAdOAuthenticator.tenant_id` + +## Loading user groups + +The `AzureAdOAuthenticator` can load the group-membership of users from the access token. +This is done by setting the `AzureAdOAuthenticator.groups_claim` to the name of the claim that contains the +group-membership. + +```python +c.JupyterHub.authenticator_class = "azuread" + +# {...} other settings (see above) + +c.AzureAdOAuthenticator.manage_groups = True +c.AzureAdOAuthenticator.user_groups_claim = 'groups' # this is the default +``` + +This requires Azure AD to be configured to include the group-membership in the access token. diff --git a/oauthenticator/azuread.py b/oauthenticator/azuread.py index 399c704b..48359864 100644 --- a/oauthenticator/azuread.py +++ b/oauthenticator/azuread.py @@ -21,6 +21,16 @@ def _login_service_default(self): def _username_claim_default(self): return "name" + user_groups_claim = Unicode( + "groups", + config=True, + help=""" + Name of claim containing user group memberships. + + Will populate JupyterHub groups if Authenticator.manage_groups is True. + """, + ) + tenant_id = Unicode( config=True, help=""" @@ -44,6 +54,15 @@ def _authorize_url_default(self): def _token_url_default(self): return f"https://login.microsoftonline.com/{self.tenant_id}/oauth2/token" + async def update_auth_model(self, auth_model, **kwargs): + auth_model = await super().update_auth_model(auth_model, **kwargs) + + if getattr(self, "manage_groups", False): + user_info = auth_model["auth_state"][self.user_auth_state_key] + auth_model["groups"] = user_info[self.user_groups_claim] + + return auth_model + async def token_to_user(self, token_info): id_token = token_info['id_token'] decoded = jwt.decode( diff --git a/oauthenticator/tests/test_azuread.py b/oauthenticator/tests/test_azuread.py index 46d9395b..b3537dd5 100644 --- a/oauthenticator/tests/test_azuread.py +++ b/oauthenticator/tests/test_azuread.py @@ -7,6 +7,7 @@ from unittest import mock import jwt +import pytest from pytest import fixture, mark from traitlets.config import Config @@ -44,6 +45,17 @@ def user_model(tenant_id, client_id, name): "tid": tenant_id, "nonce": "123523", "aio": "Df2UVXL1ix!lMCWMSOJBcFatzcGfvFGhjKv8q5g0x732dR5MB5BisvGQO7YWByjd8iQDLq!eGbIDakyp5mnOrcdqHeYSnltepQmRp6AIZ8jY", + "groups": [ + "96000b2c-7333-4f6e-a2c3-e7608fa2d131", + "a992b3d5-1966-4af4-abed-6ef021417be4", + "ceb90a42-030f-44f1-a0c7-825b572a3b07", + ], + # different from 'groups' for tests + "grp": [ + "96000b2c-7333-4f6e-a2c3", + "a992b3d5-1966-4af4-abed", + "ceb90a42-030f-44f1-a0c7", + ], }, os.urandom(5), ) @@ -103,6 +115,23 @@ def user_model(tenant_id, client_id, name): True, None, ), + # test user_groups_claim + ( + "30", + {"allow_all": True, "manage_groups": True}, + True, + None, + ), + ( + "31", + { + "allow_all": True, + "manage_groups": True, + "user_groups_claim": "grp", + }, + True, + None, + ), ], ) async def test_azuread( @@ -119,6 +148,12 @@ async def test_azuread( c.AzureAdOAuthenticator.client_id = str(uuid.uuid1()) c.AzureAdOAuthenticator.client_secret = str(uuid.uuid1()) authenticator = AzureAdOAuthenticator(config=c) + manage_groups = False + if "manage_groups" in class_config: + if hasattr(authenticator, "manage_groups"): + manage_groups = authenticator.manage_groups + else: + pytest.skip("manage_groups requires jupyterhub 2.2") handled_user_model = user_model( tenant_id=authenticator.tenant_id, @@ -130,7 +165,10 @@ async def test_azuread( if expect_allowed: assert auth_model - assert set(auth_model) == {"name", "admin", "auth_state"} + expected_keys = {"name", "admin", "auth_state"} + if manage_groups: + expected_keys.add("groups") + assert set(auth_model) == expected_keys assert auth_model["admin"] == expect_admin auth_state = auth_model["auth_state"] assert json.dumps(auth_state) @@ -138,6 +176,9 @@ async def test_azuread( user_info = auth_state[authenticator.user_auth_state_key] assert user_info["aud"] == authenticator.client_id assert auth_model["name"] == user_info[authenticator.username_claim] + if manage_groups: + groups = auth_model['groups'] + assert groups == user_info[authenticator.user_groups_claim] else: assert auth_model == None