Skip to content

Commit

Permalink
apply manage_groups behavior outside update_auth_model
Browse files Browse the repository at this point in the history
always applied last, not overrideable by subclasses.

Subclasses govern this behavior via `get_user_groups`
  • Loading branch information
minrk committed Sep 3, 2024
1 parent da158fe commit d9ac90e
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 32 deletions.
26 changes: 14 additions & 12 deletions oauthenticator/azuread.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,26 @@ def _username_claim_default(self):
return "name"

user_groups_claim = Unicode(
"groups",
"",
config=True,
help="""
Name of claim containing user group memberships.
.. deprecated:: 17.0
Will populate JupyterHub groups if Authenticator.manage_groups is True.
Use :attr:`auth_state_groups_key` instead.
""",
)

@default('auth_state_groups_key')
def _auth_state_groups_key_default(self):
key = ""
if self.user_groups_claim:
key = f"{self.user_auth_state_key}.{self.user_groups_claim}"
cls = self.__class__.__name__
self.log.warning(
f"{cls}.user_groups_claim is deprecated in OAuthenticator 17. Use {cls}.auth_state_groups_key={key!r}"
)
return key

tenant_id = Unicode(
config=True,
help="""
Expand All @@ -55,15 +66,6 @@ def _authorize_url_default(self):
def _token_url_default(self):
return f"https://login.microsoftonline.com/{self.tenant_id}/oauth2/token"

async def update_auth_model(self, auth_model, **kwargs):
auth_model = await super().update_auth_model(auth_model, **kwargs)

if getattr(self, "manage_groups", False):
user_info = auth_model["auth_state"][self.user_auth_state_key]
auth_model["groups"] = user_info[self.user_groups_claim]

return auth_model

async def token_to_user(self, token_info):
id_token = token_info['id_token']
decoded = jwt.decode(
Expand Down
2 changes: 1 addition & 1 deletion oauthenticator/globus.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ async def update_auth_model(self, auth_model):
to False makes it be revoked.
"""
user_groups = set()
if self.allowed_globus_groups or self.admin_globus_groups:
if self.allowed_globus_groups or self.admin_globus_groups or self.manage_groups:
tokens = self.get_globus_tokens(auth_model["auth_state"]["token_response"])
user_groups = await self._fetch_users_groups(tokens)
# sets are not JSONable, cast to list for auth_state
Expand Down
33 changes: 23 additions & 10 deletions oauthenticator/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,7 +1103,7 @@ def build_auth_state_dict(self, token_info, user_info):
Called by the :meth:`oauthenticator.OAuthenticator.authenticate`
.. versionchanged:: 17.0
This method be async.
This method may be async.
"""

# We know for sure the `access_token` key exists, oterwise we would have errored out already
Expand Down Expand Up @@ -1133,6 +1133,8 @@ async def get_user_groups(self, auth_state: dict):
Returns a set of groups the user belongs to based on auth_state_groups_key
and provided auth_state.
Only called when :attr:`manage_groups` is True.
- If auth_state_groups_key is a callable, it returns the list of groups directly.
Callable may be async.
- If auth_state_groups_key is a nested dictionary key like
Expand Down Expand Up @@ -1168,11 +1170,23 @@ async def update_auth_model(self, auth_model):
- `name`: the normalized username
- `admin`: the admin status (True/False/None), where None means it
should be unchanged.
- `auth_state`: the dictionary of of auth state
returned by :meth:`oauthenticator.OAuthenticator.build_auth_state_dict`
- `auth_state`: the auth state dictionary,
returned by :meth:`oauthenticator.OAuthenticator.build_auth_state_dict`
Called by the :meth:`oauthenticator.OAuthenticator.authenticate`
"""
# NOTE: this base implementation should _not_ be updated to do anything
# subclasses should have full control without calling super()
return auth_model

async def _apply_managed_groups(self, auth_model):
"""Applies managed_groups logic
Called after `update_auth_model` to populate the `groups` field.
Only called if `manage_groups` is True.
The public method for subclasses to override is `.get_user_groups`.
"""
if self.manage_groups:
auth_state = auth_model["auth_state"]
user_groups = self.get_user_groups(auth_state)
Expand Down Expand Up @@ -1244,7 +1258,10 @@ async def authenticate(self, handler, data=None, **kwargs):

# update the auth_model with info to later authorize the user in
# check_allowed, such as admin status and group memberships
return await self.update_auth_model(auth_model)
auth_model = await self.update_auth_model(auth_model)
if self.manage_groups:
auth_model = await self._apply_managed_groups(auth_model)
return auth_model

async def check_allowed(self, username, auth_model):
"""
Expand Down Expand Up @@ -1289,12 +1306,8 @@ async def check_allowed(self, username, auth_model):
return True

# allow users who are members of allowed_groups
if self.manage_groups and self.allowed_groups:
auth_state = auth_model["auth_state"]
user_groups = self.get_user_groups(auth_state)
if isawaitable(user_groups):
user_groups = await user_groups
if any(user_groups & self.allowed_groups):
if self.manage_groups and self.allowed_groups and auth_model.get("groups"):
if set(auth_model["groups"]) & self.allowed_groups:
return True

# users should be explicitly allowed via config, otherwise they aren't
Expand Down
14 changes: 10 additions & 4 deletions oauthenticator/tests/test_azuread.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,11 @@ def user_model(tenant_id, client_id, name):
# test user_groups_claim
(
"30",
{"allow_all": True, "manage_groups": True},
{
"allow_all": True,
"auth_state_groups_key": "user.groups",
"manage_groups": True,
},
True,
None,
),
Expand All @@ -128,7 +132,7 @@ def user_model(tenant_id, client_id, name):
{
"allow_all": True,
"manage_groups": True,
"user_groups_claim": "grp",
"auth_state_groups_key": "user.grp",
},
True,
None,
Expand Down Expand Up @@ -220,9 +224,11 @@ async def test_azuread(
assert auth_model["name"] == user_info[authenticator.username_claim]
if manage_groups:
groups = auth_model['groups']
assert groups == user_info[authenticator.user_groups_claim]
assert (
groups == user_info[authenticator.auth_state_groups_key.rsplit(".")[-1]]
)
else:
assert auth_model == None
assert auth_model is None


async def test_tenant_id_from_env():
Expand Down
10 changes: 5 additions & 5 deletions oauthenticator/tests/test_mediawiki.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def post_token(request, context):
"20",
{
"allowed_groups": {"group1"},
"auth_state_groups_key": "mediawiki_user.groups",
"auth_state_groups_key": "MEDIAWIKI_USER_IDENTITY.groups",
"manage_groups": True,
},
True,
Expand All @@ -90,7 +90,7 @@ def post_token(request, context):
"21",
{
"allowed_groups": {"test-user-not-in-group"},
"auth_state_groups_key": "mediawiki_user.groups",
"auth_state_groups_key": "MEDIAWIKI_USER_IDENTITY.groups",
"manage_groups": True,
},
False,
Expand All @@ -100,7 +100,7 @@ def post_token(request, context):
"22",
{
"admin_groups": {"group1"},
"auth_state_groups_key": "mediawiki_user.groups",
"auth_state_groups_key": "MEDIAWIKI_USER_IDENTITY.groups",
"manage_groups": True,
},
True,
Expand All @@ -110,7 +110,7 @@ def post_token(request, context):
"23",
{
"admin_groups": {"test-user-not-in-group"},
"auth_state_groups_key": "mediawiki_user.groups",
"auth_state_groups_key": "MEDIAWIKI_USER_IDENTITY.groups",
"manage_groups": True,
},
False,
Expand Down Expand Up @@ -155,7 +155,7 @@ async def test_mediawiki(
user_info = auth_state[authenticator.user_auth_state_key]
assert auth_model["name"] == user_info[authenticator.username_claim]
else:
assert auth_model == None
assert auth_model is None


async def test_login_redirect(mediawiki):
Expand Down

0 comments on commit d9ac90e

Please sign in to comment.