From d9ac90e7ee3543301f2167c9bd3f501833375a70 Mon Sep 17 00:00:00 2001 From: Min RK Date: Tue, 3 Sep 2024 12:27:49 +0200 Subject: [PATCH] apply manage_groups behavior outside `update_auth_model` always applied last, not overrideable by subclasses. Subclasses govern this behavior via `get_user_groups` --- oauthenticator/azuread.py | 26 ++++++++++---------- oauthenticator/globus.py | 2 +- oauthenticator/oauth2.py | 33 ++++++++++++++++++-------- oauthenticator/tests/test_azuread.py | 14 +++++++---- oauthenticator/tests/test_mediawiki.py | 10 ++++---- 5 files changed, 53 insertions(+), 32 deletions(-) diff --git a/oauthenticator/azuread.py b/oauthenticator/azuread.py index e4e6682b..f6db57db 100644 --- a/oauthenticator/azuread.py +++ b/oauthenticator/azuread.py @@ -23,15 +23,26 @@ def _username_claim_default(self): return "name" user_groups_claim = Unicode( - "groups", + "", config=True, help=""" - Name of claim containing user group memberships. + .. deprecated:: 17.0 - Will populate JupyterHub groups if Authenticator.manage_groups is True. + Use :attr:`auth_state_groups_key` instead. """, ) + @default('auth_state_groups_key') + def _auth_state_groups_key_default(self): + key = "" + if self.user_groups_claim: + key = f"{self.user_auth_state_key}.{self.user_groups_claim}" + cls = self.__class__.__name__ + self.log.warning( + f"{cls}.user_groups_claim is deprecated in OAuthenticator 17. Use {cls}.auth_state_groups_key={key!r}" + ) + return key + tenant_id = Unicode( config=True, help=""" @@ -55,15 +66,6 @@ def _authorize_url_default(self): def _token_url_default(self): return f"https://login.microsoftonline.com/{self.tenant_id}/oauth2/token" - async def update_auth_model(self, auth_model, **kwargs): - auth_model = await super().update_auth_model(auth_model, **kwargs) - - if getattr(self, "manage_groups", False): - user_info = auth_model["auth_state"][self.user_auth_state_key] - auth_model["groups"] = user_info[self.user_groups_claim] - - return auth_model - async def token_to_user(self, token_info): id_token = token_info['id_token'] decoded = jwt.decode( diff --git a/oauthenticator/globus.py b/oauthenticator/globus.py index 266789e9..30a49484 100644 --- a/oauthenticator/globus.py +++ b/oauthenticator/globus.py @@ -342,7 +342,7 @@ async def update_auth_model(self, auth_model): to False makes it be revoked. """ user_groups = set() - if self.allowed_globus_groups or self.admin_globus_groups: + if self.allowed_globus_groups or self.admin_globus_groups or self.manage_groups: tokens = self.get_globus_tokens(auth_model["auth_state"]["token_response"]) user_groups = await self._fetch_users_groups(tokens) # sets are not JSONable, cast to list for auth_state diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index 5daa3574..8c62a67c 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -1103,7 +1103,7 @@ def build_auth_state_dict(self, token_info, user_info): Called by the :meth:`oauthenticator.OAuthenticator.authenticate` .. versionchanged:: 17.0 - This method be async. + This method may be async. """ # We know for sure the `access_token` key exists, oterwise we would have errored out already @@ -1133,6 +1133,8 @@ async def get_user_groups(self, auth_state: dict): Returns a set of groups the user belongs to based on auth_state_groups_key and provided auth_state. + Only called when :attr:`manage_groups` is True. + - If auth_state_groups_key is a callable, it returns the list of groups directly. Callable may be async. - If auth_state_groups_key is a nested dictionary key like @@ -1168,11 +1170,23 @@ async def update_auth_model(self, auth_model): - `name`: the normalized username - `admin`: the admin status (True/False/None), where None means it should be unchanged. - - `auth_state`: the dictionary of of auth state - returned by :meth:`oauthenticator.OAuthenticator.build_auth_state_dict` + - `auth_state`: the auth state dictionary, + returned by :meth:`oauthenticator.OAuthenticator.build_auth_state_dict` Called by the :meth:`oauthenticator.OAuthenticator.authenticate` """ + # NOTE: this base implementation should _not_ be updated to do anything + # subclasses should have full control without calling super() + return auth_model + + async def _apply_managed_groups(self, auth_model): + """Applies managed_groups logic + + Called after `update_auth_model` to populate the `groups` field. + Only called if `manage_groups` is True. + + The public method for subclasses to override is `.get_user_groups`. + """ if self.manage_groups: auth_state = auth_model["auth_state"] user_groups = self.get_user_groups(auth_state) @@ -1244,7 +1258,10 @@ async def authenticate(self, handler, data=None, **kwargs): # update the auth_model with info to later authorize the user in # check_allowed, such as admin status and group memberships - return await self.update_auth_model(auth_model) + auth_model = await self.update_auth_model(auth_model) + if self.manage_groups: + auth_model = await self._apply_managed_groups(auth_model) + return auth_model async def check_allowed(self, username, auth_model): """ @@ -1289,12 +1306,8 @@ async def check_allowed(self, username, auth_model): return True # allow users who are members of allowed_groups - if self.manage_groups and self.allowed_groups: - auth_state = auth_model["auth_state"] - user_groups = self.get_user_groups(auth_state) - if isawaitable(user_groups): - user_groups = await user_groups - if any(user_groups & self.allowed_groups): + if self.manage_groups and self.allowed_groups and auth_model.get("groups"): + if set(auth_model["groups"]) & self.allowed_groups: return True # users should be explicitly allowed via config, otherwise they aren't diff --git a/oauthenticator/tests/test_azuread.py b/oauthenticator/tests/test_azuread.py index 3a83dc8e..31230feb 100644 --- a/oauthenticator/tests/test_azuread.py +++ b/oauthenticator/tests/test_azuread.py @@ -119,7 +119,11 @@ def user_model(tenant_id, client_id, name): # test user_groups_claim ( "30", - {"allow_all": True, "manage_groups": True}, + { + "allow_all": True, + "auth_state_groups_key": "user.groups", + "manage_groups": True, + }, True, None, ), @@ -128,7 +132,7 @@ def user_model(tenant_id, client_id, name): { "allow_all": True, "manage_groups": True, - "user_groups_claim": "grp", + "auth_state_groups_key": "user.grp", }, True, None, @@ -220,9 +224,11 @@ async def test_azuread( assert auth_model["name"] == user_info[authenticator.username_claim] if manage_groups: groups = auth_model['groups'] - assert groups == user_info[authenticator.user_groups_claim] + assert ( + groups == user_info[authenticator.auth_state_groups_key.rsplit(".")[-1]] + ) else: - assert auth_model == None + assert auth_model is None async def test_tenant_id_from_env(): diff --git a/oauthenticator/tests/test_mediawiki.py b/oauthenticator/tests/test_mediawiki.py index 417ec7e5..794ba377 100644 --- a/oauthenticator/tests/test_mediawiki.py +++ b/oauthenticator/tests/test_mediawiki.py @@ -80,7 +80,7 @@ def post_token(request, context): "20", { "allowed_groups": {"group1"}, - "auth_state_groups_key": "mediawiki_user.groups", + "auth_state_groups_key": "MEDIAWIKI_USER_IDENTITY.groups", "manage_groups": True, }, True, @@ -90,7 +90,7 @@ def post_token(request, context): "21", { "allowed_groups": {"test-user-not-in-group"}, - "auth_state_groups_key": "mediawiki_user.groups", + "auth_state_groups_key": "MEDIAWIKI_USER_IDENTITY.groups", "manage_groups": True, }, False, @@ -100,7 +100,7 @@ def post_token(request, context): "22", { "admin_groups": {"group1"}, - "auth_state_groups_key": "mediawiki_user.groups", + "auth_state_groups_key": "MEDIAWIKI_USER_IDENTITY.groups", "manage_groups": True, }, True, @@ -110,7 +110,7 @@ def post_token(request, context): "23", { "admin_groups": {"test-user-not-in-group"}, - "auth_state_groups_key": "mediawiki_user.groups", + "auth_state_groups_key": "MEDIAWIKI_USER_IDENTITY.groups", "manage_groups": True, }, False, @@ -155,7 +155,7 @@ async def test_mediawiki( user_info = auth_state[authenticator.user_auth_state_key] assert auth_model["name"] == user_info[authenticator.username_claim] else: - assert auth_model == None + assert auth_model is None async def test_login_redirect(mediawiki):