From 3a39025ed2c0fb924ceaa739f0eac56ef8bd4500 Mon Sep 17 00:00:00 2001 From: Jan-Lukas Wynen Date: Mon, 29 Apr 2024 14:00:19 +0200 Subject: [PATCH 1/3] Simplify and fix credentials classes --- src/scitacean/client.py | 42 ++++---- src/scitacean/testing/client.py | 8 +- src/scitacean/util/credentials.py | 153 +++++++++++++++++++----------- tests/util/credentials_test.py | 87 +++++++++++++++++ 4 files changed, 210 insertions(+), 80 deletions(-) create mode 100644 tests/util/credentials_test.py diff --git a/src/scitacean/client.py b/src/scitacean/client.py index 102c2e7c..047ef5c2 100644 --- a/src/scitacean/client.py +++ b/src/scitacean/client.py @@ -25,7 +25,7 @@ from .logging import get_logger from .pid import PID from .typing import DownloadConnection, FileTransfer, UploadConnection -from .util.credentials import ExpiringToken, SecretStr, StrStorage +from .util.credentials import SecretStr, Token class Client: @@ -61,7 +61,7 @@ def from_token( cls, *, url: str, - token: str | StrStorage, + token: str | SecretStr | Token, file_transfer: FileTransfer | None = None, ) -> Client: """Create a new client and authenticate with a token. @@ -91,8 +91,8 @@ def from_credentials( cls, *, url: str, - username: str | StrStorage, - password: str | StrStorage, + username: str | SecretStr, + password: str | SecretStr, file_transfer: FileTransfer | None = None, ) -> Client: """Create a new client and authenticate with username and password. @@ -559,23 +559,23 @@ class ScicatClient: def __init__( self, url: str, - token: str | StrStorage | None, + token: str | SecretStr | Token | None, timeout: datetime.timedelta | None, ): # Need to add a final / self._base_url = url[:-1] if url.endswith("/") else url self._timeout = datetime.timedelta(seconds=10) if timeout is None else timeout - self._token: StrStorage | None = ( - ExpiringToken.from_jwt(SecretStr(token)) - if isinstance(token, str) - else token + self._token = ( + token + if isinstance(token, Token) or token is None + else Token.from_jwt(token, denial_period=datetime.timedelta(seconds=2)) ) @classmethod def from_token( cls, url: str, - token: str | StrStorage, + token: str | SecretStr | Token, timeout: datetime.timedelta | None = None, ) -> ScicatClient: """Create a new low-level client and authenticate with a token. @@ -600,8 +600,8 @@ def from_token( def from_credentials( cls, url: str, - username: str | StrStorage, - password: str | StrStorage, + username: str | SecretStr, + password: str | SecretStr, timeout: datetime.timedelta | None = None, ) -> ScicatClient: """Create a new low-level client and authenticate with username and password. @@ -623,10 +623,8 @@ def from_credentials( : A new low-level client. """ - if not isinstance(username, StrStorage): - username = SecretStr(username) - if not isinstance(password, StrStorage): - password = SecretStr(password) + username = SecretStr(username) + password = SecretStr(password) return ScicatClient( url=url, token=SecretStr( @@ -1005,7 +1003,7 @@ def _send_to_scicat( self, *, cmd: str, url: str, data: model.BaseModel | None = None ) -> requests.Response: if self._token is not None: - token = self._token.get_str() + token = self._token.expose() headers = {"Authorization": f"Bearer {token}"} else: token = "" @@ -1099,12 +1097,12 @@ def _make_orig_datablock( def _log_in_via_users_login( - url: str, username: StrStorage, password: StrStorage, timeout: datetime.timedelta + url: str, username: SecretStr, password: SecretStr, timeout: datetime.timedelta ) -> requests.Response: # Currently only used for functional accounts. response = requests.post( _url_concat(url, "Users/login"), - json={"username": username.get_str(), "password": password.get_str()}, + json={"username": username.expose(), "password": password.expose()}, stream=False, verify=True, timeout=timeout.seconds, @@ -1117,7 +1115,7 @@ def _log_in_via_users_login( def _log_in_via_auth_msad( - url: str, username: StrStorage, password: StrStorage, timeout: datetime.timedelta + url: str, username: SecretStr, password: SecretStr, timeout: datetime.timedelta ) -> requests.Response: # Used for user accounts. import re @@ -1126,7 +1124,7 @@ def _log_in_via_auth_msad( base_url = re.sub(r"/api/v\d+/?", "", url) response = requests.post( _url_concat(base_url, "auth/msad"), - json={"username": username.get_str(), "password": password.get_str()}, + json={"username": username.expose(), "password": password.expose()}, stream=False, verify=True, timeout=timeout.seconds, @@ -1137,7 +1135,7 @@ def _log_in_via_auth_msad( def _get_token( - url: str, username: StrStorage, password: StrStorage, timeout: datetime.timedelta + url: str, username: SecretStr, password: SecretStr, timeout: datetime.timedelta ) -> str: """Log in using the provided username + password. diff --git a/src/scitacean/testing/client.py b/src/scitacean/testing/client.py index dadb5faf..5e5c9e1f 100644 --- a/src/scitacean/testing/client.py +++ b/src/scitacean/testing/client.py @@ -16,7 +16,7 @@ from ..error import ScicatCommError from ..pid import PID from ..typing import FileTransfer -from ..util.credentials import StrStorage +from ..util.credentials import SecretStr def _conditionally_disabled(func: Callable[..., Any]) -> Callable[..., Any]: @@ -134,7 +134,7 @@ def from_token( cls, *, url: str, - token: str | StrStorage, + token: str | SecretStr, file_transfer: FileTransfer | None = None, ) -> FakeClient: """Create a new fake client. @@ -148,8 +148,8 @@ def from_credentials( cls, *, url: str, - username: str | StrStorage, - password: str | StrStorage, + username: str | SecretStr, + password: str | SecretStr, file_transfer: FileTransfer | None = None, ) -> FakeClient: """Create a new fake client. diff --git a/src/scitacean/util/credentials.py b/src/scitacean/util/credentials.py index 6d11bf1b..57f0f1cc 100644 --- a/src/scitacean/util/credentials.py +++ b/src/scitacean/util/credentials.py @@ -4,41 +4,13 @@ from __future__ import annotations -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta from typing import NoReturn from .._internal.jwt import expiry -class StrStorage: - """Base class for storing a string. - - Instances can be nested to combine different specialized features. - """ - - def __init__(self, value: str | StrStorage | None): - self._value = value - - def get_str(self) -> str: - """Return the stored plain str object.""" - if self._value is None: - # If the implementation chooses to not - # store the string in memory. - # The method must be overridden in this case. - raise NotImplementedError("String not available") - - if isinstance(self._value, StrStorage): - return self._value.get_str() - return self._value - - def __str__(self) -> str: - return str(self._value) - - def __repr__(self) -> str: - return f"{type(self).__name__}({self._value!r})" - - -class SecretStr(StrStorage): +class SecretStr: """Minimize the risk of exposing a secret. Stores a string and blocks the most common pathways @@ -50,8 +22,19 @@ class SecretStr(StrStorage): still be leaked through introspection methods. """ - def __init__(self, value: str | StrStorage): - super().__init__(value) + def __init__(self, value: str | SecretStr) -> None: + """Initialize from a plain str or other SecretStr. + + Parameters + ---------- + value: + The string to store. + """ + self._value = value if isinstance(value, str) else value.expose() + + def expose(self) -> str: + """Return the stored plain str object.""" + return self._value def __str__(self) -> str: return "***" @@ -61,45 +44,109 @@ def __repr__(self) -> str: # prevent pickling def __reduce_ex__(self, protocol: object) -> NoReturn: - raise TypeError("SecretStr must not be pickled") + raise TypeError( + "SecretStr must not be pickled to avoid storing or sharing " + "it accidentally." + ) -class ExpiringToken(StrStorage): - """A JWT token that expires after some time.""" +class Token(SecretStr): + """A SciCat token that may expire after some time.""" def __init__( self, + value: str | SecretStr | Token, *, - value: str | StrStorage, - expires_at: datetime, + expires_at: datetime | None, denial_period: timedelta | None = None, - ): - super().__init__(value) - if denial_period is None: - denial_period = timedelta(seconds=2) - self._expires_at = expires_at - denial_period + ) -> None: + """Initialize from a plain or secret string or other token. + + Parameters + ---------- + value: + The string of the token to store. + If a ``Token`` object, the expiry date is overridden by ``expires_at``. + expires_at: + Datetime after which the token is no longer valid. + If ``None``, the token is assumed to never expire. + denial_period: + If given, the token will be treated as expired after + ``expires_at - denial_period``. + This is useful to give an operation enough leeway to finish before the + token actually expires. + """ + super().__init__(value.expose() if isinstance(value, Token) else value) + if expires_at is None: + self._expires_at = None + else: + if denial_period is None: + self._expires_at = expires_at + else: + self._expires_at = expires_at - denial_period self._check_expiry() @classmethod - def from_jwt(cls, value: str | StrStorage) -> ExpiringToken: - """Create a new ExpiringToken from a JSON web token.""" - value_str = value if isinstance(value, str) else value.get_str() + def from_jwt( + cls, + value: str | SecretStr, + denial_period: timedelta | None = None, + ) -> Token: + """Create a new ExpiringToken from a JSON web token. + + Parameters + ---------- + value: + A JSON web token. + denial_period: + If given, the token will be treated as expired after + ``expires_at - denial_period``. + This is useful to give an operation enough leeway to finish before the + token actually expires. + + Returns + ------- + : + A ``Token`` object that contains ``value`` and is + set up with an expiry date parsed from the JWT. + """ + value_str = value if isinstance(value, str) else value.expose() try: expires_at = expiry(value_str) except ValueError: - expires_at = datetime.now(tz=timezone.utc) + timedelta(weeks=100) + expires_at = None return cls( value=value, expires_at=expires_at, + denial_period=denial_period, ) - def get_str(self) -> str: - """Return the stored plain str object.""" + def expose(self) -> str: + """Return the stored plain str object. + + Returns + ------- + : + A plain string with the token. + + Raises + ------ + RuntimeError + If the token has expired. + """ self._check_expiry() - return super().get_str() + return super().expose() + + @property + def expires_at(self) -> datetime | None: + """Return the expiration date including denial period.""" + return self._expires_at def _check_expiry(self) -> None: - if datetime.now(tz=self._expires_at.tzinfo) > self._expires_at: + if ( + self._expires_at is not None + and datetime.now(tz=self._expires_at.tzinfo) > self._expires_at + ): raise RuntimeError( "SciCat login has expired. You need to create a new client either by " "logging in through `Client.from_credentials` or by getting a new " @@ -107,7 +154,5 @@ def _check_expiry(self) -> None: ) def __repr__(self) -> str: - return ( - f"TimeLimitedStr(expires_at={self._expires_at.isoformat()}, " - f"value={self._value!r}" - ) + expires = self.expires_at.isoformat() if self.expires_at is not None else None + return f"Token(***, expires_at={expires})" diff --git a/tests/util/credentials_test.py b/tests/util/credentials_test.py new file mode 100644 index 00000000..3eab5342 --- /dev/null +++ b/tests/util/credentials_test.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024 SciCat Project (https://github.com/SciCatProject/scitacean) + +import pickle +from datetime import datetime, timedelta, timezone + +import pytest + +from scitacean.util.credentials import SecretStr, Token + + +def test_secret_str_from_str_expose(): + secret_str = SecretStr("something hidden") + assert secret_str.expose() == "something hidden" + + +def test_secret_str_from_secret_str_expose(): + secret_str = SecretStr(SecretStr("don't tell!")) + assert secret_str.expose() == "don't tell!" + + +def test_secret_str_hides_content(): + secret_str = SecretStr("something hidden") + assert "something hidden" not in str(secret_str) + assert "hidden" not in str(secret_str) + assert "something hidden" not in repr(secret_str) + assert "hidden" not in repr(secret_str) + + +def test_secret_str_cannot_be_pickled(): + secret_str = SecretStr("something hidden") + with pytest.raises(TypeError, match="pickle"): + pickle.dumps(secret_str) + + +@pytest.mark.parametrize( + "expires_at", [None, datetime.now(tz=timezone.utc) + timedelta(seconds=100)] +) +def test_token_from_str_expose(expires_at): + token = Token("something hidden", expires_at=expires_at) + assert token.expose() == "something hidden" + + +def test_token_from_secret_str_expose(): + token = Token(SecretStr("don't tell!"), expires_at=None) + assert token.expose() == "don't tell!" + + +def test_token_from_token_expose(): + token = Token(Token("double-T", expires_at=None), expires_at=None) + assert token.expose() == "double-T" + + +def test_token_hides_content(): + token = Token("something hidden", expires_at=None) + assert "something hidden" not in str(token) + assert "hidden" not in str(token) + assert "something hidden" not in repr(token) + assert "hidden" not in repr(token) + + +def test_token_cannot_be_pickled(): + token = Token("something hidden", expires_at=None) + with pytest.raises(TypeError, match="pickle"): + pickle.dumps(token) + + +def test_token_cannot_init_if_expired(): + with pytest.raises(RuntimeError, match="expired"): + Token("token", expires_at=datetime.now(tz=timezone.utc) - timedelta(seconds=1)) + + +def test_token_cannot_expose_if_expired(): + # Circumvent the check in __init__ + token = Token( + "token", expires_at=datetime.now(tz=timezone.utc) + timedelta(seconds=100) + ) + token._expires_at = datetime.now(tz=timezone.utc) - timedelta(seconds=1) + with pytest.raises(RuntimeError, match="expired"): + token.expose() + + +def test_token_expires_at_includes_denial_period(): + # Need a time in the future even within the denial period. + base_expires_at = datetime.now(tz=timezone.utc) + timedelta(seconds=100) + token = Token("", expires_at=base_expires_at, denial_period=timedelta(seconds=10)) + assert token.expires_at == base_expires_at - timedelta(seconds=10) From 480a96309f886084281121eb1bb66cf450e9a97d Mon Sep 17 00:00:00 2001 From: Jan-Lukas Wynen Date: Mon, 29 Apr 2024 15:06:33 +0200 Subject: [PATCH 2/3] Add method to request new token --- src/scitacean/client.py | 29 +++++++++++++++++++++++------ src/scitacean/testing/client.py | 7 +++++++ tests/client/client_test.py | 19 +++++++++++++++++++ 3 files changed, 49 insertions(+), 6 deletions(-) diff --git a/src/scitacean/client.py b/src/scitacean/client.py index 047ef5c2..4dfe6424 100644 --- a/src/scitacean/client.py +++ b/src/scitacean/client.py @@ -85,7 +85,6 @@ def from_token( file_transfer=file_transfer, ) - # TODO rename to login? and provide logout? @classmethod def from_credentials( cls, @@ -565,11 +564,7 @@ def __init__( # Need to add a final / self._base_url = url[:-1] if url.endswith("/") else url self._timeout = datetime.timedelta(seconds=10) if timeout is None else timeout - self._token = ( - token - if isinstance(token, Token) or token is None - else Token.from_jwt(token, denial_period=datetime.timedelta(seconds=2)) - ) + self._token = _wrap_token(token) @classmethod def from_token( @@ -999,6 +994,20 @@ def validate_dataset_model( if not response["valid"]: raise ValueError(f"Dataset {dset} did not pass validation in SciCat.") + def renew_login(self) -> None: + """Request and assign a new SciCat token. + + Can be used to prolong a login session before a token expires. + The new token is assigned to the client and is used for all future operations. + + Raises :class:`scitacean.ScicatCommError` if renewal fails. + In this case, the old token will not be replaced. + """ + response = self._call_endpoint( + cmd="post", url="users/jwt", operation="renew_login" + ) + self._token = _wrap_token(response["jwt"]) + def _send_to_scicat( self, *, cmd: str, url: str, data: model.BaseModel | None = None ) -> requests.Response: @@ -1162,6 +1171,14 @@ def _get_token( raise ScicatLoginError(response.content) +def _wrap_token(token: str | SecretStr | Token | None) -> Token: + match token: + case str() | SecretStr(): + return Token.from_jwt(token, denial_period=datetime.timedelta(seconds=2)) + case Token() | None: + return token + + FileSelector = ( bool | str | list[str] | tuple[str] | re.Pattern[str] | Callable[[File], bool] ) diff --git a/src/scitacean/testing/client.py b/src/scitacean/testing/client.py index 5e5c9e1f..a25c0d98 100644 --- a/src/scitacean/testing/client.py +++ b/src/scitacean/testing/client.py @@ -288,6 +288,13 @@ def validate_dataset_model( # Models were locally validated on construction, assume they are valid. pass + @_conditionally_disabled + def renew_login(self) -> None: + """Request a new SciCat token. + + Does nothing because FakeScicatClient does not use authentication. + """ + def _model_dict(mod: model.BaseModel) -> dict[str, Any]: return { diff --git a/tests/client/client_test.py b/tests/client/client_test.py index d11acfd2..4ee81a56 100644 --- a/tests/client/client_test.py +++ b/tests/client/client_test.py @@ -127,3 +127,22 @@ def test_detects_expired_token_get_dataset(scicat_access, require_scicat_backend time.sleep(0.5) with pytest.raises(RuntimeError, match="SciCat login has expired"): client.get_dataset(INITIAL_DATASETS["public"].pid) # type: ignore[arg-type] + + +def test_renew_login_real(scicat_access, require_scicat_backend): + client = Client.from_credentials( + url=scicat_access.url, **scicat_access.user.credentials + ) + initial = client.scicat._token + # Wait long enough to see a difference in the token because the + # expiration date has seconds-resolution. + time.sleep(1) + client.scicat.renew_login() + renewed = client.scicat._token + assert renewed.expires_at > initial.expires_at + + +def test_renew_login_fake(): + client = FakeClient.without_login(url="") + # Does nothing + client.scicat.renew_login() From f5789e6740164ed6c2e55b7bd28ce094c0ab07fa Mon Sep 17 00:00:00 2001 From: Jan-Lukas Wynen Date: Mon, 29 Apr 2024 16:23:10 +0200 Subject: [PATCH 3/3] Automatically renew token in operations --- src/scitacean/client.py | 80 ++++++++++++++++++++++++----- src/scitacean/testing/client.py | 2 + tests/client/client_test.py | 4 ++ tests/client/dataset_client_test.py | 41 +++++++++++++++ 4 files changed, 115 insertions(+), 12 deletions(-) diff --git a/src/scitacean/client.py b/src/scitacean/client.py index 4dfe6424..9adcaa8f 100644 --- a/src/scitacean/client.py +++ b/src/scitacean/client.py @@ -34,8 +34,9 @@ class Client: Clients hold all information needed to communicate with a SciCat instance and a filesystem that holds data files (via ``file_transfer``). - Use :func:`Client.from_token` or :func:`Client.from_credentials` to initialize - a client instead of the constructor directly. + Use :func:`Client.from_token`, :func:`Client.from_credentials`, or + :func:`Client.without_login` to initialize a client instead + of the constructor directly. See the user guide for typical usage patterns. In particular, `Downloading Datasets <../../user-guide/downloading.ipynb>`_ @@ -50,8 +51,8 @@ def __init__( ): """Initialize a client. - Do not use directly, instead use :func:`Client.from_token` - or :func:`Client.from_credentials`! + Do not use directly, instead use :func:`Client.from_token`, + :func:`Client.from_credentials`, or :func:`Client.without_login`! """ self._client = client self._file_transfer = file_transfer @@ -63,6 +64,7 @@ def from_token( url: str, token: str | SecretStr | Token, file_transfer: FileTransfer | None = None, + auto_renew_period: datetime.timedelta | None = datetime.timedelta(seconds=30), ) -> Client: """Create a new client and authenticate with a token. @@ -74,6 +76,9 @@ def from_token( User token to authenticate with SciCat. file_transfer: Handler for down-/uploads of files. + auto_renew_period: + If not ``None``, the SciCat login is renewed in operations + that happen within this time delta of the login expiration time. Returns ------- @@ -81,7 +86,11 @@ def from_token( A new client. """ return Client( - client=ScicatClient.from_token(url=url, token=token), + client=ScicatClient.from_token( + url=url, + token=token, + auto_renew_period=auto_renew_period, + ), file_transfer=file_transfer, ) @@ -93,6 +102,7 @@ def from_credentials( username: str | SecretStr, password: str | SecretStr, file_transfer: FileTransfer | None = None, + auto_renew_period: datetime.timedelta | None = datetime.timedelta(seconds=30), ) -> Client: """Create a new client and authenticate with username and password. @@ -107,6 +117,9 @@ def from_credentials( Password of the user. file_transfer: Handler for down-/uploads of files. + auto_renew_period: + If not ``None``, the SciCat login is renewed in operations + that happen within this time delta of the login expiration time. Returns ------- @@ -115,14 +128,20 @@ def from_credentials( """ return Client( client=ScicatClient.from_credentials( - url=url, username=username, password=password + url=url, + username=username, + password=password, + auto_renew_period=auto_renew_period, ), file_transfer=file_transfer, ) @classmethod def without_login( - cls, *, url: str, file_transfer: FileTransfer | None = None + cls, + *, + url: str, + file_transfer: FileTransfer | None = None, ) -> Client: """Create a new client without authentication. @@ -142,7 +161,8 @@ def without_login( A new client. """ return Client( - client=ScicatClient.without_login(url=url), file_transfer=file_transfer + client=ScicatClient.without_login(url=url), + file_transfer=file_transfer, ) @property @@ -560,11 +580,13 @@ def __init__( url: str, token: str | SecretStr | Token | None, timeout: datetime.timedelta | None, + auto_renew_period: datetime.timedelta | None = datetime.timedelta(seconds=30), ): # Need to add a final / self._base_url = url[:-1] if url.endswith("/") else url self._timeout = datetime.timedelta(seconds=10) if timeout is None else timeout self._token = _wrap_token(token) + self._auto_renew_period = auto_renew_period @classmethod def from_token( @@ -572,6 +594,7 @@ def from_token( url: str, token: str | SecretStr | Token, timeout: datetime.timedelta | None = None, + auto_renew_period: datetime.timedelta | None = datetime.timedelta(seconds=30), ) -> ScicatClient: """Create a new low-level client and authenticate with a token. @@ -583,13 +606,21 @@ def from_token( User token to authenticate with SciCat. timeout: Timeout for all API requests. + auto_renew_period: + If not ``None``, the SciCat login is renewed in operations + that happen within this time delta of the login expiration time. Returns ------- : A new low-level client. """ - return ScicatClient(url=url, token=token, timeout=timeout) + return ScicatClient( + url=url, + token=token, + timeout=timeout, + auto_renew_period=auto_renew_period, + ) @classmethod def from_credentials( @@ -598,6 +629,7 @@ def from_credentials( username: str | SecretStr, password: str | SecretStr, timeout: datetime.timedelta | None = None, + auto_renew_period: datetime.timedelta | None = datetime.timedelta(seconds=30), ) -> ScicatClient: """Create a new low-level client and authenticate with username and password. @@ -612,6 +644,9 @@ def from_credentials( Password of the user. timeout: Timeout for all API requests. + auto_renew_period: + If not ``None``, the SciCat login is renewed in operations + that happen within this time delta of the login expiration time. Returns ------- @@ -631,6 +666,7 @@ def from_credentials( ) ), timeout=timeout, + auto_renew_period=auto_renew_period, ) @classmethod @@ -654,7 +690,9 @@ def without_login( : A new low-level client. """ - return ScicatClient(url=url, token=None, timeout=timeout) + return ScicatClient( + url=url, token=None, timeout=timeout, auto_renew_period=None + ) def get_dataset_model( self, pid: PID, strict_validation: bool = False @@ -1004,10 +1042,24 @@ def renew_login(self) -> None: In this case, the old token will not be replaced. """ response = self._call_endpoint( - cmd="post", url="users/jwt", operation="renew_login" + cmd="post", url="users/jwt", operation="renew_login", renew_login=False ) self._token = _wrap_token(response["jwt"]) + def _renew_login_if_needed(self, operation: str) -> None: + if ( + self._token is not None + and self._token.expires_at is not None + and self._auto_renew_period is not None + ): + if self._token.expires_at + self._auto_renew_period > datetime.datetime.now( + tz=self._token.expires_at.tzinfo + ): + get_logger().info( + "Renewing SciCat login during operation '%s'", operation + ) + self.renew_login() + def _send_to_scicat( self, *, cmd: str, url: str, data: model.BaseModel | None = None ) -> requests.Response: @@ -1050,7 +1102,11 @@ def _call_endpoint( url: str, data: model.BaseModel | None = None, operation: str, + renew_login: bool = True, ) -> Any: + if renew_login: + self._renew_login_if_needed(operation) + full_url = _url_concat(self._base_url, url) logger = get_logger() logger.info("Calling SciCat API at %s for operation '%s'", full_url, operation) @@ -1171,7 +1227,7 @@ def _get_token( raise ScicatLoginError(response.content) -def _wrap_token(token: str | SecretStr | Token | None) -> Token: +def _wrap_token(token: str | SecretStr | Token | None) -> Token | None: match token: case str() | SecretStr(): return Token.from_jwt(token, denial_period=datetime.timedelta(seconds=2)) diff --git a/src/scitacean/testing/client.py b/src/scitacean/testing/client.py index a25c0d98..f10aeb48 100644 --- a/src/scitacean/testing/client.py +++ b/src/scitacean/testing/client.py @@ -136,6 +136,7 @@ def from_token( url: str, token: str | SecretStr, file_transfer: FileTransfer | None = None, + auto_renew_period: datetime.timedelta | None = datetime.timedelta(seconds=30), ) -> FakeClient: """Create a new fake client. @@ -151,6 +152,7 @@ def from_credentials( username: str | SecretStr, password: str | SecretStr, file_transfer: FileTransfer | None = None, + auto_renew_period: datetime.timedelta | None = datetime.timedelta(seconds=30), ) -> FakeClient: """Create a new fake client. diff --git a/tests/client/client_test.py b/tests/client/client_test.py index 4ee81a56..8fe71188 100644 --- a/tests/client/client_test.py +++ b/tests/client/client_test.py @@ -139,6 +139,10 @@ def test_renew_login_real(scicat_access, require_scicat_backend): time.sleep(1) client.scicat.renew_login() renewed = client.scicat._token + assert initial is not None + assert initial.expires_at is not None + assert renewed is not None + assert renewed.expires_at is not None assert renewed.expires_at > initial.expires_at diff --git a/tests/client/dataset_client_test.py b/tests/client/dataset_client_test.py index 0380619b..ce66976a 100644 --- a/tests/client/dataset_client_test.py +++ b/tests/client/dataset_client_test.py @@ -2,6 +2,9 @@ # Copyright (c) 2024 SciCat Project (https://github.com/SciCatProject/scitacean) # mypy: disable-error-code="arg-type, index" +import time +from datetime import timedelta + import pydantic import pytest from dateutil.parser import parse as parse_date @@ -141,3 +144,41 @@ def test_get_broken_dataset_strict_validation(real_client, require_scicat_backen dset = INITIAL_DATASETS["partially-broken"] with pytest.raises(pydantic.ValidationError): real_client.get_dataset(dset.pid, strict_validation=True) + + +def test_get_dataset_renews_login(scicat_access, require_scicat_backend): + # The test backend is configured to create tokens that expire after 1h. + # So pick a renewal period that guarantees that the token is renewed. + real_client = Client.from_credentials( + url=scicat_access.url, + **scicat_access.user.credentials, + auto_renew_period=timedelta(hours=1), + ) + initial = real_client.scicat._token + # Wait long enough to see a difference in the token because the + # expiration date has seconds-resolution. + time.sleep(1) + _ = real_client.get_dataset(INITIAL_DATASETS["derived"].pid) + renewed = real_client.scicat._token + assert initial is not None + assert initial.expires_at is not None + assert renewed is not None + assert renewed.expires_at is not None + assert renewed.expires_at > initial.expires_at + + +def test_get_dataset_disabled_login_renewal(scicat_access, require_scicat_backend): + real_client = Client.from_credentials( + url=scicat_access.url, **scicat_access.user.credentials, auto_renew_period=None + ) + initial = real_client.scicat._token + # Wait long enough to see a difference in the token because the + # expiration date has seconds-resolution. + time.sleep(1) + _ = real_client.get_dataset(INITIAL_DATASETS["derived"].pid) + not_renewed = real_client.scicat._token + assert initial is not None + assert initial.expires_at is not None + assert not_renewed is not None + assert not_renewed.expires_at is not None + assert not_renewed.expose() == initial.expose()