Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix username collision #6

Merged
merged 11 commits into from
Oct 30, 2023
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ To set up a development environment for this repository:
1. Do a development install with pip

```bash
pip install --editable ".[test]"
pip install --editable ".[dev]"
```

1. Set up pre-commit hooks for automatic code formatting, etc.
Expand Down
46 changes: 45 additions & 1 deletion multiauthenticator/multiauthenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from jupyterhub.utils import url_path_join
from traitlets import List

PREFIX_SEPARATOR = ":"


class URLScopeMixin:
"""Mixin class that adds the"""
Expand All @@ -51,6 +53,14 @@ def get_handlers(self, app):
]


def removeprefix(self: str, prefix: str) -> str:
"""PEP-0616 implementation to stay compatible with Python < 3.9"""
if self.startswith(prefix):
return self[len(prefix) :]
else:
return self[:]


class MultiAuthenticator(Authenticator):
"""Wrapper class that allows to use more than one authentication provider
for JupyterHub"""
Expand All @@ -69,12 +79,46 @@ def __init__(self, *arg, **kwargs):
class WrapperAuthenticator(URLScopeMixin, authenticator_klass):
url_scope = url_scope_authenticator

@property
def username_prefix(self):
return f"{getattr(self, 'service_name', self.login_service)}{PREFIX_SEPARATOR}"

async def authenticate(self, handler, data=None, **kwargs):
response = await super().authenticate(handler, data, **kwargs)
if response is None:
return None
elif type(response) == str:
return self.username_prefix + response
else:
response["name"] = self.username_prefix + response["name"]
return response

def check_allowed(self, username, authentication=None):
if not username.startswith(self.username_prefix):
return False

return super().check_allowed(
removeprefix(username, self.username_prefix), authentication
)

def check_blocked_users(self, username, authentication=None):
if not username.startswith(self.username_prefix):
return False

return super().check_blocked_users(
removeprefix(username, self.username_prefix), authentication
)

service_name = authenticator_configuration.pop("service_name", None)

authenticator = WrapperAuthenticator(**authenticator_configuration)

if service_name:
if service_name is not None:
if PREFIX_SEPARATOR in service_name:
raise ValueError(f"Service name cannot contain {PREFIX_SEPARATOR}")
authenticator.service_name = service_name
elif PREFIX_SEPARATOR in authenticator.login_service:
raise ValueError(f"Login service cannot contain {PREFIX_SEPARATOR}")

self._authenticators.append(authenticator)

Expand Down
127 changes: 120 additions & 7 deletions multiauthenticator/tests/test_multiauthenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
#
# SPDX-License-Identifier: BSD-3-Clause
"""Test module for the MultiAuthenticator class"""
import pytest

from jupyterhub.auth import DummyAuthenticator
from jupyterhub.auth import PAMAuthenticator
from oauthenticator import OAuthenticator
from oauthenticator.github import GitHubOAuthenticator
from oauthenticator.gitlab import GitLabOAuthenticator
from oauthenticator.google import GoogleOAuthenticator

from ..multiauthenticator import PREFIX_SEPARATOR
from ..multiauthenticator import MultiAuthenticator


Expand Down Expand Up @@ -82,7 +87,7 @@ def test_same_authenticators():
GoogleOAuthenticator,
"/mygoogle",
{
"login_service": "My Google",
"service_name": "My Google",
"client_id": "yyyyy",
"client_secret": "yyyyy",
"oauth_callback_url": "http://example.com/hub/mygoogle/oauth_callback",
Expand All @@ -92,7 +97,7 @@ def test_same_authenticators():
GoogleOAuthenticator,
"/othergoogle",
{
"login_service": "Other Google",
"service_name": "Other Google",
"client_id": "xxxx",
"client_secret": "xxxx",
"oauth_callback_url": "http://example.com/hub/othergoogle/oauth_callback",
Expand All @@ -109,9 +114,9 @@ def test_same_authenticators():
for path, handler in handlers:
assert isinstance(handler.authenticator, GoogleOAuthenticator)
if "mygoogle" in path:
assert handler.authenticator.login_service == "My Google"
assert handler.authenticator.service_name == "My Google"
elif "othergoogle" in path:
assert handler.authenticator.login_service == "Other Google"
assert handler.authenticator.service_name == "Other Google"
else:
raise ValueError(f"Unknown path: {path}")

Expand Down Expand Up @@ -171,7 +176,6 @@ def test_extra_configuration():
{
"service_name": "PAM",
"allowed_users": allowed_users,
"not_existing": "boom",
},
),
]
Expand All @@ -182,5 +186,114 @@ def test_extra_configuration():
for authenticator in multi_authenticator._authenticators:
assert authenticator.allowed_users == allowed_users

