diff --git a/docs/source/session_auth/middleware.rst b/docs/source/session_auth/middleware.rst index be897a4f..8ebde60f 100644 --- a/docs/source/session_auth/middleware.rst +++ b/docs/source/session_auth/middleware.rst @@ -88,6 +88,13 @@ follows: ------------------------------------------------------------------------------- +``excluded_paths`` +------------------ + +This works identically to token auth - see :ref:`excluded_paths`. + +------------------------------------------------------------------------------- + Source ------ diff --git a/docs/source/token_auth/middleware.rst b/docs/source/token_auth/middleware.rst index 0ce1361e..26236002 100644 --- a/docs/source/token_auth/middleware.rst +++ b/docs/source/token_auth/middleware.rst @@ -71,6 +71,8 @@ You'll have to run the migrations for this to work correctly. ``TokenAuthBackend`` -------------------- +.. _excluded_paths: + ``excluded_paths`` ~~~~~~~~~~~~~~~~~~ diff --git a/piccolo_api/session_auth/middleware.py b/piccolo_api/session_auth/middleware.py index 6f90dc50..43d682d4 100644 --- a/piccolo_api/session_auth/middleware.py +++ b/piccolo_api/session_auth/middleware.py @@ -14,6 +14,7 @@ from piccolo_api.session_auth.tables import SessionsBase from piccolo_api.shared.auth import UnauthenticatedUser, User +from piccolo_api.shared.auth.excluded_paths import check_excluded_paths class SessionsAuthBackend(AuthenticationBackend): @@ -31,6 +32,7 @@ def __init__( active_only: bool = True, increase_expiry: t.Optional[timedelta] = None, allow_unauthenticated: bool = False, + excluded_paths: t.Optional[t.Sequence[str]] = None, ): """ :param auth_table: @@ -43,22 +45,26 @@ def __init__( The name of the session cookie. Override this if it clashes with other cookies in your application. :param admin_only: - If True, users which aren't admins will be rejected. + If ``True``, users which aren't admins will be rejected. :param superuser_only: - If True, users which aren't superusers will be rejected. + If ``True``, users which aren't superusers will be rejected. :param active_only: - If True, users which aren't active will be rejected. + If ``True``, users which aren't active will be rejected. :param increase_expiry: If set, the session expiry will be increased by this amount on each request, if it's close to expiry. This allows sessions to have a short expiry date, whilst also providing a good user experience. :param allow_unauthenticated: - If True, when a matching user session can't be found, the request + If ``True``, when a matching user session can't be found, the request still continues, but an unauthenticated user is added to the scope. It's then up to the application's endpoints to check if a user is authenticated or not using ``request.user.is_authenticated``. If - False, the request is automatically rejected if a user session + ``False``, the request is automatically rejected if a user session can't be found. + :param excluded_paths: + These paths don't require a session cookie - useful if you want to + exclude a few URLs, such as docs. + """ # noqa: E501 super().__init__() self.auth_table = auth_table @@ -69,7 +75,9 @@ def __init__( self.active_only = active_only self.increase_expiry = increase_expiry self.allow_unauthenticated = allow_unauthenticated + self.excluded_paths = excluded_paths or [] + @check_excluded_paths async def authenticate( self, conn: HTTPConnection ) -> t.Optional[t.Tuple[AuthCredentials, BaseUser]]: diff --git a/piccolo_api/shared/auth/excluded_paths.py b/piccolo_api/shared/auth/excluded_paths.py new file mode 100644 index 00000000..cc86565b --- /dev/null +++ b/piccolo_api/shared/auth/excluded_paths.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import functools +import typing as t + +from starlette.authentication import AuthCredentials, AuthenticationBackend +from starlette.requests import HTTPConnection + +from piccolo_api.shared.auth import UnauthenticatedUser + + +def check_excluded_paths(authenticate_func: t.Callable): + + @functools.wraps(authenticate_func) + async def authenticate(self: AuthenticationBackend, conn: HTTPConnection): + conn_path = dict(conn) + + excluded_paths = getattr(self, "excluded_paths", None) + + if excluded_paths is None: + raise ValueError("excluded_paths isn't defined") + + for excluded_path in excluded_paths: + if excluded_path.endswith("*"): + if ( + conn_path["raw_path"] + .decode("utf-8") + .startswith(excluded_path.rstrip("*")) + ): + return ( + AuthCredentials(scopes=[]), + UnauthenticatedUser(), + ) + else: + if conn_path["path"] == excluded_path: + return ( + AuthCredentials(scopes=[]), + UnauthenticatedUser(), + ) + + return await authenticate_func(self=self, conn=conn) + + return authenticate diff --git a/piccolo_api/token_auth/middleware.py b/piccolo_api/token_auth/middleware.py index c9f2c019..b9dedd29 100644 --- a/piccolo_api/token_auth/middleware.py +++ b/piccolo_api/token_auth/middleware.py @@ -13,7 +13,8 @@ ) from starlette.requests import HTTPConnection -from piccolo_api.shared.auth import UnauthenticatedUser, User +from piccolo_api.shared.auth import User +from piccolo_api.shared.auth.excluded_paths import check_excluded_paths from piccolo_api.token_auth.tables import TokenAuth @@ -90,7 +91,9 @@ def __init__( :param token_auth_provider: Used to verify that a token is correct. :param excluded_paths: - These paths don't require a token. + These paths don't require a token - useful if you want to + exclude a few URLs, such as docs. + """ super().__init__() self.token_auth_provider = token_auth_provider @@ -104,29 +107,11 @@ def extract_token(self, header: str) -> str: return token + @check_excluded_paths async def authenticate( self, conn: HTTPConnection ) -> t.Optional[t.Tuple[AuthCredentials, BaseUser]]: auth_header = conn.headers.get("Authorization", None) - conn_path = dict(conn) - - for excluded_path in self.excluded_paths: - if excluded_path.endswith("*"): - if ( - conn_path["raw_path"] - .decode("utf-8") - .startswith(excluded_path.rstrip("*")) - ): - return ( - AuthCredentials(scopes=[]), - UnauthenticatedUser(), - ) - else: - if conn_path["path"] == excluded_path: - return ( - AuthCredentials(scopes=[]), - UnauthenticatedUser(), - ) if not auth_header: raise AuthenticationError("The Authorization header is missing.") diff --git a/tests/session_auth/test_session.py b/tests/session_auth/test_session.py index 78318123..f2edb08d 100644 --- a/tests/session_auth/test_session.py +++ b/tests/session_auth/test_session.py @@ -707,6 +707,105 @@ def test_wrong_cookie_value(self): ) +############################################################################### + +EXCLUDED_PATHS_APP = Router( + routes=[ + Route("/", EchoEndpoint), + Route( + "/foo/", + EchoEndpoint, + ), + Route( + "/foo/1/", + EchoEndpoint, + ), + Route( + "/bar/", + EchoEndpoint, + ), + Route( + "/bar/1/", + EchoEndpoint, + ), + ] +) + + +class TestExcludedPaths(SessionTestCase): + """ + Make sure that if `excluded_paths` is set, then the middleware allows the + request to continue without a cookie. + """ + + def create_user_and_session(self): + user = BaseUser( + **self.credentials, active=True, admin=True, superuser=True + ) + user.save().run_sync() + SessionsBase.create_session_sync(user_id=user.id) + + def setUp(self): + super().setUp() + + # Add a session to the database to make it more realistic. + self.create_user_and_session() + + def test_excluded_paths(self): + """ + Make sure that only the `excluded_paths` are accessible + """ + app = AuthenticationMiddleware( + EXCLUDED_PATHS_APP, + SessionsAuthBackend( + allow_unauthenticated=False, + excluded_paths=["/foo/"], + ), + ) + client = TestClient(app) + + for path in ("/", "/foo/1/", "/bar/", "/bar/1/"): + response = client.get(path) + self.assertEqual(response.status_code, 400) + self.assertEqual(response.content, b"No session cookie found.") + + response = client.get("/foo/") + assert response.status_code == 200 + self.assertDictEqual( + response.json(), + {"is_unauthenticated_user": True, "is_authenticated": False}, + ) + + def test_excluded_paths_wildcard(self): + """ + Make sure that wildcard paths work correctly. + """ + app = AuthenticationMiddleware( + EXCLUDED_PATHS_APP, + SessionsAuthBackend( + allow_unauthenticated=False, + excluded_paths=["/foo/*"], + ), + ) + client = TestClient(app) + + for path in ("/", "/bar/", "/bar/1/"): + response = client.get(path) + self.assertEqual(response.status_code, 400) + self.assertEqual(response.content, b"No session cookie found.") + + for path in ("/foo/", "/foo/1/"): + response = client.get(path) + self.assertEqual(response.status_code, 200) + self.assertDictEqual( + response.json(), + {"is_unauthenticated_user": True, "is_authenticated": False}, + ) + + +############################################################################### + + class TestHooks(SessionTestCase): def test_hooks(self): # TODO Replace these with mocks ...