Skip to content

Commit

Permalink
Merge pull request #758 from consideRatio/pr/groups-tests
Browse files Browse the repository at this point in the history
Various fixes for allowed_groups and admin_groups
  • Loading branch information
minrk authored Sep 3, 2024
2 parents d2aac2d + 9e5520d commit ad4034c
Show file tree
Hide file tree
Showing 18 changed files with 615 additions and 126 deletions.
2 changes: 2 additions & 0 deletions docs/source/tutorials/general-setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ projects' authenticator classes.
- {attr}`.OAuthenticator.allow_all`
- {attr}`.OAuthenticator.allow_existing_users`
- {attr}`.OAuthenticator.allowed_users`
- {attr}`.OAuthenticator.allowed_groups`
- {attr}`.OAuthenticator.admin_users`
- {attr}`.OAuthenticator.admin_groups`

Your authenticator class may have unique config, so in the end it can look
something like this:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,14 @@ be relevant to read more about in the configuration reference:
## Loading user groups

The `AzureAdOAuthenticator` can load the group-membership of users from the access token.
This is done by setting the `AzureAdOAuthenticator.groups_claim` to the name of the claim that contains the
group-membership.

```python
c.JupyterHub.authenticator_class = "azuread"

# {...} other settings (see above)

c.AzureAdOAuthenticator.manage_groups = True
c.AzureAdOAuthenticator.user_groups_claim = 'groups' # this is the default
c.AzureAdOAuthenticator.auth_state_groups_key = "user.groups" # this is the default
```

This requires Azure AD to be configured to include the group-membership in the access token.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ c.GenericOAuthenticator.userdata_url = "https://accounts.example.com/auth/realms
#
c.GenericOAuthenticator.scope = ["openid", "email", "groups"]
c.GenericOAuthenticator.username_claim = "email"
c.GenericOAuthenticator.claim_groups_key = "groups"
c.GenericOAuthenticator.auth_state_groups_key = "oauth_user.groups"

# Authorization
# -------------
Expand Down
26 changes: 14 additions & 12 deletions oauthenticator/azuread.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,26 @@ def _username_claim_default(self):
return "name"

user_groups_claim = Unicode(
"groups",
"",
config=True,
help="""
Name of claim containing user group memberships.
.. deprecated:: 17.0
Will populate JupyterHub groups if Authenticator.manage_groups is True.
Use :attr:`auth_state_groups_key` instead.
""",
)

@default('auth_state_groups_key')
def _auth_state_groups_key_default(self):
key = "user.groups"
if self.user_groups_claim:
key = f"{self.user_auth_state_key}.{self.user_groups_claim}"
cls = self.__class__.__name__
self.log.warning(
f"{cls}.user_groups_claim is deprecated in OAuthenticator 17. Use {cls}.auth_state_groups_key = {key!r}"
)
return key

tenant_id = Unicode(
config=True,
help="""
Expand All @@ -55,15 +66,6 @@ def _authorize_url_default(self):
def _token_url_default(self):
return f"https://login.microsoftonline.com/{self.tenant_id}/oauth2/token"

async def update_auth_model(self, auth_model, **kwargs):
auth_model = await super().update_auth_model(auth_model, **kwargs)

if getattr(self, "manage_groups", False):
user_info = auth_model["auth_state"][self.user_auth_state_key]
auth_model["groups"] = user_info[self.user_groups_claim]

return auth_model

async def token_to_user(self, token_info):
id_token = token_info['id_token']
decoded = jwt.decode(
Expand Down
2 changes: 1 addition & 1 deletion oauthenticator/globus.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ async def update_auth_model(self, auth_model):
to False makes it be revoked.
"""
user_groups = set()
if self.allowed_globus_groups or self.admin_globus_groups:
if self.allowed_globus_groups or self.admin_globus_groups or self.manage_groups:
tokens = self.get_globus_tokens(auth_model["auth_state"]["token_response"])
user_groups = await self._fetch_users_groups(tokens)
# sets are not JSONable, cast to list for auth_state
Expand Down
33 changes: 23 additions & 10 deletions oauthenticator/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,7 +1103,7 @@ def build_auth_state_dict(self, token_info, user_info):
Called by the :meth:`oauthenticator.OAuthenticator.authenticate`
.. versionchanged:: 17.0
This method be async.
This method may be async.
"""

# We know for sure the `access_token` key exists, oterwise we would have errored out already
Expand Down Expand Up @@ -1133,6 +1133,8 @@ async def get_user_groups(self, auth_state: dict):
Returns a set of groups the user belongs to based on auth_state_groups_key
and provided auth_state.
Only called when :attr:`manage_groups` is True.
- If auth_state_groups_key is a callable, it returns the list of groups directly.
Callable may be async.
- If auth_state_groups_key is a nested dictionary key like
Expand Down Expand Up @@ -1168,11 +1170,23 @@ async def update_auth_model(self, auth_model):
- `name`: the normalized username
- `admin`: the admin status (True/False/None), where None means it
should be unchanged.
- `auth_state`: the dictionary of of auth state
returned by :meth:`oauthenticator.OAuthenticator.build_auth_state_dict`
- `auth_state`: the auth state dictionary,
returned by :meth:`oauthenticator.OAuthenticator.build_auth_state_dict`
Called by the :meth:`oauthenticator.OAuthenticator.authenticate`
"""
# NOTE: this base implementation should _not_ be updated to do anything
# subclasses should have full control without calling super()
return auth_model

async def _apply_managed_groups(self, auth_model):
"""Applies managed_groups logic
Called after `update_auth_model` to populate the `groups` field.
Only called if `manage_groups` is True.
The public method for subclasses to override is `.get_user_groups`.
"""
if self.manage_groups:
auth_state = auth_model["auth_state"]
user_groups = self.get_user_groups(auth_state)
Expand Down Expand Up @@ -1244,7 +1258,10 @@ async def authenticate(self, handler, data=None, **kwargs):

# update the auth_model with info to later authorize the user in
# check_allowed, such as admin status and group memberships
return await self.update_auth_model(auth_model)
auth_model = await self.update_auth_model(auth_model)
if self.manage_groups:
auth_model = await self._apply_managed_groups(auth_model)
return auth_model

async def check_allowed(self, username, auth_model):
"""
Expand Down Expand Up @@ -1289,12 +1306,8 @@ async def check_allowed(self, username, auth_model):
return True

# allow users who are members of allowed_groups
if self.manage_groups and self.allowed_groups:
auth_state = auth_model["auth_state"]
user_groups = self.get_user_groups(auth_state)
if isawaitable(user_groups):
user_groups = await user_groups
if any(user_groups & self.allowed_groups):
if self.manage_groups and self.allowed_groups and auth_model.get("groups"):
if set(auth_model["groups"]) & self.allowed_groups:
return True

# users should be explicitly allowed via config, otherwise they aren't
Expand Down
60 changes: 5 additions & 55 deletions oauthenticator/openshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@

from jupyterhub.auth import LocalAuthenticator
from tornado.httpclient import HTTPClient, HTTPRequest
from traitlets import Bool, Set, Unicode, default
from traitlets import Bool, Unicode, default

from oauthenticator.oauth2 import OAuthenticator


class OpenShiftOAuthenticator(OAuthenticator):
user_auth_state_key = "openshift_user"

@default("auth_state_groups_key")
def _auth_state_groups_key_default(self):
return "openshift_user.groups"

@default("scope")
def _scope_default(self):
return ["user:info"]
Expand Down Expand Up @@ -45,24 +49,6 @@ def _http_request_kwargs_default(self):
""",
)