if isinstance(authenticator, PAMAuthenticator):
assert not hasattr(authenticator, "not_existing")

def test_username_prefix():
MultiAuthenticator.authenticators = [
(
GitLabOAuthenticator,
"/gitlab",
{
"client_id": "xxxx",
"client_secret": "xxxx",
"oauth_callback_url": "http://example.com/hub/gitlab/oauth_callback",
},
),
(PAMAuthenticator, "/pam", {"service_name": "PAM"}),
]

multi_authenticator = MultiAuthenticator()
assert len(multi_authenticator._authenticators) == 2
assert (
multi_authenticator._authenticators[0].username_prefix
== f"GitLab{PREFIX_SEPARATOR}"
)
assert (
multi_authenticator._authenticators[1].username_prefix
== f"PAM{PREFIX_SEPARATOR}"
)


@pytest.mark.asyncio
async def test_authenticated_username_prefix():
MultiAuthenticator.authenticators = [
(DummyAuthenticator, "/pam", {"service_name": "Dummy"}),
]

multi_authenticator = MultiAuthenticator()
assert len(multi_authenticator._authenticators) == 1
username = await multi_authenticator._authenticators[0].authenticate(
None, {"username": "test"}
)
assert username == f"Dummy{PREFIX_SEPARATOR}test"


def test_username_prefix_checks():
MultiAuthenticator.authenticators = [
(PAMAuthenticator, "/pam", {"service_name": "PAM", "allowed_users": {"test"}}),
(
PAMAuthenticator,
"/pam",
{"service_name": "PAM2", "blocked_users": {"test2"}},
),
]

multi_authenticator = MultiAuthenticator()
assert len(multi_authenticator._authenticators) == 2
authenticator = multi_authenticator._authenticators[0]

assert authenticator.check_allowed("test") == False
assert authenticator.check_allowed("PAM:test") == True
assert (
authenticator.check_blocked_users("test") == False
) # Even if no block list, it does not have the correct prefix
assert authenticator.check_blocked_users("PAM:test") == True

authenticator = multi_authenticator._authenticators[1]
assert authenticator.check_allowed("test2") == False
assert (
authenticator.check_allowed("PAM2:test2") == True
) # Because allowed_users is empty
assert authenticator.check_blocked_users("test2") == False
assert authenticator.check_blocked_users("PAM2:test2") == False


@pytest.fixture(params=[f"test me{PREFIX_SEPARATOR}", f"second{PREFIX_SEPARATOR} test"])
def invalid_name(request):
yield request.param


def test_username_prefix_validation_with_service_name(invalid_name):
MultiAuthenticator.authenticators = [
(
PAMAuthenticator,
"/pam",
{"service_name": invalid_name, "allowed_users": {"test"}},
),
]

with pytest.raises(ValueError) as excinfo:
MultiAuthenticator()

assert f"Service name cannot contain {PREFIX_SEPARATOR}" in str(excinfo.value)


def test_username_prefix_validation_with_login_service(invalid_name):
class MyAuthenticator(OAuthenticator):
login_service = invalid_name

MultiAuthenticator.authenticators = [
(
MyAuthenticator,
"/myauth",
{
"client_id": "xxxx",
"client_secret": "xxxx",
"oauth_callback_url": "http://example.com/myauth/oauth_callback",
},
),
]

with pytest.raises(ValueError) as excinfo:
MultiAuthenticator()

assert f"Login service cannot contain {PREFIX_SEPARATOR}" in str(excinfo.value)
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ dependencies = [
]

[project.optional-dependencies]
test = ["pytest", "pytest-cov"]
test = ["pytest", "pytest-cov", "pytest-asyncio"]
dev = ["pre-commit", "jupyterhub-multiauthenticator[test]"]

[tool.setuptools]
packages = ["multiauthenticator"]
Expand Down