Skip to content

Commit

Permalink
Merge pull request #735 from yuvipanda/groups-main
Browse files Browse the repository at this point in the history
Move group management from generic to base oauthenticator
  • Loading branch information
manics authored Jun 12, 2024
2 parents 10adfa9 + 80c41bd commit 72f95e7
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 107 deletions.
138 changes: 40 additions & 98 deletions oauthenticator/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
"""

import os
from functools import reduce

from jupyterhub.auth import LocalAuthenticator
from jupyterhub.traitlets import Callable
from tornado.httpclient import AsyncHTTPClient
from traitlets import Bool, Dict, Set, Unicode, Union, default
from traitlets import Bool, Dict, Unicode, Union, default, observe

from .oauth2 import OAuthenticator

Expand All @@ -22,44 +21,54 @@ def _login_service_default(self):
[Unicode(os.environ.get('OAUTH2_GROUPS_KEY', 'groups')), Callable()],
config=True,
help="""
Userdata groups claim key from returned json for USERDATA_URL.
.. deprecated:: 17.0
Can be a string key name (use periods for nested keys), or a callable
that accepts the returned json (as a dict) and returns the groups list.
Use :attr:`auth_state_groups_key` instead.
This configures how group membership in the upstream provider is determined
for use by `allowed_groups`, `admin_groups`, etc. If `manage_groups` is True,
this will also determine users' _JupyterHub_ group membership.
""",
)
allowed_groups = Set(
Unicode(),
config=True,
help="""
Allow members of selected groups to sign in.
.. versionchanged:: 17.0
When configuring this you may need to configure `claim_groups_key` as
well as it determines the key in the `userdata_url` response that is
assumed to list the groups a user is a member of.
:attr:`manage_groups` is now required to be `True` to use this functionality
""",
)

admin_groups = Set(
Unicode(),
config=True,
help="""
Allow members of selected groups to sign in and consider them as
JupyterHub admins.
# Initialize value of auth_state_groups_key based on what is in claim_groups_key
@default('auth_state_groups_key')
def _auth_state_groups_key_default(self):
if callable(self.claim_groups_key):
# Automatically wrap the claim_groups_key call so it gets what it thinks it should get
return lambda auth_state: self.claim_groups_key(
auth_state[self.user_auth_state_key]
)
else:
return f"{self.user_auth_state_key}.{self.claim_groups_key}"

# propagate any changes to claim_groups_key to auth_state_groups_key
@observe("claim_groups_key")
def _claim_groups_key_changed(self, change):
# Emit a deprecation warning directly, without using _deprecated_oauth_aliases,
# as it is not a direct replacement for this functionality
self.log.warning(
"{cls}.claim_groups_key is deprecated since OAuthenticator 17.0, use {cls}.auth_state_groups_key instead".format(
cls=self.__class__.__name__,
)
)

If this is set and a user isn't part of one of these groups or listed in
`admin_users`, a user signing in will have their admin status revoked.
if change.new:
if not self.manage_groups:
raise ValueError(
f'{change.owner.__class__.__name__}.{change.name} requires {change.owner.__class__.__name__}.manage_groups to also be set'
)

When configuring this you may need to configure `claim_groups_key` as
well as it determines the key in the `userdata_url` response that is
assumed to list the groups a user is a member of.
""",
)
if callable(change.new):
# Automatically wrap the claim_gorups_key call so it gets what it thinks it should get
self.auth_state_groups_key = lambda auth_state: self.claim_groups_key(
auth_state[self.user_auth_state_key]
)
else:
self.auth_state_groups_key = (
f"{self.user_auth_state_key}.{self.claim_groups_key}"
)

@default("http_client")
def _default_http_client(self):
Expand Down Expand Up @@ -100,73 +109,6 @@ def _default_http_client(self):
""",
)

def get_user_groups(self, user_info):
"""
Returns a set of groups the user belongs to based on claim_groups_key
and provided user_info.
- If claim_groups_key is a callable, it is meant to return the groups
directly.
- If claim_groups_key is a nested dictionary key like
"permissions.groups", this function returns
user_info["permissions"]["groups"].
Note that this method is introduced by GenericOAuthenticator and not
present in the base class.
"""
if callable(self.claim_groups_key):
return set(self.claim_groups_key(user_info))
try:
return set(reduce(dict.get, self.claim_groups_key.split("."), user_info))
except TypeError:
self.log.error(
f"The claim_groups_key {self.claim_groups_key} does not exist in the user token"
)
return set()

async def update_auth_model(self, auth_model):
"""
Sets admin status to True or False if `admin_groups` is configured and
the user isn't part of `admin_users` or `admin_groups`. Note that
leaving it at None makes users able to retain an admin status while
setting it to False makes it be revoked.
Also populates groups if `manage_groups` is set.
"""
if self.manage_groups or self.admin_groups:
user_info = auth_model["auth_state"][self.user_auth_state_key]
user_groups = self.get_user_groups(user_info)

if self.manage_groups:
auth_model["groups"] = sorted(user_groups)

if auth_model["admin"]:
# auth_model["admin"] being True means the user was in admin_users
return auth_model

