Skip to content

Commit 09de44b

Browse files
authored
Merge pull request #1815 from alan-turing-institute/unique_list_validator
Add UniqueList annotated type
2 parents 8e76a2f + 7e0e7a3 commit 09de44b

File tree

6 files changed

+62
-21
lines changed

6 files changed

+62
-21
lines changed

data_safe_haven/config/config.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
b64decode,
2727
b64encode,
2828
)
29+
from data_safe_haven.functions.validators import validate_unique_list
2930
from data_safe_haven.utility import (
3031
DatabaseSystem,
3132
LoggingSingleton,
@@ -39,6 +40,7 @@
3940
Guid,
4041
IpAddress,
4142
TimeZone,
43+
UniqueList,
4244
)
4345

4446
from .context_settings import Context
@@ -150,7 +152,9 @@ def update(
150152

151153

152154
class ConfigSectionSRE(BaseModel, validate_assignment=True):
153-
databases: list[DatabaseSystem] = Field(..., default_factory=list[DatabaseSystem])
155+
databases: UniqueList[DatabaseSystem] = Field(
156+
..., default_factory=list[DatabaseSystem]
157+
)
154158
data_provider_ip_addresses: list[IpAddress] = Field(
155159
..., default_factory=list[IpAddress]
156160
)
@@ -164,17 +168,6 @@ class ConfigSectionSRE(BaseModel, validate_assignment=True):
164168
)
165169
software_packages: SoftwarePackageCategory = SoftwarePackageCategory.NONE
166170

167-
@field_validator("databases")
168-
@classmethod
169-
def all_databases_must_be_unique(
170-
cls, v: list[DatabaseSystem | str]
171-
) -> list[DatabaseSystem]:
172-
v_ = [DatabaseSystem(d) for d in v]
173-
if len(v_) != len(set(v_)):
174-
msg = "all databases must be unique"
175-
raise ValueError(msg)
176-
return v_
177-
178171
def update(
179172
self,
180173
*,
@@ -258,9 +251,7 @@ def all_sre_indices_must_be_unique(
258251
cls, v: dict[str, ConfigSectionSRE]
259252
) -> dict[str, ConfigSectionSRE]:
260253
indices = [s.index for s in v.values()]
261-
if len(indices) != len(set(indices)):
262-
msg = "all SRE indices must be unique"
263-
raise ValueError(msg)
254+
validate_unique_list(indices)
264255
return v
265256

266257
@property

data_safe_haven/functions/validators.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import ipaddress
22
import re
3+
from collections.abc import Hashable
4+
from typing import TypeVar
35

46
import fqdn
57
import pytz
@@ -57,3 +59,13 @@ def validate_timezone(timezone: str) -> str:
5759
msg = "Expected valid timezone, for example 'Europe/London'."
5860
raise ValueError(msg)
5961
return timezone
62+
63+
64+
TH = TypeVar("TH", bound=Hashable)
65+
66+
67+
def validate_unique_list(items: list[TH]) -> list[TH]:
68+
if len(items) != len(set(items)):
69+
msg = "All items must be unique."
70+
raise ValueError(msg)
71+
return items

data_safe_haven/utility/annotated_types.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Annotated
1+
from collections.abc import Hashable
2+
from typing import Annotated, TypeAlias, TypeVar
23

34
from pydantic import Field
45
from pydantic.functional_validators import AfterValidator
@@ -11,6 +12,7 @@
1112
validate_fqdn,
1213
validate_ip_address,
1314
validate_timezone,
15+
validate_unique_list,
1416
)
1517

1618
AzureShortName = Annotated[str, Field(min_length=1, max_length=24)]
@@ -22,3 +24,9 @@
2224
Guid = Annotated[str, AfterValidator(validate_aad_guid)]
2325
IpAddress = Annotated[str, AfterValidator(validate_ip_address)]
2426
TimeZone = Annotated[str, AfterValidator(validate_timezone)]
27+
TH = TypeVar("TH", bound=Hashable)
28+
# type UniqueList[TH] = Annotated[list[TH], AfterValidator(validate_unique_list)]
29+
# mypy doesn't support PEP695 type statements
30+
UniqueList: TypeAlias = Annotated[ # noqa:UP040
31+
list[TH], AfterValidator(validate_unique_list)
32+
]

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,15 @@ all = [
9191

9292
[tool.hatch.envs.test]
9393
dependencies = [
94-
"pytest~=7.4"
94+
"pytest~=8.1"
9595
]
9696
pre-install-commands = ["pip install -r requirements.txt"]
9797

9898
[tool.hatch.envs.test.scripts]
9999
test = "pytest {args:-vvv tests}"
100100

101101
[tool.black]
102-
target-version = ["py311", "py312"]
102+
target-version = ["py312"]
103103

104104
[tool.ruff.lint]
105105
select = [

tests/config/test_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def test_constructor_defaults(self, remote_desktop_config):
137137
assert sre_config.software_packages == SoftwarePackageCategory.NONE
138138

139139
def test_all_databases_must_be_unique(self):
140-
with pytest.raises(ValueError, match="all databases must be unique"):
140+
with pytest.raises(ValueError, match="All items must be unique."):
141141
ConfigSectionSRE(
142142
index=1,
143143
databases=[DatabaseSystem.POSTGRESQL, DatabaseSystem.POSTGRESQL],
@@ -239,7 +239,7 @@ def test_constructor(self, context, azure_config, pulumi_config, shm_config):
239239
def test_all_sre_indices_must_be_unique(
240240
self, context, azure_config, pulumi_config, shm_config
241241
):
242-
with pytest.raises(ValueError, match="all SRE indices must be unique"):
242+
with pytest.raises(ValueError, match="All items must be unique."):
243243
sre_config_1 = ConfigSectionSRE(index=1)
244244
sre_config_2 = ConfigSectionSRE(index=1)
245245
Config(

tests/functions/test_validators.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import pytest
22

3-
from data_safe_haven.functions.validators import validate_aad_guid, validate_fqdn
3+
from data_safe_haven.functions.validators import (
4+
validate_aad_guid,
5+
validate_fqdn,
6+
validate_unique_list,
7+
)
8+
from data_safe_haven.utility.enums import DatabaseSystem
49

510

611
class TestValidateAadGuid:
@@ -53,3 +58,28 @@ def test_validate_fqdn_fail(self, fqdn):
5358
ValueError, match="Expected valid fully qualified domain name"
5459
):
5560
validate_fqdn(fqdn)
61+
62+
63+
class TestValidateUniqueList:
64+
@pytest.mark.parametrize(
65+
"items",
66+
[
67+
[1, 2, 3],
68+
["a", 5, len],
69+
],
70+
)
71+
def test_validate_unique_list(self, items):
72+
validate_unique_list(items)
73+
74+
@pytest.mark.parametrize(
75+
"items",
76+
[
77+
[DatabaseSystem.POSTGRESQL, DatabaseSystem.POSTGRESQL],
78+
[DatabaseSystem.POSTGRESQL, 2, DatabaseSystem.POSTGRESQL],
79+
[1, 1],
80+
["abc", "abc"],
81+
],
82+
)
83+
def test_validate_unique_list_fail(self, items):
84+
with pytest.raises(ValueError, match="All items must be unique."):
85+
validate_unique_list(items)

0 commit comments

Comments
 (0)