diff --git a/data_safe_haven/allowlist/__init__.py b/data_safe_haven/allowlist/__init__.py new file mode 100644 index 0000000000..dc6eac9282 --- /dev/null +++ b/data_safe_haven/allowlist/__init__.py @@ -0,0 +1,3 @@ +from .allowlist import Allowlist + +__all__ = ["Allowlist"] diff --git a/data_safe_haven/allowlist/allowlist.py b/data_safe_haven/allowlist/allowlist.py new file mode 100644 index 0000000000..aa0654f751 --- /dev/null +++ b/data_safe_haven/allowlist/allowlist.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from difflib import unified_diff +from typing import Self + +from data_safe_haven.config import Context +from data_safe_haven.exceptions import DataSafeHavenAzureStorageError +from data_safe_haven.external import AzureSdk +from data_safe_haven.infrastructure import SREProjectManager +from data_safe_haven.types import AllowlistRepository + + +class Allowlist: + """Allowlist for packages.""" + + def __init__( + self, + repository: AllowlistRepository, + sre_stack: SREProjectManager, + allowlist: str | None = None, + ): + self.repository = repository + self.sre_resource_group = sre_stack.output("sre_resource_group") + self.storage_account_name = sre_stack.output("data")[ + "storage_account_data_configuration_name" + ] + self.share_name = sre_stack.output("allowlist_share_name") + self.filename = sre_stack.output("allowlist_share_filenames")[repository.value] + self.allowlist = str(allowlist) if allowlist else "" + + @classmethod + def from_remote( + cls: type[Self], + *, + context: Context, + repository: AllowlistRepository, + sre_stack: SREProjectManager, + ) -> Self: + azure_sdk = AzureSdk(subscription_name=context.subscription_name) + allowlist = cls(repository=repository, sre_stack=sre_stack) + try: + share_file = azure_sdk.download_share_file( + allowlist.filename, + allowlist.sre_resource_group, + allowlist.storage_account_name, + allowlist.share_name, + ) + allowlist.allowlist = share_file + return allowlist + except DataSafeHavenAzureStorageError as exc: + msg = f"Storage account '{cls.storage_account_name}' does not exist." + raise DataSafeHavenAzureStorageError(msg) from exc + + @classmethod + def remote_exists( + cls: type[Self], + context: Context, + *, + repository: AllowlistRepository, + sre_stack: SREProjectManager, + ) -> bool: + # Get the Azure SDK + azure_sdk = AzureSdk(subscription_name=context.subscription_name) + + allowlist = cls(repository=repository, sre_stack=sre_stack) + + # Get the file share name + share_list_exists = azure_sdk.file_share_exists( + allowlist.filename, + allowlist.sre_resource_group, + allowlist.storage_account_name, + allowlist.share_name, + ) + return share_list_exists + + def upload( + self, + context: Context, + ) -> None: + # Get the Azure SDK + azure_sdk = AzureSdk(subscription_name=context.subscription_name) + + azure_sdk.upload_file_share( + self.allowlist, + self.filename, + self.sre_resource_group, + self.storage_account_name, + self.share_name, + ) + + def diff(self, other: Allowlist) -> list[str]: + diff = list( + unified_diff( + self.allowlist.splitlines(), + other.allowlist.splitlines(), + fromfile="remote", + tofile="local", + ) + ) + return diff diff --git a/data_safe_haven/commands/allowlist.py b/data_safe_haven/commands/allowlist.py new file mode 100644 index 0000000000..4c9d66f19a --- /dev/null +++ b/data_safe_haven/commands/allowlist.py @@ -0,0 +1,183 @@ +"""Command group for managing package allowlists""" + +from pathlib import Path +from typing import Annotated, Optional + +import typer + +from data_safe_haven import console +from data_safe_haven.allowlist import Allowlist +from data_safe_haven.config import ContextManager, DSHPulumiConfig, SREConfig +from data_safe_haven.exceptions import DataSafeHavenConfigError, DataSafeHavenError +from data_safe_haven.infrastructure import SREProjectManager +from data_safe_haven.logging import get_logger +from data_safe_haven.types import AllowlistRepository + +allowlist_command_group = typer.Typer() + + +@allowlist_command_group.command() +def show( + name: Annotated[ + str, + typer.Argument(help="Name of SRE to show allowlist for."), + ], + repository: Annotated[ + AllowlistRepository, + typer.Argument(help="Name of the repository to show the allowlist for."), + ], + file: Annotated[ + Optional[str], # noqa: UP007 + typer.Option(help="File path to write the allowlist to."), + ] = None, +) -> None: + """Print the current package allowlist""" + logger = get_logger() + + try: + context = ContextManager.from_file().assert_context() + except DataSafeHavenConfigError as exc: + logger.critical( + "No context is selected. Use `dsh context add` to create a context " + "or `dsh context switch` to select one." + ) + raise typer.Exit(1) from exc + + sre_config = SREConfig.from_remote_by_name(context, name) + + # Load Pulumi config + pulumi_config = DSHPulumiConfig.from_remote(context) + + if sre_config.name not in pulumi_config.project_names: + msg = f"Could not load Pulumi settings for '{sre_config.name}'. Have you deployed the SRE?" + logger.error(msg) + raise typer.Exit(1) + + sre_stack = SREProjectManager( + context=context, + config=sre_config, + pulumi_config=pulumi_config, + ) + + try: + allowlist = Allowlist.from_remote( + context=context, repository=repository, sre_stack=sre_stack + ) + except DataSafeHavenError as exc: + logger.critical( + "No allowlist is configured. Use `dsh allowlist upload` to create one." + ) + raise typer.Exit(1) from exc + + if file: + with open(file, "w") as f: + f.write(allowlist.allowlist) + else: + console.print(allowlist.allowlist) + + +@allowlist_command_group.command() +def template( + repository: Annotated[ + AllowlistRepository, + typer.Argument(help="Name of the repository to show the allowlist for."), + ], + file: Annotated[ + Optional[Path], # noqa: UP007 + typer.Option(help="File path to write allowlist template to."), + ] = None, +) -> None: + """Print a template for the package allowlist""" + + template_path = Path( + "data_safe_haven/resources", + "software_repositories", + "allowlists", + f"{repository.value}.allowlist", + ) + with open(template_path) as f: + example_allowlist = f.read() + if file: + with open(file, "w") as f: + f.write(example_allowlist) + raise typer.Exit() + else: + console.print(example_allowlist) + + +@allowlist_command_group.command() +def upload( + name: Annotated[ + str, + typer.Argument(help="Name of SRE to upload the allowlist for."), + ], + file: Annotated[ + Path, + typer.Argument(help="Path to the allowlist file to upload."), + ], + repository: Annotated[ + AllowlistRepository, + typer.Argument(help="Repository type of the allowlist."), + ], + force: Annotated[ # noqa: FBT002 + bool, + typer.Option(help="Skip check for existing remote allowlist."), + ] = False, +) -> None: + """Upload a package allowlist""" + context = ContextManager.from_file().assert_context() + logger = get_logger() + + if file.is_file(): + with open(file) as f: + allowlist = f.read() + else: + logger.critical(f"Allowlist file '{file}' not found.") + raise typer.Exit(1) + sre_config = SREConfig.from_remote_by_name(context, name) + + # Load Pulumi config + pulumi_config = DSHPulumiConfig.from_remote(context) + + if sre_config.name not in pulumi_config.project_names: + msg = f"Could not load Pulumi settings for '{sre_config.name}'. Have you deployed the SRE?" + logger.error(msg) + raise typer.Exit(1) + + sre_stack = SREProjectManager( + context=context, + config=sre_config, + pulumi_config=pulumi_config, + ) + + local_allowlist = Allowlist( + repository=repository, sre_stack=sre_stack, allowlist=allowlist + ) + + if not force and Allowlist.remote_exists( + context=context, + repository=repository, + sre_stack=sre_stack, + ): + remote_allowlist = Allowlist.from_remote( + context=context, + repository=repository, + sre_stack=sre_stack, + ) + if allow_diff := remote_allowlist.diff(local_allowlist): + for line in list(filter(None, "\n".join(allow_diff).splitlines())): + logger.info(line) + if not console.confirm( + f"An allowlist already exists for {repository.name}. Do you want to overwrite it?", + default_to_yes=True, + ): + raise typer.Exit() + else: + console.print("No changes, won't upload allowlist.") + raise typer.Exit() + try: + logger.info(f"Uploading allowlist for {repository.name} to {sre_config.name}") + local_allowlist.upload(context=context) + except DataSafeHavenError as exc: + logger.error(f"Failed to upload allowlist: {exc}") + raise typer.Exit(1) from exc diff --git a/data_safe_haven/commands/cli.py b/data_safe_haven/commands/cli.py index 29c6dcd0f7..ce8cd4e266 100644 --- a/data_safe_haven/commands/cli.py +++ b/data_safe_haven/commands/cli.py @@ -7,6 +7,7 @@ from data_safe_haven import __version__, console from data_safe_haven.logging import set_console_level, show_console_level +from .allowlist import allowlist_command_group from .config import config_command_group from .context import context_command_group from .pulumi import pulumi_command_group @@ -64,6 +65,11 @@ def callback( # Register command groups +application.add_typer( + allowlist_command_group, + name="allowlist", + help="Manage the package allowlists of a Data Safe Haven deployment.", +) application.add_typer( users_command_group, name="users", diff --git a/data_safe_haven/external/api/azure_sdk.py b/data_safe_haven/external/api/azure_sdk.py index 3ae0c4c3de..637a5161f1 100644 --- a/data_safe_haven/external/api/azure_sdk.py +++ b/data_safe_haven/external/api/azure_sdk.py @@ -60,6 +60,7 @@ ) from azure.storage.blob import BlobClient, BlobServiceClient from azure.storage.filedatalake import DataLakeServiceClient +from azure.storage.fileshare import ShareClient, ShareServiceClient from data_safe_haven.exceptions import ( DataSafeHavenAzureAPIAuthenticationError, @@ -162,6 +163,113 @@ def blob_exists( ) return exists + def list_shares( + self, + resource_group_name: str, + storage_account_name: str, + ) -> list[str]: + """List all shares in a container + + Returns: + List[str]: The list of share names + """ + + share_client = self.share_service_client( + resource_group_name=resource_group_name, + storage_account_name=storage_account_name, + ) + share_list = share_client.list_shares() + return list(share_list) + + def share_client( + self, resource_group_name: str, storage_account_name: str, file_share_name: str + ) -> ShareClient: + + share_service_client = self.share_service_client( + resource_group_name, storage_account_name + ) + share_client = share_service_client.get_share_client(share=file_share_name) + return share_client + + def share_service_client( + self, resource_group_name: str, storage_account_name: str + ) -> ShareServiceClient: + storage_account_keys = self.get_storage_account_keys( + resource_group_name, storage_account_name + ) + + share_service_client = ShareServiceClient( + account_url=f"https://{storage_account_name}.file.core.windows.net", + credential=storage_account_keys[0].value, + ) + return share_service_client + + def download_share_file( + self, + file_name: str, + resource_group_name: str, + storage_account_name: str, + file_share_name: str, + ) -> str: + """Download a share file from Azure storage + + Returns: + str: The contents of the share + + Raises: + DataSafeHavenAzureError if the share could not be downloaded + """ + try: + # Get the share client + share_client = self.share_client( + resource_group_name, + storage_account_name, + file_share_name, + ) + share_file_client = share_client.get_file_client(file_name) + # Download the requested file + share_content = share_file_client.download_file(encoding="utf-8").readall() + self.logger.debug( + f"Downloaded file [green]{file_name}[/] from share storage.", + ) + return str(share_content) + except (AzureError, DataSafeHavenAzureStorageError) as exc: + msg = f"Share file '{file_name}' could not be downloaded from '{storage_account_name}'." + raise DataSafeHavenAzureError(msg) from exc + + def file_share_exists( + self, + file_name: str, + resource_group_name: str, + storage_account_name: str, + storage_share_name: str, + ) -> bool: + """Find out whether a file share exists in Azure storage + + Returns: + bool: Whether or not the file share exists + """ + + if not self.storage_exists(storage_account_name): + msg = f"Storage account '{storage_account_name}' could not be found." + raise DataSafeHavenAzureStorageError(msg) + + try: + share_client = self.share_client( + resource_group_name, + storage_account_name, + storage_share_name, + ) + share_file_client = share_client.get_file_client(file_name) + exists = bool(share_file_client.exists()) + except DataSafeHavenAzureStorageError: + exists = False + response = "exists" if exists else "does not exist" + self.logger.debug( + f"File [green]{file_name}[/] {response} in file share.", + ) + return exists + def blob_service_client( self, resource_group_name: str, @@ -1382,3 +1490,40 @@ def upload_blob( except (AzureError, DataSafeHavenAzureStorageError) as exc: msg = f"Blob file '{blob_name}' could not be uploaded to '{storage_account_name}'." raise DataSafeHavenAzureError(msg) from exc + + def upload_file_share( + self, + file_data: str, + file_name: str, + resource_group_name: str, + storage_account_name: str, + file_share_name: str, + ) -> None: + """Upload a file to Azure file share + + Returns: + None + + Raises: + DataSafeHavenAzureError if the file could not be uploaded + """ + try: + # Get the share client + share_client = self.share_client( + resource_group_name, + storage_account_name, + file_share_name, + ) + share_file_client = share_client.get_file_client(file_name) + # Upload the created file + share_file_client.upload_file( + file_data, + ) + self.logger.debug( + f"Uploaded file [green]{file_name}[/] to file share.", + ) + except (AzureError, DataSafeHavenAzureStorageError) as exc: + msg = ( + f"File '{file_name}' could not be uploaded to '{storage_account_name}'." + ) + raise DataSafeHavenAzureError(msg) from exc diff --git a/data_safe_haven/infrastructure/programs/declarative_sre.py b/data_safe_haven/infrastructure/programs/declarative_sre.py index 02cc59b71c..7da8d5481f 100644 --- a/data_safe_haven/infrastructure/programs/declarative_sre.py +++ b/data_safe_haven/infrastructure/programs/declarative_sre.py @@ -427,8 +427,17 @@ def __call__(self) -> None: ) # Export values for later use + pulumi.export( + "allowlist_share_name", + user_services.software_repositories.allowlist_file_share_name, + ) + pulumi.export( + "allowlist_share_filenames", + user_services.software_repositories.allowlist_file_names, + ) pulumi.export("data", data.exports) pulumi.export("ldap", ldap_group_names) pulumi.export("remote_desktop", remote_desktop.exports) pulumi.export("sre_fqdn", networking.sre_fqdn) + pulumi.export("sre_resource_group", resource_group.name) pulumi.export("workspaces", workspaces.exports) diff --git a/data_safe_haven/infrastructure/programs/sre/data.py b/data_safe_haven/infrastructure/programs/sre/data.py index 10732670f5..9833433a1e 100644 --- a/data_safe_haven/infrastructure/programs/sre/data.py +++ b/data_safe_haven/infrastructure/programs/sre/data.py @@ -815,4 +815,5 @@ def __init__( self.exports = { "key_vault_name": key_vault.name, "password_user_database_admin_secret": kvs_password_user_database_admin.name, + "storage_account_data_configuration_name": storage_account_data_configuration.name, } diff --git a/data_safe_haven/infrastructure/programs/sre/software_repositories.py b/data_safe_haven/infrastructure/programs/sre/software_repositories.py index dd5ac81bfb..39844b66d5 100644 --- a/data_safe_haven/infrastructure/programs/sre/software_repositories.py +++ b/data_safe_haven/infrastructure/programs/sre/software_repositories.py @@ -124,36 +124,38 @@ def __init__( ) # Upload Nexus allowlists - cran_reader = FileReader( - resources_path / "software_repositories" / "allowlists" / "cran.allowlist" - ) - FileShareFile( + cran_allowlist = FileShareFile( f"{self._name}_file_share_cran_allowlist", FileShareFileProps( - destination_path=cran_reader.name, + destination_path="cran.allowlist", share_name=file_share_nexus_allowlists.name, - file_contents=cran_reader.file_contents(), + file_contents="", storage_account_key=props.storage_account_key, storage_account_name=props.storage_account_name, ), opts=ResourceOptions.merge( - child_opts, ResourceOptions(parent=file_share_nexus_allowlists) + child_opts, + ResourceOptions( + parent=file_share_nexus_allowlists, + ignore_changes=["file_contents"], + ), ), ) - pypi_reader = FileReader( - resources_path / "software_repositories" / "allowlists" / "pypi.allowlist" - ) - FileShareFile( + pypi_allowlist = FileShareFile( f"{self._name}_file_share_pypi_allowlist", FileShareFileProps( - destination_path=pypi_reader.name, + destination_path="pypi.allowlist", share_name=file_share_nexus_allowlists.name, - file_contents=pypi_reader.file_contents(), + file_contents="", storage_account_key=props.storage_account_key, storage_account_name=props.storage_account_name, ), opts=ResourceOptions.merge( - child_opts, ResourceOptions(parent=file_share_nexus_allowlists) + child_opts, + ResourceOptions( + parent=file_share_nexus_allowlists, + ignore_changes=["file_contents"], + ), ), ) @@ -344,3 +346,8 @@ def __init__( # Register outputs self.hostname = hostname + self.allowlist_file_share_name = file_share_nexus_allowlists.name + self.allowlist_file_names = { + "cran": cran_allowlist.destination_path, + "pypi": pypi_allowlist.destination_path, + } diff --git a/data_safe_haven/types/__init__.py b/data_safe_haven/types/__init__.py index bfe1f6898a..56394b734c 100644 --- a/data_safe_haven/types/__init__.py +++ b/data_safe_haven/types/__init__.py @@ -14,6 +14,7 @@ UniqueList, ) from .enums import ( + AllowlistRepository, AzureDnsZoneNames, AzureSdkCredentialScope, AzureServiceTag, @@ -31,6 +32,7 @@ from .types import PathType __all__ = [ + "AllowlistRepository", "AzureDnsZoneNames", "AzureLocation", "AzurePremiumFileShareSize", diff --git a/data_safe_haven/types/enums.py b/data_safe_haven/types/enums.py index 17d5dda8e3..bbacadfaba 100644 --- a/data_safe_haven/types/enums.py +++ b/data_safe_haven/types/enums.py @@ -211,3 +211,11 @@ class SoftwarePackageCategory(str, Enum): ANY = "any" PRE_APPROVED = "pre-approved" NONE = "none" + + +@verify(UNIQUE) +class AllowlistRepository(str, Enum): + """Repositories for which allowlists are maintained.""" + + CRAN = "cran" + PYPI = "pypi" diff --git a/docs/source/management/allowlist.md b/docs/source/management/allowlist.md new file mode 100644 index 0000000000..d346934b71 --- /dev/null +++ b/docs/source/management/allowlist.md @@ -0,0 +1,55 @@ +# Managing allowlists + +For Tier 3 SREs, the Python and R software packages that users are allowed to download from the PyPI and CRAN repositories are restricted. +Connection to PyPI and CRAN is achieved using [Sonatype Nexus Repository](https://www.sonatype.com/products/sonatype-nexus-repository). + +Packages must be explicitly added to the allowlist for the relevant repository before the users can download the package. +Packages not on the allowlist are blocked. + +An allowlist is a plain text file, with the name of each allowed package on its own line. + +```{important} +The user must also be able to download any dependencies of any package on the allowlist. +You should ensure that any such dependencies are also added to the allowlist. + +For example, a minimal CRAN allowlist that permits the user to install the packages `data.table`, `DBI`, and `RPostgres` would be as below. + +:::{code} text +bit64 +blob +data.table +DBI +hms +lubridate +RPostgres +withr +::: + +This includes the requested packages and their dependencies. +``` + +## Viewing allowlists + +To view the current allowlist for a given repository, use {typer}`dsh allowlist show` + +```{code} shell +dsh allowlist show YOUR_SRE_NAME REPOSITORY_NAME +``` + +## Uploading and updating an allowlist + +To upload an allowlist, use {typer}`dsh allowlist upload`. + +```{code} shell +dsh allowlist upload YOUR_SRE_NAME PATH_TO_ALLOWLIST_FILE REPOSITORY_NAME +``` + +The local allowlist file does not need to have a specific name. + +## Example allowlists + +Example allowlists for PyPI and CRAN can be generated using {typer}`dsh allowlist template` + +```{code} shell +dsh allowlist template REPOSITORY_NAME +``` diff --git a/docs/source/management/index.md b/docs/source/management/index.md index f8cd8ac0e0..ee9e08aa85 100644 --- a/docs/source/management/index.md +++ b/docs/source/management/index.md @@ -3,10 +3,11 @@ :::{toctree} :hidden: -user.md -sre.md +allowlist.md data.md logs.md +sre.md +user.md ::: Running a secure and productive Data Safe Haven requires a manager to conduct tasks which support users and to monitor the correct operation of the TRE. diff --git a/docs/source/reference/allowlist.md b/docs/source/reference/allowlist.md new file mode 100644 index 0000000000..3f89402ec4 --- /dev/null +++ b/docs/source/reference/allowlist.md @@ -0,0 +1,10 @@ +# `allowlist` + +`dsh allowlist` commands are used to manage which packages are allowed to be downloaded from software repositories in Tier 3 SREs. For an explanation of the tiering system, see [Security Objectives](https://data-safe-haven.readthedocs.io/en/latest/design/security/objectives.html). + +:::{typer} data_safe_haven.commands.allowlist:allowlist_command_group +:width: 65 +:prog: dsh allowlist +:show-nested: +:make-sections: +::: diff --git a/docs/source/reference/index.md b/docs/source/reference/index.md index 6bcb99fd2e..95a80c35f1 100644 --- a/docs/source/reference/index.md +++ b/docs/source/reference/index.md @@ -3,6 +3,7 @@ :::{toctree} :hidden: +allowlist.md config.md context.md users.md @@ -25,6 +26,9 @@ All commands begin with `dsh`. The subcommands can be used to manage various aspects of a Data Safe Haven deployment. For further detail on each subcommand, navigate to the relevant page. +[Allowlist](allowlist.md) +: Management of the package allowlists for Tier 3 SREs + [Config](config.md) : Management of the configuration files used to define SHMs and SREs diff --git a/tests/allowlist/test_allowlist.py b/tests/allowlist/test_allowlist.py new file mode 100644 index 0000000000..e19563fedf --- /dev/null +++ b/tests/allowlist/test_allowlist.py @@ -0,0 +1,133 @@ +from pytest import fixture + +from data_safe_haven.allowlist import Allowlist +from data_safe_haven.external import AzureSdk +from data_safe_haven.provisioning.sre_provisioning_manager import SREProjectManager +from data_safe_haven.types import AllowlistRepository + + +@fixture +def mock_project_output(request): + if request == "allowlist_share_filenames": + return { + "cran": "cran.allowlist", + "pypi": "pypi.allowlist", + } + elif request == "data": + return {"storage_account_data_configuration_name": "test"} + elif request == "sre_resource_group": + return "test" + + +class TestAllowlist: + def test_from_remote( + self, + mocker, + context, + sre_project_manager, + mock_project_output, + ) -> None: + + repository = AllowlistRepository.CRAN + mocker.patch.object( + SREProjectManager, + "output", + wraps=mock_project_output, + ) + mocker.patch.object( + AzureSdk, + "download_share_file", + return_value="tidyverse\ndplyr\nnumpy", + ) + result = Allowlist.from_remote( + context=context, + sre_stack=sre_project_manager, + repository=repository, + ).allowlist + assert "dplyr" in result + + def test_remote_exists( + self, mocker, context, sre_project_manager, mock_project_output + ) -> None: + mocker.patch.object( + AzureSdk, + "file_share_exists", + return_value=True, + ) + mocker.patch.object( + SREProjectManager, + "output", + wraps=mock_project_output, + ) + + exists = Allowlist.remote_exists( + context, + sre_stack=sre_project_manager, + repository=AllowlistRepository.CRAN, + ) + + assert isinstance(exists, bool) + assert exists + + def test_remote_diff( + self, + mocker, + context, + sre_project_manager, + mock_project_output, + ) -> None: + mocker.patch.object( + SREProjectManager, + "output", + wraps=mock_project_output, + ) + mocker.patch.object( + AzureSdk, "download_share_file", return_value="tidyverse\ndplyr\nnumpy" + ) + + local_allowlist = Allowlist( + sre_stack=sre_project_manager, + repository=AllowlistRepository.CRAN, + allowlist="tidyverse\ndplyr\nnumpy\npandas", + ) + remote_allowlist = Allowlist.from_remote( + context=context, + sre_stack=sre_project_manager, + repository=AllowlistRepository.CRAN, + ) + + diff = remote_allowlist.diff(local_allowlist) + + assert isinstance(diff, list) + assert "+pandas" in diff + + def test_remote_diff_no_change( + self, + mocker, + context, + sre_project_manager, + mock_project_output, + ) -> None: + mocker.patch.object( + SREProjectManager, + "output", + wraps=mock_project_output, + ) + mocker.patch.object( + AzureSdk, "download_share_file", return_value="tidyverse\ndplyr\nnumpy" + ) + local_allowlist = Allowlist( + sre_stack=sre_project_manager, + repository=AllowlistRepository.CRAN, + allowlist="tidyverse\ndplyr\nnumpy", + ) + remote_allowlist = Allowlist.from_remote( + context=context, + sre_stack=sre_project_manager, + repository=AllowlistRepository.CRAN, + ) + + diff = remote_allowlist.diff(local_allowlist) + + assert isinstance(diff, list) + assert not diff diff --git a/tests/commands/test_allowlist.py b/tests/commands/test_allowlist.py new file mode 100644 index 0000000000..28b468047c --- /dev/null +++ b/tests/commands/test_allowlist.py @@ -0,0 +1,195 @@ +from pytest import fixture, mark + +from data_safe_haven.allowlist import Allowlist +from data_safe_haven.commands.allowlist import allowlist_command_group +from data_safe_haven.external import AzureSdk +from data_safe_haven.infrastructure import SREProjectManager +from data_safe_haven.types import AllowlistRepository + + +@fixture +def mock_allowlist(mocker, sre_project_manager, mock_project_output) -> Allowlist: + mocker.patch.object( + SREProjectManager, + "output", + wraps=mock_project_output, + ) + allow = Allowlist( + repository=AllowlistRepository.CRAN, + sre_stack=sre_project_manager, + allowlist="tidyverse\ndplyr\nnumpy", + ) + return allow + + +@fixture +def allowlist_file(mock_allowlist, tmp_path): + allowlist_file_path = tmp_path / "allowlist.txt" + with open(allowlist_file_path, "w") as f: + f.write(mock_allowlist.allowlist) + return allowlist_file_path + + +@fixture +def mock_project_output(request): + if request == "allowlist_share_filenames": + return { + "cran": "cran.allowlist", + "pypi": "pypi.allowlist", + } + elif request == "data": + return {"storage_account_data_configuration_name": "test"} + elif request == "sre_resource_group": + return "test" + + +class TestShowAllowlist: + def test_show( + self, + mocker, + runner, + mock_azuresdk_get_credential, # noqa: ARG002 + mock_azuresdk_get_subscription, # noqa: ARG002 + mock_sre_config_from_remote, # noqa: ARG002 + mock_pulumi_config_no_key_from_remote, # noqa: ARG002, + mock_allowlist, + ) -> None: + sre_name = "sandbox" + repository = "cran" + mocker.patch.object(Allowlist, "from_remote", return_value=mock_allowlist) + result = runner.invoke( + allowlist_command_group, + ["show", sre_name, repository], + ) + assert result.exit_code == 0 + assert "tidyverse\ndplyr\nnumpy" in result.output + + +class TestTemplateAllowlist: + @mark.parametrize( + "repository", + [ + "cran", + "pypi", + ], + ) + def test_template(self, runner, repository) -> None: + + result = runner.invoke( + allowlist_command_group, + ["template", repository], + ) + assert result.exit_code == 0 + if repository == "cran": + assert "DBI\nMASS" in result.output + elif repository == "pypi": + assert "numpy\npackaging" in result.output + + +class TestUploadAllowlist: + @mark.parametrize( + "repository", + [ + "cran", + "pypi", + ], + ) + def test_upload_no_remote( + self, + mocker, + runner, + repository, + allowlist_file, + mock_azuresdk_get_subscription, # noqa: ARG002 + mock_pulumi_config_no_key_from_remote, # noqa: ARG002 + mock_sre_config_from_remote, # noqa: ARG002 + mock_azuresdk_get_credential, # noqa: ARG002 + mock_project_output, + ) -> None: + sre_name = "sandbox" + mocker.patch.object( + SREProjectManager, + "output", + wraps=mock_project_output, + ) + + mocker.patch.object(Allowlist, "remote_exists", return_value=False) + mocker.patch.object(AzureSdk, "upload_file_share", return_value=None) + + result = runner.invoke( + allowlist_command_group, + ["upload", sre_name, str(allowlist_file), repository], + ) + assert result.exit_code == 0 + + @mark.parametrize( + "repository", + [ + "cran", + "pypi", + ], + ) + def test_upload_remote_exists_no_diff( + self, + mocker, + runner, + repository, + allowlist_file, + mock_allowlist, + mock_azuresdk_get_subscription, # noqa: ARG002 + mock_project_output, + mock_pulumi_config_no_key_from_remote, # noqa: ARG002 + mock_sre_config_from_remote, # noqa: ARG002 + mock_azuresdk_get_credential, # noqa: ARG002 + ) -> None: + sre_name = "sandbox" + mocker.patch.object( + SREProjectManager, + "output", + wraps=mock_project_output, + ) + mocker.patch.object(Allowlist, "remote_exists", return_value=True) + mocker.patch.object(Allowlist, "from_remote", return_value=mock_allowlist) + mocker.patch.object(Allowlist, "diff", return_value=[]) + + result = runner.invoke( + allowlist_command_group, + ["upload", sre_name, str(allowlist_file), repository], + ) + assert result.exit_code == 0 + assert "No changes, won't upload allowlist." in result.output + + def test_upload_remote_exists_with_diff( + self, + mocker, + runner, + allowlist_file, + mock_allowlist, + mock_azuresdk_get_subscription, # noqa: ARG002 + mock_pulumi_config_no_key_from_remote, # noqa: ARG002 + mock_sre_config_from_remote, # noqa: ARG002 + mock_azuresdk_get_credential, # noqa: ARG002 + mock_project_output, + ) -> None: + sre_name = "sandbox" + repository = "cran" + mocker.patch.object( + SREProjectManager, + "output", + wraps=mock_project_output, + ) + mocker.patch.object(AzureSdk, "upload_file_share", return_value=None) + mocker.patch.object(Allowlist, "remote_exists", return_value=True) + mocker.patch.object(Allowlist, "from_remote", return_value=mock_allowlist) + mocker.patch.object(Allowlist, "diff", return_value=["-numpy", "+pandas"]) + + result = runner.invoke( + allowlist_command_group, + ["upload", sre_name, str(allowlist_file), repository], + input="y\n", + ) + + assert "-numpy" in result.output + assert result.exit_code == 0 + assert "An allowlist already exists" in result.output + assert "Uploading allowlist for CRAN to sandbox" in result.output diff --git a/tests/external/api/test_azure_sdk.py b/tests/external/api/test_azure_sdk.py index 1cb2bd2e95..e76a34472d 100644 --- a/tests/external/api/test_azure_sdk.py +++ b/tests/external/api/test_azure_sdk.py @@ -55,6 +55,70 @@ def mock_blob_client( ) +@fixture +def mock_share_client(monkeypatch): + class MockShareFileClient: + def __init__(self, file_name): + self.file_name = file_name + + def exists(self): + if self.file_name == "exists": + return True + else: + return False + + class MockShareClient: + def __init__( + self, + resource_group_name, + storage_account_name, + file_share_name, + ): + self.resource_group_name = resource_group_name + self.storage_account_name = storage_account_name + self.file_share_name = file_share_name + + def get_file_client(self, file_name): + return MockShareFileClient( + file_name, + ) + + def mock_share_client( + self, # noqa: ARG001 + resource_group_name, + storage_account_name, + file_share_name, + ): + return MockShareClient( + resource_group_name, + storage_account_name, + file_share_name, + ) + + monkeypatch.setattr( + data_safe_haven.external.api.azure_sdk.AzureSdk, + "share_client", + mock_share_client, + ) + + +@fixture +def mock_share_service_client(monkeypatch): + class MockShareServiceClient: + def __init__(self, resource_group_name, storage_account_name): + self.resource_group_name = resource_group_name + self.storage_account_name = storage_account_name + + def list_shares(self): + return ["file_share_name", "file_share_name2"] + + monkeypatch.setattr( + data_safe_haven.external.api.azure_sdk.AzureSdk, + "share_service_client", + MockShareServiceClient, + ) + + @fixture def mock_key_client(monkeypatch): class MockKeyClient: @@ -236,6 +300,25 @@ def test_blob_does_not_exist( "storage_account", ) + def test_file_share_exists( + self, mock_share_client, mock_storage_exists # noqa: ARG002 + ): + sdk = AzureSdk("subscription name") + exists = sdk.file_share_exists( + "exists", "resource_group", "storage_account", "file_share_name" + ) + assert isinstance(exists, bool) + assert exists + + mock_storage_exists.assert_called_once_with( + "storage_account", + ) + + def test_file_share_list(self, mock_share_service_client): # noqa: ARG002 + sdk = AzureSdk("subscription name") + shares = sdk.list_shares("resource_group", "storage_account") + assert shares == ["file_share_name", "file_share_name2"] + def test_get_keyvault_key(self, mock_key_client): # noqa: ARG002 sdk = AzureSdk("subscription name") key = sdk.get_keyvault_key("exists", "key vault name")