if self.admin_groups:
# admin status should in this case be True or False, not None
auth_model["admin"] = bool(user_groups & self.admin_groups)

return auth_model

async def check_allowed(self, username, auth_model):
"""
Overrides the OAuthenticator.check_allowed to also allow users part of
`allowed_groups`.
"""
if await super().check_allowed(username, auth_model):
return True

if self.allowed_groups:
user_info = auth_model["auth_state"][self.user_auth_state_key]
user_groups = self.get_user_groups(user_info)
if any(user_groups & self.allowed_groups):
return True

# users should be explicitly allowed via config, otherwise they aren't
return False


class LocalGenericOAuthenticator(LocalAuthenticator, GenericOAuthenticator):
"""A version that mixes in local system user creation"""
104 changes: 103 additions & 1 deletion oauthenticator/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import json
import os
import uuid
from functools import reduce
from urllib.parse import quote, urlencode, urlparse, urlunparse

import jwt
Expand All @@ -20,7 +21,19 @@
from tornado.httpclient import AsyncHTTPClient, HTTPClientError, HTTPRequest
from tornado.httputil import url_concat
from tornado.log import app_log
from traitlets import Any, Bool, Callable, Dict, List, Unicode, Union, default, validate
from traitlets import (
Any,
Bool,
Callable,
Dict,
List,
Set,
Unicode,
Union,
default,
observe,
validate,
)


def guess_callback_uri(protocol, host, hub_server_url):
Expand Down Expand Up @@ -316,6 +329,54 @@ class OAuthenticator(Authenticator):
""",
)

allowed_groups = Set(
Unicode(),
config=True,
help="""
Allow members of selected groups to log in.
Requires `manage_groups` to also be `True`.
""",
)

admin_groups = Set(
Unicode(),
config=True,
help="""
Allow members of selected groups to sign in and consider them as
JupyterHub admins.
If this is set and a user isn't part of one of these groups or listed in
`admin_users`, a user signing in will have their admin status revoked.
Requires `manage_groups` to also be `True`.
""",
)

auth_state_groups_key = Union(
[Unicode(), Callable()],
config=True,
help="""
Determine groups this user belongs based on contents of auth_state.
Can be a string key name (use periods for nested keys), or a callable
that accepts the auth state (as a dict) and returns the groups list.
Requires `manage_groups` to also be `True`.
""",
)

@observe("allowed_groups", "admin_groups", "auth_state_groups_key")
def _requires_manage_groups(self, change):
"""
Validate that group management keys are only set when manage_groups is also True
"""
if change.new:
if not self.manage_groups:
raise ValueError(
f'{change.owner.__class__.__name__}.{change.name} requires {change.owner.__class__.__name__}.manage_groups to also be set'
)

authorize_url = Unicode(
config=True,
help="""
Expand Down Expand Up @@ -1025,6 +1086,28 @@ def build_auth_state_dict(self, token_info, user_info):
self.user_auth_state_key: user_info,
}

def get_user_groups(self, auth_state: dict):
"""
Returns a set of groups the user belongs to based on auth_state_groups_key
and provided auth_state.
- If auth_state_groups_key is a callable, it returns the list of groups directly.
- If auth_state_groups_key is a nested dictionary key like
"permissions.groups", this function returns
auth_state["permissions"]["groups"].
"""
if callable(self.auth_state_groups_key):
return set(self.auth_state_groups_key(auth_state))
try:
return set(
reduce(dict.get, self.auth_state_groups_key.split("."), auth_state)
)
except TypeError:
self.log.error(
f"The auth_state_groups_key {self.auth_state_groups_key} does not exist in the auth_model. Available keys are: {auth_state.keys()}"
)
return set()

async def update_auth_model(self, auth_model):
"""
Updates and returns the `auth_model` dict.
Expand All @@ -1040,6 +1123,17 @@ async def update_auth_model(self, auth_model):
Called by the :meth:`oauthenticator.OAuthenticator.authenticate`
"""
if self.manage_groups:
auth_state = auth_model["auth_state"]
user_groups = self.get_user_groups(auth_state)

auth_model["groups"] = sorted(user_groups)

if self.admin_groups:
if not auth_model["admin"]:
# auth_model["admin"] being True means the user was in admin_users
# so their group membership should not affect their admin status
auth_model["admin"] = bool(user_groups & self.admin_groups)
return auth_model

async def authenticate(self, handler, data=None, **kwargs):
Expand Down Expand Up @@ -1125,6 +1219,13 @@ async def check_allowed(self, username, auth_model):
if username in self.allowed_users:
return True

# allow users who are members of allowed_groups
if self.manage_groups and self.allowed_groups:
auth_state = auth_model["auth_state"]
user_groups = self.get_user_groups(auth_state)
if any(user_groups & self.allowed_groups):
return True

# users should be explicitly allowed via config, otherwise they aren't
return False

Expand Down Expand Up @@ -1177,6 +1278,7 @@ def __init__(self, **kwargs):
self.observe(
self._deprecated_oauth_trait, names=list(self._deprecated_oauth_aliases)
)

super().__init__(**kwargs)


Expand Down
Loading

0 comments on commit 72f95e7

Please sign in to comment.