Skip to content

Commit 2337e2d

Browse files
dweindldilpath
andauthored
v2: More validation (#381)
* Check priors (related to #374) * Check observables * Fix missing prior parameters after v1->v2 conversion of uniform priors * Fix some pydantic validation / serialization * Fix style --------- Co-authored-by: Dilan Pathirana <[email protected]>
1 parent 84fbdba commit 2337e2d

File tree

5 files changed

+219
-32
lines changed

5 files changed

+219
-32
lines changed

petab/v2/core.py

+55-16
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,11 @@ def _not_nan(v: float, info: ValidationInfo) -> float:
7373

7474

7575
def _convert_nan_to_none(v):
76+
"""Convert NaN or "" to None."""
7677
if isinstance(v, float) and np.isnan(v):
7778
return None
79+
if isinstance(v, str) and v == "":
80+
return None
7881
return v
7982

8083

@@ -503,9 +506,17 @@ class ExperimentPeriod(BaseModel):
503506
@field_validator("condition_ids", mode="before")
504507
@classmethod
505508
def _validate_ids(cls, condition_ids):
509+
if condition_ids in [None, "", [], [""]]:
510+
# unspecified, or "use-model-as-is"
511+
return []
512+
506513
for condition_id in condition_ids:
514+
# The empty condition ID for "use-model-as-is" has been handled
515+
# above. Having a combination of empty and non-empty IDs is an
516+
# error, since the targets of conditions to be combined must be
517+
# disjoint.
507518
if not is_valid_identifier(condition_id):
508-
raise ValueError(f"Invalid ID: {condition_id}")
519+
raise ValueError(f"Invalid {C.CONDITION_ID}: `{condition_id}'")
509520
return condition_ids
510521

511522

@@ -854,17 +865,23 @@ class Parameter(BaseModel):
854865
#: Parameter ID.
855866
id: str = Field(alias=C.PARAMETER_ID)
856867
#: Lower bound.
857-
lb: float | None = Field(alias=C.LOWER_BOUND, default=None)
868+
lb: Annotated[float | None, BeforeValidator(_convert_nan_to_none)] = Field(
869+
alias=C.LOWER_BOUND, default=None
870+
)
858871
#: Upper bound.
859-
ub: float | None = Field(alias=C.UPPER_BOUND, default=None)
872+
ub: Annotated[float | None, BeforeValidator(_convert_nan_to_none)] = Field(
873+
alias=C.UPPER_BOUND, default=None
874+
)
860875
#: Nominal value.
861-
nominal_value: float | None = Field(alias=C.NOMINAL_VALUE, default=None)
876+
nominal_value: Annotated[
877+
float | None, BeforeValidator(_convert_nan_to_none)
878+
] = Field(alias=C.NOMINAL_VALUE, default=None)
862879
#: Is the parameter to be estimated?
863880
estimate: bool = Field(alias=C.ESTIMATE, default=True)
864881
#: Type of parameter prior distribution.
865-
prior_distribution: PriorDistribution | None = Field(
866-
alias=C.PRIOR_DISTRIBUTION, default=None
867-
)
882+
prior_distribution: Annotated[
883+
PriorDistribution | None, BeforeValidator(_convert_nan_to_none)
884+
] = Field(alias=C.PRIOR_DISTRIBUTION, default=None)
868885
#: Prior distribution parameters.
869886
prior_parameters: list[float] = Field(
870887
alias=C.PRIOR_PARAMETERS, default_factory=list
@@ -889,8 +906,18 @@ def _validate_id(cls, v):
889906

890907
@field_validator("prior_parameters", mode="before")
891908
@classmethod
892-
def _validate_prior_parameters(cls, v):
909+
def _validate_prior_parameters(
910+
cls, v: str | list[str] | float | None | np.ndarray
911+
):
912+
if v is None:
913+
return []
914+
915+
if isinstance(v, float) and np.isnan(v):
916+
return []
917+
893918
if isinstance(v, str):
919+
if v == "":
920+
return []
894921
v = v.split(C.PARAMETER_SEPARATOR)
895922
elif not isinstance(v, Sequence):
896923
v = [v]
@@ -899,7 +926,7 @@ def _validate_prior_parameters(cls, v):
899926

900927
@field_validator("estimate", mode="before")
901928
@classmethod
902-
def _validate_estimate_before(cls, v):
929+
def _validate_estimate_before(cls, v: bool | str):
903930
if isinstance(v, bool):
904931
return v
905932

@@ -918,12 +945,17 @@ def _validate_estimate_before(cls, v):
918945
def _serialize_estimate(self, estimate: bool, _info):
919946
return str(estimate).lower()
920947

921-
@field_validator("lb", "ub", "nominal_value")
922-
@classmethod
923-
def _convert_nan_to_none(cls, v):
924-
if isinstance(v, float) and np.isnan(v):
925-
return None
926-
return v
948+
@field_serializer("prior_distribution")
949+
def _serialize_prior_distribution(
950+
self, prior_distribution: PriorDistribution | None, _info
951+
):
952+
if prior_distribution is None:
953+
return ""
954+
return str(prior_distribution)
955+
956+
@field_serializer("prior_parameters")
957+
def _serialize_prior_parameters(self, prior_parameters: list[str], _info):
958+
return C.PARAMETER_SEPARATOR.join(prior_parameters)
927959

928960
@model_validator(mode="after")
929961
def _validate(self) -> Self:
@@ -952,7 +984,7 @@ def _validate(self) -> Self:
952984

953985
@property
954986
def prior_dist(self) -> Distribution:
955-
"""Get the pior distribution of the parameter."""
987+
"""Get the prior distribution of the parameter."""
956988
if self.estimate is False:
957989
raise ValueError(f"Parameter `{self.id}' is not estimated.")
958990

@@ -980,6 +1012,13 @@ def prior_dist(self) -> Distribution:
9801012
"transformation."
9811013
)
9821014
return cls(*self.prior_parameters, trunc=[self.lb, self.ub])
1015+
1016+
if cls == Uniform:
1017+
# `Uniform.__init__` does not accept the `trunc` parameter
1018+
low = max(self.prior_parameters[0], self.lb)
1019+
high = min(self.prior_parameters[1], self.ub)
1020+
return cls(low, high, log=log)
1021+
9831022
return cls(*self.prior_parameters, log=log, trunc=[self.lb, self.ub])
9841023

9851024

0 commit comments

Comments
 (0)