allowed_groups = Set(
config=True,
help="""
Allow members of selected OpenShift groups to sign in.
""",
)

admin_groups = Set(
config=True,
help="""
Allow members of selected OpenShift groups to sign in and consider them
as JupyterHub admins.
If this is set and a user isn't part of one of these groups or listed in
`admin_users`, a user signing in will have their admin status revoked.
""",
)

openshift_auth_api_url = Unicode(
config=True,
help="""
Expand Down Expand Up @@ -158,42 +144,6 @@ def user_info_to_username(self, user_info):
"""
return user_info['metadata']['name']

async def update_auth_model(self, auth_model):
"""
Sets admin status to True or False if `admin_groups` is configured and
the user isn't part of `admin_users`. Note that leaving it at None makes
users able to retain an admin status while setting it to False makes it
be revoked.
"""
if auth_model["admin"]:
# auth_model["admin"] being True means the user was in admin_users
return auth_model

if self.admin_groups:
# admin status should in this case be True or False, not None
user_info = auth_model["auth_state"][self.user_auth_state_key]
user_groups = set(user_info["groups"])
auth_model["admin"] = bool(user_groups & self.admin_groups)

return auth_model

async def check_allowed(self, username, auth_model):
"""
Overrides OAuthenticator.check_allowed to also allow users part of
`allowed_groups`.
"""
if await super().check_allowed(username, auth_model):
return True

if self.allowed_groups:
user_info = auth_model["auth_state"][self.user_auth_state_key]
user_groups = set(user_info["groups"])
if user_groups & self.allowed_groups:
return True

# users should be explicitly allowed via config, otherwise they aren't
return False


class LocalOpenShiftOAuthenticator(LocalAuthenticator, OpenShiftOAuthenticator):
"""A version that mixes in local system user creation"""
47 changes: 46 additions & 1 deletion oauthenticator/tests/test_auth0.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def user_model():
return {
"email": "[email protected]",
"name": "user1",
"groups": ["group1"],
}


Expand Down Expand Up @@ -62,6 +63,47 @@ def user_model():
True,
True,
),
# common tests with allowed_groups and manage_groups
(
"20",
{
"allowed_groups": {"group1"},
"auth_state_groups_key": "auth0_user.groups",
"manage_groups": True,
},
True,
None,
),
(
"21",
{
"allowed_groups": {"test-user-not-in-group"},
"auth_state_groups_key": "auth0_user.groups",
"manage_groups": True,
},
False,
None,
),
(
"22",
{
"admin_groups": {"group1"},
"auth_state_groups_key": "auth0_user.groups",
"manage_groups": True,
},
True,
True,
),
(
"23",
{
"admin_groups": {"test-user-not-in-group"},
"auth_state_groups_key": "auth0_user.groups",
"manage_groups": True,
},
False,
False,
),
],
)
async def test_auth0(
Expand All @@ -84,7 +126,10 @@ async def test_auth0(

if expect_allowed:
assert auth_model
assert set(auth_model) == {"name", "admin", "auth_state"}
if authenticator.manage_groups:
assert set(auth_model) == {"name", "admin", "auth_state", "groups"}
else:
assert set(auth_model) == {"name", "admin", "auth_state"}
assert auth_model["admin"] == expect_admin
auth_state = auth_model["auth_state"]
assert json.dumps(auth_state)
Expand Down
Loading

0 comments on commit ad4034c

Please sign in to comment.