Skip to content

Commit 2576270

Browse files
committed
Store problem configuration in Problem
Introduces Problem.config which contains the info from the PEtab yaml file. Sometimes it is convenient to have the original filenames around. Closes PEtab-dev#324.
1 parent 9a4efb4 commit 2576270

File tree

2 files changed

+78
-27
lines changed

2 files changed

+78
-27
lines changed

petab/v1/problem.py

+77-27
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from warnings import warn
1111

1212
import pandas as pd
13+
from pydantic import AnyUrl, BaseModel, Field, RootModel
1314

1415
from . import (
1516
conditions,
@@ -78,6 +79,7 @@ def __init__(
7879
observable_df: pd.DataFrame = None,
7980
mapping_df: pd.DataFrame = None,
8081
extensions_config: dict = None,
82+
config: ProblemConfig = None,
8183
):
8284
self.condition_df: pd.DataFrame | None = condition_df
8385
self.measurement_df: pd.DataFrame | None = measurement_df
@@ -112,6 +114,7 @@ def __init__(
112114

113115
self.model: Model | None = model
114116
self.extensions_config = extensions_config or {}
117+
self.config = config
115118

116119
def __getattr__(self, name):
117120
# For backward-compatibility, allow access to SBML model related
@@ -261,10 +264,14 @@ def from_yaml(
261264
yaml_config: PEtab configuration as dictionary or YAML file name
262265
base_path: Base directory or URL to resolve relative paths
263266
"""
267+
# path to the yaml file
268+
filepath = None
269+
264270
if isinstance(yaml_config, Path):
265271
yaml_config = str(yaml_config)
266272

267273
if isinstance(yaml_config, str):
274+
filepath = yaml_config
268275
if base_path is None:
269276
base_path = get_path_prefix(yaml_config)
270277
yaml_config = yaml.load_yaml(yaml_config)
@@ -296,59 +303,58 @@ def get_path(filename):
296303
DeprecationWarning,
297304
stacklevel=2,
298305
)
306+
config = ProblemConfig(
307+
**yaml_config, base_path=base_path, filepath=filepath
308+
)
309+
problem0 = config.problems[0]
310+
# currently required for handling PEtab v2 in here
311+
problem0_ = yaml_config["problems"][0]
299312

300-
problem0 = yaml_config["problems"][0]
301-
302-
if isinstance(yaml_config[PARAMETER_FILE], list):
313+
if isinstance(config.parameter_file, list):
303314
parameter_df = parameters.get_parameter_df(
304-
[get_path(f) for f in yaml_config[PARAMETER_FILE]]
315+
[get_path(f) for f in config.parameter_file]
305316
)
306317
else:
307318
parameter_df = (
308-
parameters.get_parameter_df(
309-
get_path(yaml_config[PARAMETER_FILE])
310-
)
311-
if yaml_config[PARAMETER_FILE]
319+
parameters.get_parameter_df(get_path(config.parameter_file))
320+
if config.parameter_file
312321
else None
313322
)
314-
315-
if yaml_config[FORMAT_VERSION] in [1, "1", "1.0.0"]:
316-
if len(problem0[SBML_FILES]) > 1:
323+
if config.format_version.root in [1, "1", "1.0.0"]:
324+
if len(problem0.sbml_files) > 1:
317325
# TODO https://github.com/PEtab-dev/libpetab-python/issues/6
318326
raise NotImplementedError(
319327
"Support for multiple models is not yet implemented."
320328
)
321329

322330
model = (
323331
model_factory(
324-
get_path(problem0[SBML_FILES][0]),
332+
get_path(problem0.sbml_files[0]),
325333
MODEL_TYPE_SBML,
326334
model_id=None,
327335
)
328-
if problem0[SBML_FILES]
336+
if problem0.sbml_files
329337
else None
330338
)
331339
else:
332-
if len(problem0[MODEL_FILES]) > 1:
340+
if len(problem0_[MODEL_FILES]) > 1:
333341
# TODO https://github.com/PEtab-dev/libpetab-python/issues/6
334342
raise NotImplementedError(
335343
"Support for multiple models is not yet implemented."
336344
)
337-
if not problem0[MODEL_FILES]:
345+
if not problem0_[MODEL_FILES]:
338346
model = None
339347
else:
340348
model_id, model_info = next(
341-
iter(problem0[MODEL_FILES].items())
349+
iter(problem0_[MODEL_FILES].items())
342350
)
343351
model = model_factory(
344352
get_path(model_info[MODEL_LOCATION]),
345353
model_info[MODEL_LANGUAGE],
346354
model_id=model_id,
347355
)
348356

349-
measurement_files = [
350-
get_path(f) for f in problem0.get(MEASUREMENT_FILES, [])
351-
]
357+
measurement_files = [get_path(f) for f in problem0.measurement_files]
352358
# If there are multiple tables, we will merge them
353359
measurement_df = (
354360
core.concat_tables(
@@ -358,9 +364,7 @@ def get_path(filename):
358364
else None
359365
)
360366

361-
condition_files = [
362-
get_path(f) for f in problem0.get(CONDITION_FILES, [])
363-
]
367+
condition_files = [get_path(f) for f in problem0.condition_files]
364368
# If there are multiple tables, we will merge them
365369
condition_df = (
366370
core.concat_tables(condition_files, conditions.get_condition_df)
@@ -369,7 +373,7 @@ def get_path(filename):
369373
)
370374

371375
visualization_files = [
372-
get_path(f) for f in problem0.get(VISUALIZATION_FILES, [])
376+
get_path(f) for f in problem0.visualization_files
373377
]
374378
# If there are multiple tables, we will merge them
375379
visualization_df = (
@@ -378,17 +382,15 @@ def get_path(filename):
378382
else None
379383
)
380384

381-
observable_files = [
382-
get_path(f) for f in problem0.get(OBSERVABLE_FILES, [])
383-
]
385+
observable_files = [get_path(f) for f in problem0.observable_files]
384386
# If there are multiple tables, we will merge them
385387
observable_df = (
386388
core.concat_tables(observable_files, observables.get_observable_df)
387389
if observable_files
388390
else None
389391
)
390392

391-
mapping_files = [get_path(f) for f in problem0.get(MAPPING_FILES, [])]
393+
mapping_files = [get_path(f) for f in problem0_.get(MAPPING_FILES, [])]
392394
# If there are multiple tables, we will merge them
393395
mapping_df = (
394396
core.concat_tables(mapping_files, mapping.get_mapping_df)
@@ -405,6 +407,7 @@ def get_path(filename):
405407
visualization_df=visualization_df,
406408
mapping_df=mapping_df,
407409
extensions_config=yaml_config.get(EXTENSIONS, {}),
410+
config=config,
408411
)
409412

410413
@staticmethod
@@ -1005,3 +1008,50 @@ def n_priors(self) -> int:
10051008
return 0
10061009

10071010
return self.parameter_df[OBJECTIVE_PRIOR_PARAMETERS].notna().sum()
1011+
1012+
1013+
class VersionNumber(RootModel):
1014+
root: str | int
1015+
1016+
1017+
class ListOfFiles(RootModel):
1018+
"""List of files."""
1019+
1020+
root: list[str | AnyUrl] = Field(..., description="List of files.")
1021+
1022+
def __iter__(self):
1023+
return iter(self.root)
1024+
1025+
def __len__(self):
1026+
return len(self.root)
1027+
1028+
def __getitem__(self, index):
1029+
return self.root[index]
1030+
1031+
1032+
class SubProblem(BaseModel):
1033+
"""A `problems` object in the PEtab problem configuration."""
1034+
1035+
sbml_files: ListOfFiles = []
1036+
measurement_files: ListOfFiles = []
1037+
condition_files: ListOfFiles = []
1038+
observable_files: ListOfFiles = []
1039+
visualization_files: ListOfFiles = []
1040+
1041+
1042+
class ProblemConfig(BaseModel):
1043+
"""The PEtab problem configuration."""
1044+
1045+
filepath: str | AnyUrl | None = Field(
1046+
None,
1047+
description="The path to the PEtab problem configuration.",
1048+
exclude=True,
1049+
)
1050+
base_path: str | AnyUrl | None = Field(
1051+
None,
1052+
description="The base path to resolve relative paths.",
1053+
exclude=True,
1054+
)
1055+
format_version: VersionNumber = 1
1056+
parameter_file: str | AnyUrl | None = None
1057+
problems: list[SubProblem] = []

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ dependencies = [
2222
"pyyaml",
2323
"jsonschema",
2424
"antlr4-python3-runtime==4.13.1",
25+
"pydantic>=2.10",
2526
]
2627
license = {text = "MIT License"}
2728
authors = [

0 commit comments

Comments
 (0)