Skip to content

Commit 228a958

Browse files
committed
more linting
1 parent 7ec2a9b commit 228a958

File tree

6 files changed

+388
-274
lines changed

6 files changed

+388
-274
lines changed

petab/v2/conditions.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
from __future__ import annotations
44

5+
from itertools import chain
56
from pathlib import Path
67

78
import pandas as pd
89
import sympy as sp
910

1011
from .. import v2
11-
from ..v1.math import sympify_petab
1212
from .C import *
1313
from .lint import assert_no_leading_trailing_whitespace
1414

@@ -59,10 +59,11 @@ def get_condition_table_free_symbols(problem: v2.Problem) -> set[sp.Basic]:
5959
6060
:returns: Set of free symbols.
6161
"""
62-
if problem.condition_df is None:
63-
return set()
64-
65-
free_symbols = set()
66-
for target_value in problem.condition_df[TARGET_VALUE]:
67-
free_symbols |= sympify_petab(target_value).free_symbols
68-
return free_symbols
62+
return set(
63+
chain.from_iterable(
64+
change.target_value.free_symbols
65+
for condition in problem.conditions_table.conditions
66+
for change in condition.changes
67+
if change.target_value is not None
68+
)
69+
)

petab/v2/core.py

+55-8
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
from __future__ import annotations
44

5+
import re
56
from collections.abc import Sequence
67
from enum import Enum
78
from pathlib import Path
8-
from typing import Annotated
9+
from typing import Annotated, Literal
910

1011
import numpy as np
1112
import pandas as pd
@@ -46,14 +47,20 @@
4647
]
4748

4849

49-
def is_finite_or_neg_inf(v: float, info: ValidationInfo) -> float:
50+
def _is_finite_or_neg_inf(v: float, info: ValidationInfo) -> float:
5051
if not np.isfinite(v) and v != -np.inf:
5152
raise ValueError(
5253
f"{info.field_name} value must be finite or -inf but got {v}"
5354
)
5455
return v
5556

5657

58+
def _not_nan(v: float, info: ValidationInfo) -> float:
59+
if np.isnan(v):
60+
raise ValueError(f"{info.field_name} value must not be nan.")
61+
return v
62+
63+
5764
def _convert_nan_to_none(v):
5865
if isinstance(v, float) and np.isnan(v):
5966
return None
@@ -149,6 +156,11 @@ class Observable(BaseModel):
149156
alias=C.NOISE_DISTRIBUTION, default=NoiseDistribution.NORMAL
150157
)
151158

159+
#: :meta private:
160+
model_config = ConfigDict(
161+
arbitrary_types_allowed=True, populate_by_name=True
162+
)
163+
152164
@field_validator("id")
153165
@classmethod
154166
def _validate_id(cls, v):
@@ -183,10 +195,31 @@ def _sympify(cls, v):
183195

184196
return sympify_petab(v)
185197

186-
#: :meta private:
187-
model_config = ConfigDict(
188-
arbitrary_types_allowed=True, populate_by_name=True
189-
)
198+
def _placeholders(
199+
self, type_: Literal["observable", "noise"]
200+
) -> set[sp.Symbol]:
201+
# TODO: add field validator to check for 1-based consecutive numbering
202+
t = f"{re.escape(type_)}Parameter"
203+
o = re.escape(self.id)
204+
pattern = re.compile(rf"(?:^|\W)({t}\d+_{o})(?=\W|$)")
205+
formula = (
206+
self.formula
207+
if type_ == "observable"
208+
else self.noise_formula
209+
if type_ == "noise"
210+
else None
211+
)
212+
return {s for s in formula.free_symbols if pattern.match(str(s))}
213+
214+
@property
215+
def observable_placeholders(self) -> set[sp.Symbol]:
216+
"""Placeholder symbols for the observable formula."""
217+
return self._placeholders("observable")
218+
219+
@property
220+
def noise_placeholders(self) -> set[sp.Symbol]:
221+
"""Placeholder symbols for the noise formula."""
222+
return self._placeholders("noise")
190223

191224

192225
class ObservablesTable(BaseModel):
@@ -440,7 +473,7 @@ class ExperimentPeriod(BaseModel):
440473
"""
441474

442475
#: The start time of the period in time units as defined in the model.
443-
time: Annotated[float, AfterValidator(is_finite_or_neg_inf)] = Field(
476+
time: Annotated[float, AfterValidator(_is_finite_or_neg_inf)] = Field(
444477
alias=C.TIME
445478
)
446479
#: The ID of the condition to be applied at the start time.
@@ -588,7 +621,9 @@ class Measurement(BaseModel):
588621
#: The time point of the measurement in time units as defined in the model.
589622
time: float = Field(alias=C.TIME)
590623
#: The measurement value.
591-
measurement: float = Field(alias=C.MEASUREMENT)
624+
measurement: Annotated[float, AfterValidator(_not_nan)] = Field(
625+
alias=C.MEASUREMENT
626+
)
592627
#: Values for placeholder parameters in the observable formula.
593628
observable_parameters: list[sp.Basic] = Field(
594629
alias=C.OBSERVABLE_PARAMETERS, default_factory=list
@@ -794,6 +829,13 @@ def __getitem__(self, petab_id: str) -> Mapping:
794829
return mapping
795830
raise KeyError(f"PEtab ID {petab_id} not found")
796831

832+
def get(self, petab_id, default=None):
833+
"""Get a mapping by PEtab ID or return a default value."""
834+
try:
835+
return self[petab_id]
836+
except KeyError:
837+
return default
838+
797839

798840
class Parameter(BaseModel):
799841
"""Parameter definition."""
@@ -893,3 +935,8 @@ def __getitem__(self, item) -> Parameter:
893935
if parameter.id == item:
894936
return parameter
895937
raise KeyError(f"Parameter ID {item} not found")
938+
939+
@property
940+
def n_estimated(self) -> int:
941+
"""Number of estimated parameters."""
942+
return sum(p.estimate for p in self.parameters)

0 commit comments

Comments
 (0)