diff --git a/pytest.ini b/pytest.ini index 03d50d80e1..69caec1e48 100644 --- a/pytest.ini +++ b/pytest.ini @@ -27,4 +27,5 @@ filterwarnings = # ignore jax deprecation warnings ignore:jax.* is deprecated:DeprecationWarning + norecursedirs = .git amici_models build doc documentation matlab models ThirdParty amici sdist examples diff --git a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb index 848f6521ff..d693c2594a 100644 --- a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -457,6 +457,8 @@ "my = jax_problem._my[ic, :]\n", "iys = jax_problem._iys[ic, :]\n", "iy_trafos = jax_problem._iy_trafos[ic, :]\n", + "ops = jax_problem._op_numeric[ic, :]\n", + "nps = jax_problem._np_numeric[ic, :]\n", "\n", "# Load parameters for the specified condition\n", "p = jax_problem.load_parameters(simulation_condition[0])\n", @@ -472,6 +474,8 @@ " my=jnp.array(my),\n", " iys=jnp.array(iys),\n", " iy_trafos=jnp.array(iy_trafos),\n", + " ops=jnp.array(ops),\n", + " nps=jnp.array(nps),\n", " solver=diffrax.Kvaerno5(),\n", " controller=diffrax.PIDController(atol=1e-8, rtol=1e-8),\n", " steady_state_event=diffrax.steady_state_event(),\n", diff --git a/python/sdist/amici/constants.py b/python/sdist/amici/constants.py index 74b365889c..346dc1c9ab 100644 --- a/python/sdist/amici/constants.py +++ b/python/sdist/amici/constants.py @@ -34,3 +34,5 @@ class SymbolId(str, enum.Enum): SIGMAZ = "sigmaz" LLHZ = "llhz" LLHRZ = "llhrz" + NOISE_PARAMETER = "noise_parameter" + OBSERVABLE_PARAMETER = "observable_parameter" diff --git a/python/sdist/amici/de_model.py b/python/sdist/amici/de_model.py index 8ad2e7a998..463f78f927 100644 --- a/python/sdist/amici/de_model.py +++ b/python/sdist/amici/de_model.py @@ -35,6 +35,8 @@ LogLikelihoodY, LogLikelihoodZ, LogLikelihoodRZ, + NoiseParameter, + ObservableParameter, Expression, ConservationLaw, Event, @@ -226,6 +228,8 @@ def __init__( self._log_likelihood_ys: list[LogLikelihoodY] = [] self._log_likelihood_zs: list[LogLikelihoodZ] = [] self._log_likelihood_rzs: list[LogLikelihoodRZ] = [] + self._noise_parameters: list[NoiseParameter] = [] + self._observable_parameters: list[ObservableParameter] = [] self._expressions: list[Expression] = [] self._conservation_laws: list[ConservationLaw] = [] self._events: list[Event] = [] @@ -273,6 +277,8 @@ def __init__( "sigmay": self.sigma_ys, "sigmaz": self.sigma_zs, "h": self.events, + "np": self.noise_parameters, + "op": self.observable_parameters, } self._value_prototype: dict[str, Callable] = { "p": self.parameters, @@ -385,6 +391,14 @@ def log_likelihood_rzs(self) -> list[LogLikelihoodRZ]: """Get all event observable regularization log likelihoods.""" return self._log_likelihood_rzs + def noise_parameters(self) -> list[NoiseParameter]: + """Get all noise parameters.""" + return self._noise_parameters + + def observable_parameters(self) -> list[ObservableParameter]: + """Get all observable parameters.""" + return self._observable_parameters + def is_ode(self) -> bool: """Check if model is ODE model.""" return len(self._algebraic_equations) == 0 @@ -565,6 +579,8 @@ def add_component( ConservationLaw, Event, EventObservable, + NoiseParameter, + ObservableParameter, }: raise ValueError(f"Invalid component type {type(component)}") @@ -1087,6 +1103,26 @@ def _generate_symbol(self, name: str) -> None: """ if name in self._variable_prototype: components = self._variable_prototype[name]() + # ensure placeholder parameters are consistently and correctly ordered + # we want that components are ordered by their placeholder index + if name == "op": + components = sorted( + components, + key=lambda x: int( + str(strip_pysb(x.get_id())).replace( + "observableParameter", "" + ) + ), + ) + if name == "np": + components = sorted( + components, + key=lambda x: int( + str(strip_pysb(x.get_id())).replace( + "noiseParameter", "" + ) + ), + ) self._syms[name] = sp.Matrix( [comp.get_id() for comp in components] ) diff --git a/python/sdist/amici/de_model_components.py b/python/sdist/amici/de_model_components.py index bc93f44b87..30624dbc9e 100644 --- a/python/sdist/amici/de_model_components.py +++ b/python/sdist/amici/de_model_components.py @@ -607,6 +607,46 @@ def __init__( super().__init__(identifier, name, value) +class NoiseParameter(ModelQuantity): + """ + A NoiseParameter is an input variable for the computation of ``sigma`` that can be specified in a data-point + specific manner, abbreviated by ``np``. Only used for jax models. + """ + + def __init__(self, identifier: sp.Symbol, name: str): + """ + Create a new Expression instance. + + :param identifier: + unique identifier of the NoiseParameter + + :param name: + individual name of the NoiseParameter (does not need to be + unique) + """ + super().__init__(identifier, name, 0.0) + + +class ObservableParameter(ModelQuantity): + """ + A NoiseParameter is an input variable for the computation of ``y`` that can be specified in a data-point specific + manner, abbreviated by ``op``. Only used for jax models. + """ + + def __init__(self, identifier: sp.Symbol, name: str): + """ + Create a new Expression instance. + + :param identifier: + unique identifier of the ObservableParameter + + :param name: + individual name of the ObservableParameter (does not need to be + unique) + """ + super().__init__(identifier, name, 0.0) + + class LogLikelihood(ModelQuantity): """ A LogLikelihood defines the distance between measurements and @@ -751,4 +791,6 @@ def get_trigger_time(self) -> sp.Float: SymbolId.LLHRZ: LogLikelihoodRZ, SymbolId.EXPRESSION: Expression, SymbolId.EVENT: Event, + SymbolId.NOISE_PARAMETER: NoiseParameter, + SymbolId.OBSERVABLE_PARAMETER: ObservableParameter, } diff --git a/python/sdist/amici/jax/jax.template.py b/python/sdist/amici/jax/jax.template.py index 5d5521d222..f78561fd55 100644 --- a/python/sdist/amici/jax/jax.template.py +++ b/python/sdist/amici/jax/jax.template.py @@ -64,28 +64,30 @@ def _tcl(self, x, p): return TPL_TOTAL_CL_RET - def _y(self, t, x, p, tcl): + def _y(self, t, x, p, tcl, op): TPL_X_SYMS = x TPL_P_SYMS = p TPL_W_SYMS = self._w(t, x, p, tcl) + TPL_OP_SYMS = op TPL_Y_EQ return TPL_Y_RET - def _sigmay(self, y, p): + def _sigmay(self, y, p, np): TPL_P_SYMS = p TPL_Y_SYMS = y + TPL_NP_SYMS = np TPL_SIGMAY_EQ return TPL_SIGMAY_RET - def _nllh(self, t, x, p, tcl, my, iy): - y = self._y(t, x, p, tcl) + def _nllh(self, t, x, p, tcl, my, iy, op, np): + y = self._y(t, x, p, tcl, op) TPL_Y_SYMS = y - TPL_SIGMAY_SYMS = self._sigmay(y, p) + TPL_SIGMAY_SYMS = self._sigmay(y, p, np) TPL_JY_EQ diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 616431dd94..da5b2f9e56 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -43,7 +43,7 @@ class JAXModel(eqx.Module): Path to the JAX model file. """ - MODEL_API_VERSION = "0.0.2" + MODEL_API_VERSION = "0.0.3" api_version: str jax_py_file: Path @@ -77,7 +77,7 @@ def _w( self, t: jt.Float[jt.Array, ""], x: jt.Float[jt.Array, "nxs"], - pk: jt.Float[jt.Array, "np"], + p: jt.Float[jt.Array, "np"], tcl: jt.Float[jt.Array, "ncl"], ) -> jt.Float[jt.Array, "nw"]: """ @@ -85,7 +85,7 @@ def _w( :param t: time point :param x: state vector - :param pk: parameters + :param p: parameters :param tcl: total values for conservation laws :return: Expression values. @@ -93,11 +93,11 @@ def _w( ... @abstractmethod - def _x0(self, pk: jt.Float[jt.Array, "np"]) -> jt.Float[jt.Array, "nx"]: + def _x0(self, p: jt.Float[jt.Array, "np"]) -> jt.Float[jt.Array, "nx"]: """ Compute the initial state vector. - :param pk: parameters + :param p: parameters """ ... @@ -133,14 +133,14 @@ def _x_rdata( @abstractmethod def _tcl( - self, x: jt.Float[jt.Array, "nx"], pk: jt.Float[jt.Array, "np"] + self, x: jt.Float[jt.Array, "nx"], p: jt.Float[jt.Array, "np"] ) -> jt.Float[jt.Array, "ncl"]: """ Compute the total values for conservation laws. :param x: state vector - :param pk: + :param p: parameters :return: total values for conservation laws @@ -152,8 +152,9 @@ def _y( self, t: jt.Float[jt.Scalar, ""], x: jt.Float[jt.Array, "nxs"], - pk: jt.Float[jt.Array, "np"], + p: jt.Float[jt.Array, "np"], tcl: jt.Float[jt.Array, "ncl"], + op: jt.Float[jt.Array, "ny"], ) -> jt.Float[jt.Array, "ny"]: """ Compute the observables. @@ -162,10 +163,12 @@ def _y( time point :param x: state vector - :param pk: + :param p: parameters :param tcl: total values for conservation laws + :param op: + observables parameters :return: observables """ @@ -173,15 +176,20 @@ def _y( @abstractmethod def _sigmay( - self, y: jt.Float[jt.Array, "ny"], pk: jt.Float[jt.Array, "np"] + self, + y: jt.Float[jt.Array, "ny"], + p: jt.Float[jt.Array, "np"], + np: jt.Float[jt.Array, "ny"], ) -> jt.Float[jt.Array, "ny"]: """ Compute the standard deviations of the observables. :param y: observables - :param pk: + :param p: parameters + :param np: + noise parameters :return: standard deviations of the observables """ @@ -192,10 +200,12 @@ def _nllh( self, t: jt.Float[jt.Scalar, ""], x: jt.Float[jt.Array, "nxs"], - pk: jt.Float[jt.Array, "np"], + p: jt.Float[jt.Array, "np"], tcl: jt.Float[jt.Array, "ncl"], my: jt.Float[jt.Array, ""], iy: jt.Int[jt.Array, ""], + op: jt.Float[jt.Array, "ny"], + np: jt.Float[jt.Array, "ny"], ) -> jt.Float[jt.Scalar, ""]: """ Compute the negative log-likelihood of the observable for the specified observable index. @@ -204,7 +214,7 @@ def _nllh( time point :param x: state vector - :param pk: + :param p: parameters :param tcl: total values for conservation laws @@ -212,6 +222,10 @@ def _nllh( observed data :param iy: observable index + :param op: + observables parameters + :param np: + noise parameters :return: log-likelihood of the observable """ @@ -377,6 +391,8 @@ def _nllhs( tcl: jt.Float[jt.Array, "ncl"], mys: jt.Float[jt.Array, "nt"], iys: jt.Int[jt.Array, "nt"], + ops: jt.Float[jt.Array, "nt *nop"], + nps: jt.Float[jt.Array, "nt *nnp"], ) -> jt.Float[jt.Array, "nt"]: """ Compute the negative log-likelihood for each observable. @@ -393,11 +409,15 @@ def _nllhs( observed data :param iys: observable indices + :param ops: + observables parameters + :param nps: + noise parameters :return: negative log-likelihoods of the observables """ - return jax.vmap(self._nllh, in_axes=(0, 0, None, None, 0, 0))( - ts, xs, p, tcl, mys, iys + return jax.vmap(self._nllh, in_axes=(0, 0, None, None, 0, 0, 0, 0))( + ts, xs, p, tcl, mys, iys, ops, nps ) def _ys( @@ -407,6 +427,7 @@ def _ys( p: jt.Float[jt.Array, "np"], tcl: jt.Float[jt.Array, "ncl"], iys: jt.Float[jt.Array, "nt"], + ops: jt.Float[jt.Array, "nt *nop"], ) -> jt.Int[jt.Array, "nt"]: """ Compute the observables. @@ -421,13 +442,17 @@ def _ys( total values for conservation laws :param iys: observable indices + :param ops: + observables parameters :return: observables """ return jax.vmap( - lambda t, x, p, tcl, iy: self._y(t, x, p, tcl).at[iy].get(), - in_axes=(0, 0, None, None, 0), - )(ts, xs, p, tcl, iys) + lambda t, x, p, tcl, iy, op: self._y(t, x, p, tcl, op) + .at[iy] + .get(), + in_axes=(0, 0, None, None, 0, 0), + )(ts, xs, p, tcl, iys, ops) def _sigmays( self, @@ -436,6 +461,8 @@ def _sigmays( p: jt.Float[jt.Array, "np"], tcl: jt.Float[jt.Array, "ncl"], iys: jt.Int[jt.Array, "nt"], + ops: jt.Float[jt.Array, "nt *nop"], + nps: jt.Float[jt.Array, "nt *nnp"], ): """ Compute the standard deviations of the observables. @@ -450,15 +477,21 @@ def _sigmays( total values for conservation laws :param iys: observable indices + :param ops: + observables parameters + :param nps: + noise parameters :return: standard deviations of the observables """ return jax.vmap( - lambda t, x, p, tcl, iy: self._sigmay(self._y(t, x, p, tcl), p) + lambda t, x, p, tcl, iy, op, np: self._sigmay( + self._y(t, x, p, tcl, op), p, np + ) .at[iy] .get(), - in_axes=(0, 0, None, None, 0), - )(ts, xs, p, tcl, iys) + in_axes=(0, 0, None, None, 0, 0, 0), + )(ts, xs, p, tcl, iys, ops, nps) @eqx.filter_jit def simulate_condition( @@ -469,6 +502,8 @@ def simulate_condition( my: jt.Float[jt.Array, "nt"], iys: jt.Int[jt.Array, "nt"], iy_trafos: jt.Int[jt.Array, "nt"], + ops: jt.Float[jt.Array, "nt *nop"], + nps: jt.Float[jt.Array, "nt *nnp"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, adjoint: diffrax.AbstractAdjoint, @@ -497,13 +532,12 @@ def simulate_condition( observed data :param iys: indices of the observables according to ordering in :ivar observable_ids: - :param x_preeq: - initial state vector for pre-equilibration. If not provided, the initial state vector is computed using - :meth:`_x0`. - :param mask_reinit: - mask for re-initialization. If `True`, the corresponding state variable is re-initialized. - :param x_reinit: - re-initialized state vector. If not provided, the state vector is not re-initialized. + :param iy_trafos: + indices of transformations for observables + :param ops: + observables parameters + :param nps: + noise parameters :param solver: ODE solver :param controller: @@ -515,13 +549,20 @@ def simulate_condition( event function for steady state. See :func:`diffrax.steady_state_event` for details. :param max_steps: maximum number of solver steps - :param ret: - which output to return. See :class:`ReturnValue` for available options. + :param x_preeq: + initial state vector for pre-equilibration. If not provided, the initial state vector is computed using + :meth:`_x0`. + :param mask_reinit: + mask for re-initialization. If `True`, the corresponding state variable is re-initialized. + :param x_reinit: + re-initialized state vector. If not provided, the state vector is not re-initialized. :param ts_mask: mask to remove (padded) time points. If `True`, the corresponding time point is used for the evaluation of the output. Only applied if ret is ReturnValue.llh, ReturnValue.nllhs, ReturnValue.res, or ReturnValue.chi2. + :param ret: + which output to return. See :class:`ReturnValue` for available options. :return: - output according to `ret` and statistics + output according to `ret` and general results/statistics """ if x_preeq.shape[0]: x = x_preeq @@ -578,7 +619,7 @@ def simulate_condition( x = jnp.concatenate((x_dyn, x_posteq), axis=0) - nllhs = self._nllhs(ts, x, p, tcl, my, iys) + nllhs = self._nllhs(ts, x, p, tcl, my, iys, ops, nps) nllhs = jnp.where(ts_mask, nllhs, 0.0) llh = -jnp.sum(nllhs) @@ -598,9 +639,9 @@ def simulate_condition( elif ret == ReturnValue.x_solver: output = x elif ret == ReturnValue.y: - output = self._ys(ts, x, p, tcl, iys) + output = self._ys(ts, x, p, tcl, iys, ops) elif ret == ReturnValue.sigmay: - output = self._sigmays(ts, x, p, tcl, iys) + output = self._sigmays(ts, x, p, tcl, iys, ops, nps) elif ret == ReturnValue.x0: output = self._x_rdata(x[0, :], tcl) elif ret == ReturnValue.x0_solver: @@ -616,10 +657,10 @@ def simulate_condition( .at[iy_trafo] .get(), ) - ys_obj = obs_trafo(self._ys(ts, x, p, tcl, iys), iy_trafos) + ys_obj = obs_trafo(self._ys(ts, x, p, tcl, iys, ops), iy_trafos) m_obj = obs_trafo(my, iy_trafos) if ret == ReturnValue.chi2: - sigma_obj = self._sigmays(ts, x, p, tcl, iys) + sigma_obj = self._sigmays(ts, x, p, tcl, iys, ops, nps) chi2 = jnp.square((ys_obj - m_obj) / sigma_obj) chi2 = jnp.where(ts_mask, chi2, 0.0) output = jnp.sum(chi2) diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index 4329195441..0ad7e48ed9 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -194,7 +194,18 @@ def _generate_jax_code(self) -> None: "x_rdata", "total_cl", ) - sym_names = ("p", "x", "tcl", "w", "my", "y", "sigmay", "x_rdata") + sym_names = ( + "p", + "np", + "op", + "x", + "tcl", + "w", + "my", + "y", + "sigmay", + "x_rdata", + ) indent = 8 diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index c47a00e1e3..8ac110fbc2 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -1,8 +1,9 @@ """PEtab wrappers for JAX models.""" "" +import copy import shutil from numbers import Number -from collections.abc import Iterable +from collections.abc import Sized, Iterable from pathlib import Path from collections.abc import Callable @@ -88,6 +89,12 @@ class JAXProblem(eqx.Module): _iys: np.ndarray _iy_trafos: np.ndarray _ts_masks: np.ndarray + _op_numeric: np.ndarray + _op_mask: np.ndarray + _op_indices: np.ndarray + _np_numeric: np.ndarray + _np_mask: np.ndarray + _np_indices: np.ndarray _petab_measurement_indices: np.ndarray _petab_problem: petab.Problem @@ -113,6 +120,12 @@ def __init__(self, model: JAXModel, petab_problem: petab.Problem): self._iy_trafos, self._ts_masks, self._petab_measurement_indices, + self._op_numeric, + self._op_mask, + self._op_indices, + self._np_numeric, + self._np_mask, + self._np_indices, ) = self._get_measurements(scs) self.parameters = self._get_nominal_parameter_values() @@ -169,13 +182,22 @@ def _get_parameter_mappings( Dictionary mapping simulation conditions to parameter mappings. """ scs = list(set(simulation_conditions.values.flatten())) + petab_problem = copy.deepcopy(self._petab_problem) + # remove observable and noise parameters from measurement dataframe as we are mapping them elsewhere + petab_problem.measurement_df.drop( + columns=[petab.OBSERVABLE_PARAMETERS, petab.NOISE_PARAMETERS], + inplace=True, + errors="ignore", + ) mappings = create_parameter_mapping( - petab_problem=self._petab_problem, + petab_problem=petab_problem, simulation_conditions=[ {petab.SIMULATION_CONDITION_ID: sc} for sc in scs ], scaled_parameters=False, + allow_timepoint_specific_numeric_noise_parameters=True, ) + # fill in dummy variables for mapping in mappings: for sim_var, value in mapping.map_sim_var.items(): if isinstance(value, Number) and not np.isfinite(value): @@ -192,6 +214,12 @@ def _get_measurements( np.ndarray, np.ndarray, np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, ]: """ Get measurements for the model based on the provided simulation conditions. @@ -208,9 +236,38 @@ def _get_measurements( - observable transformations indices - measurement masks - data indices (index in petab measurement dataframe). + - numeric values for observable parameter overrides + - non-numeric mask for observable parameter overrides + - parameter indices (problem parameters) for observable parameter overrides + - numeric values for noise parameter overrides + - non-numeric mask for noise parameter overrides + - parameter indices (problem parameters) for noise parameter overrides """ measurements = dict() petab_indices = dict() + + n_pars = dict() + for col in [petab.OBSERVABLE_PARAMETERS, petab.NOISE_PARAMETERS]: + n_pars[col] = 0 + if col in self._petab_problem.measurement_df: + if np.issubdtype( + self._petab_problem.measurement_df[col].dtype, np.number + ): + n_pars[col] = 1 - int( + self._petab_problem.measurement_df[col].isna().all() + ) + else: + n_pars[col] = ( + self._petab_problem.measurement_df[col] + .str.split(petab.C.PARAMETER_SEPARATOR) + .apply( + lambda x: len(x) + if isinstance(x, Sized) + else 1 - int(pd.isna(x)) + ) + .max() + ) + for _, simulation_condition in simulation_conditions.iterrows(): query = " & ".join( [f"{k} == '{v}'" for k, v in simulation_condition.items()] @@ -249,94 +306,184 @@ def _get_measurements( else: iy_trafos = np.zeros_like(iys) + parameter_overrides_par_indices = dict() + parameter_overrides_numeric_vals = dict() + parameter_overrides_mask = dict() + + def get_parameter_override(x): + if ( + x in self._petab_problem.parameter_df.index + and not self._petab_problem.parameter_df.loc[ + x, petab.ESTIMATE + ] + ): + return self._petab_problem.parameter_df.loc[ + x, petab.NOMINAL_VALUE + ] + return x + + for col in [petab.OBSERVABLE_PARAMETERS, petab.NOISE_PARAMETERS]: + if col not in m or m[col].isna().all(): + mat_numeric = jnp.ones((len(m), n_pars[col])) + par_mask = np.zeros_like(mat_numeric, dtype=bool) + par_index = np.zeros_like(mat_numeric, dtype=int) + elif np.issubdtype(m[col].dtype, np.number): + mat_numeric = np.expand_dims(m[col].values, axis=1) + par_mask = np.zeros_like(mat_numeric, dtype=bool) + par_index = np.zeros_like(mat_numeric, dtype=int) + else: + split_vals = m[col].str.split(petab.C.PARAMETER_SEPARATOR) + list_vals = split_vals.apply( + lambda x: [get_parameter_override(y) for y in x] + if isinstance(x, list) + else [] + if pd.isna(x) + else [ + x + ] # every string gets transformed to lists, so this is already a float + ) + vals = list_vals.apply( + lambda x: np.pad( + x, + (0, n_pars[col] - len(x)), + mode="constant", + constant_values=1.0, + ) + ) + mat = np.stack(vals) + # deconstruct such that we can reconstruct mapped parameter overrides via vectorized operations + # mat = np.where(par_mask, map(lambda ip: p.at[ip], par_index), mat_numeric) + par_index = np.vectorize( + lambda x: self.parameter_ids.index(x) + if x in self.parameter_ids + else -1 + )(mat) + # map out numeric values + par_mask = par_index != -1 + # remove non-numeric values + mat[par_mask] = 0.0 + mat_numeric = mat.astype(float) + # replace dummy index with some valid index + par_index[~par_mask] = 0 + + parameter_overrides_numeric_vals[col] = mat_numeric + parameter_overrides_mask[col] = par_mask + parameter_overrides_par_indices[col] = par_index + measurements[tuple(simulation_condition)] = ( - ts_dyn, - ts_posteq, - my, - iys, - iy_trafos, + ts_dyn, # 0 + ts_posteq, # 1 + my, # 2 + iys, # 3 + iy_trafos, # 4 + parameter_overrides_numeric_vals[ + petab.OBSERVABLE_PARAMETERS + ], # 5 + parameter_overrides_mask[petab.OBSERVABLE_PARAMETERS], # 6 + parameter_overrides_par_indices[ + petab.OBSERVABLE_PARAMETERS + ], # 7 + parameter_overrides_numeric_vals[petab.NOISE_PARAMETERS], # 8 + parameter_overrides_mask[petab.NOISE_PARAMETERS], # 9 + parameter_overrides_par_indices[petab.NOISE_PARAMETERS], # 10 ) petab_indices[tuple(simulation_condition)] = tuple(index.tolist()) # compute maximum lengths - n_ts_dyn = max( - len(ts_dyn) for ts_dyn, _, _, _, _ in measurements.values() - ) - n_ts_posteq = max( - len(ts_posteq) for _, ts_posteq, _, _, _ in measurements.values() - ) + n_ts_dyn = max(len(mv[0]) for mv in measurements.values()) + n_ts_posteq = max(len(mv[1]) for mv in measurements.values()) # pad with last value and stack ts_dyn = np.stack( [ - np.pad(x, (0, n_ts_dyn - len(x)), mode="edge") - for x, _, _, _, _ in measurements.values() + np.pad(mv[0], (0, n_ts_dyn - len(mv[0])), mode="edge") + for mv in measurements.values() ] ) ts_posteq = np.stack( [ - np.pad(x, (0, n_ts_posteq - len(x)), mode="edge") - for _, x, _, _, _ in measurements.values() + np.pad(mv[1], (0, n_ts_posteq - len(mv[1])), mode="edge") + for mv in measurements.values() ] ) - def pad_measurement(x_dyn, x_peq, n_ts_dyn, n_ts_posteq): + def pad_measurement(x_dyn, x_peq): + # only pad first axis + pad_width_dyn = tuple( + [(0, n_ts_dyn - len(x_dyn))] + [(0, 0)] * (x_dyn.ndim - 1) + ) + pad_width_peq = tuple( + [(0, n_ts_posteq - len(x_peq))] + [(0, 0)] * (x_peq.ndim - 1) + ) return np.concatenate( ( - np.pad(x_dyn, (0, n_ts_dyn - len(x_dyn)), mode="edge"), - np.pad(x_peq, (0, n_ts_posteq - len(x_peq)), mode="edge"), + np.pad(x_dyn, pad_width_dyn, mode="edge"), + np.pad(x_peq, pad_width_peq, mode="edge"), ) ) - my = np.stack( - [ - pad_measurement( - x[: len(tdyn)], x[len(tdyn) :], n_ts_dyn, n_ts_posteq - ) - for tdyn, tpeq, x, _, _ in measurements.values() - ] - ) - iys = np.stack( - [ - pad_measurement( - x[: len(tdyn)], x[len(tdyn) :], n_ts_dyn, n_ts_posteq - ) - for tdyn, tpeq, _, x, _ in measurements.values() - ] - ) - iy_trafos = np.stack( - [ - pad_measurement( - x[: len(tdyn)], x[len(tdyn) :], n_ts_dyn, n_ts_posteq - ) - for tdyn, tpeq, _, _, x in measurements.values() - ] - ) + def pad_and_stack(output_index: int): + return np.stack( + [ + pad_measurement( + mv[output_index][: len(mv[0])], + mv[output_index][len(mv[0]) :], + ) + for mv in measurements.values() + ] + ) + + my = pad_and_stack(2) + iys = pad_and_stack(3) + iy_trafos = pad_and_stack(4) + op_numeric = pad_and_stack(5) + op_mask = pad_and_stack(6) + op_indices = pad_and_stack(7) + np_numeric = pad_and_stack(8) + np_mask = pad_and_stack(9) + np_indices = pad_and_stack(10) ts_masks = np.stack( [ np.concatenate( ( - np.pad(np.ones_like(tdyn), (0, n_ts_dyn - len(tdyn))), np.pad( - np.ones_like(tpeq), (0, n_ts_posteq - len(tpeq)) + np.ones_like(mv[0]), (0, n_ts_dyn - len(mv[0])) + ), + np.pad( + np.ones_like(mv[1]), (0, n_ts_posteq - len(mv[1])) ), ) ) - for tdyn, tpeq, _, _, _ in measurements.values() + for mv in measurements.values() ] ).astype(bool) petab_indices = np.stack( [ pad_measurement( - idx[: len(tdyn)], idx[len(tdyn) :], n_ts_dyn, n_ts_posteq + np.array(idx[: len(mv[0])]), + np.array(idx[len(mv[0]) :]), ) - for (tdyn, tpeq, _, _, _), idx in zip( + for mv, idx in zip( measurements.values(), petab_indices.values() ) ] ) - return ts_dyn, ts_posteq, my, iys, iy_trafos, ts_masks, petab_indices + return ( + ts_dyn, + ts_posteq, + my, + iys, + iy_trafos, + ts_masks, + petab_indices, + op_numeric, + op_mask, + op_indices, + np_numeric, + np_mask, + np_indices, + ) def get_all_simulation_conditions(self) -> tuple[tuple[str, ...], ...]: simulation_conditions = ( @@ -549,21 +696,76 @@ def update_parameters(self, p: jt.Float[jt.Array, "np"]) -> "JAXProblem": return eqx.tree_at(lambda p: p.parameters, self, p) def _prepare_conditions( - self, conditions: Iterable[str] + self, + conditions: list[str], + op_numeric: np.ndarray | None = None, + op_mask: np.ndarray | None = None, + op_indices: np.ndarray | None = None, + np_numeric: np.ndarray | None = None, + np_mask: np.ndarray | None = None, + np_indices: np.ndarray | None = None, ) -> tuple[ - jt.Float[jt.Array, "np"], # noqa: F821 + jt.Float[jt.Array, "nc np"], # noqa: F821, F722 jt.Bool[jt.Array, "nx"], # noqa: F821 jt.Float[jt.Array, "nx"], # noqa: F821 + jt.Float[jt.Array, "nc nt nop"], # noqa: F821, F722 + jt.Float[jt.Array, "nc nt nnp"], # noqa: F821, F722 ]: """ Prepare conditions for simulation. :param conditions: Simulation conditions to prepare. + :param op_numeric: + Numeric values for observable parameter overrides. If None, no overrides are used. + :param op_mask: + Mask for observable parameter overrides. True for free parameter overrides, False for numeric values. + :param op_indices: + Free parameter indices (wrt. `self.parameters`) for observable parameter overrides. + :param np_numeric: + Numeric values for noise parameter overrides. If None, no overrides are used. + :param np_mask: + Mask for noise parameter overrides. True for free parameter overrides, False for numeric values. + :param np_indices: + Free parameter indices (wrt. `self.parameters`) for noise parameter overrides. :return: - Tuple of parameter arrays, reinitialisation masks and reinitialisation values. + Tuple of parameter arrays, reinitialisation masks and reinitialisation values, observable parameters and + noise parameters. """ p_array = jnp.stack([self.load_parameters(sc) for sc in conditions]) + unscaled_parameters = jnp.stack( + [ + jax_unscale( + self.parameters[ip], + self._petab_problem.parameter_df.loc[ + p_id, petab.PARAMETER_SCALE + ], + ) + for ip, p_id in enumerate(self.parameter_ids) + ] + ) + + if op_numeric is not None and op_numeric.size: + op_array = jnp.where( + op_mask, + jax.vmap( + jax.vmap(jax.vmap(lambda ip: unscaled_parameters[ip])) + )(op_indices), + op_numeric, + ) + else: + op_array = jnp.zeros((*self._ts_masks.shape[:2], 0)) + + if np_numeric is not None and np_numeric.size: + np_array = jnp.where( + np_mask, + jax.vmap( + jax.vmap(jax.vmap(lambda ip: unscaled_parameters[ip])) + )(np_indices), + np_numeric, + ) + else: + np_array = jnp.zeros((*self._ts_masks.shape[:2], 0)) mask_reinit_array = jnp.stack( [ @@ -577,7 +779,7 @@ def _prepare_conditions( for sc, p in zip(conditions, p_array) ] ) - return p_array, mask_reinit_array, x_reinit_array + return p_array, mask_reinit_array, x_reinit_array, op_array, np_array @eqx.filter_vmap( in_axes={ @@ -593,6 +795,8 @@ def run_simulation( my: np.ndarray, iys: np.ndarray, iy_trafos: np.ndarray, + ops: jt.Float[jt.Array, "nt *nop"], # noqa: F821, F722 + nps: jt.Float[jt.Array, "nt *nnp"], # noqa: F821, F722 mask_reinit: jt.Bool[jt.Array, "nx"], # noqa: F821, F722 x_reinit: jt.Float[jt.Array, "nx"], # noqa: F821, F722 solver: diffrax.AbstractSolver, @@ -620,6 +824,10 @@ def run_simulation( (Padded) observable indices :param iy_trafos: (Padded) observable transformations indices + :param ops: + (Padded) observable parameters + :param nps: + (Padded) noise parameters :param mask_reinit: Mask for states that need reinitialisation :param x_reinit: @@ -650,6 +858,8 @@ def run_simulation( my=jax.lax.stop_gradient(jnp.array(my)), iys=jax.lax.stop_gradient(jnp.array(iys)), iy_trafos=jax.lax.stop_gradient(jnp.array(iy_trafos)), + nps=nps, + ops=ops, x_preeq=x_preeq, mask_reinit=jax.lax.stop_gradient(mask_reinit), x_reinit=x_reinit, @@ -699,8 +909,16 @@ def run_simulations( Output value and condition specific results and statistics. Results and statistics are returned as a dict with arrays with the leading dimension corresponding to the simulation conditions. """ - p_array, mask_reinit_array, x_reinit_array = self._prepare_conditions( - simulation_conditions + p_array, mask_reinit_array, x_reinit_array, op_array, np_array = ( + self._prepare_conditions( + simulation_conditions, + self._op_numeric, + self._op_mask, + self._op_indices, + self._np_numeric, + self._np_mask, + self._np_indices, + ) ) return self.run_simulation( p_array, @@ -709,6 +927,8 @@ def run_simulations( self._my, self._iys, self._iy_trafos, + op_array, + np_array, mask_reinit_array, x_reinit_array, solver, @@ -779,8 +999,8 @@ def run_preequilibrations( ], max_steps: jnp.int_, ): - p_array, mask_reinit_array, x_reinit_array = self._prepare_conditions( - simulation_conditions + p_array, mask_reinit_array, x_reinit_array, _, _ = ( + self._prepare_conditions(simulation_conditions, None, None) ) return self.run_preequilibration( p_array, diff --git a/python/sdist/amici/petab/pysb_import.py b/python/sdist/amici/petab/pysb_import.py index 32de3d6666..b426ce424b 100644 --- a/python/sdist/amici/petab/pysb_import.py +++ b/python/sdist/amici/petab/pysb_import.py @@ -16,6 +16,7 @@ from petab.v1.C import CONDITION_NAME, NOISE_FORMULA, OBSERVABLE_FORMULA from petab.v1.models.pysb_model import PySBModel +from ..import_utils import strip_pysb from ..logging import get_logger, log_execution_time, set_log_level from . import PREEQ_INDICATOR_ID from .import_helpers import ( @@ -28,7 +29,7 @@ def _add_observation_model( - pysb_model: pysb.Model, petab_problem: petab.Problem + pysb_model: pysb.Model, petab_problem: petab.Problem, jax: bool = False ): """Extend PySB model by observation model as defined in the PEtab observables table""" @@ -39,22 +40,45 @@ def _add_observation_model( for comp in pysb_model.components if isinstance(comp, sp.Symbol) } - for formula in [ - *petab_problem.observable_df[OBSERVABLE_FORMULA], - *petab_problem.observable_df[NOISE_FORMULA], - ]: - sym = sp.sympify(formula, locals=local_syms) - for s in sym.free_symbols: - if not isinstance(s, pysb.Component): - p = pysb.Parameter(str(s), 1.0) - pysb_model.add_component(p) - local_syms[sp.Symbol.__str__(p)] = p + obs_df = petab_problem.observable_df.copy() + for col, placeholder_pattern in ( + (OBSERVABLE_FORMULA, r"^(observableParameter\d+)_\w+$"), + (NOISE_FORMULA, r"^(noiseParameter\d+)_\w+$"), + ): + for ir, formula in petab_problem.observable_df[col].items(): + if not isinstance(formula, str): + continue + + changed_formula = False + sym = sp.sympify(formula, locals=local_syms) + for s in sym.free_symbols: + if not isinstance(s, pysb.Component): + if jax: + name = re.sub(placeholder_pattern, r"\1", str(s)) + else: + name = str(s) + p = pysb.Parameter(name, 1.0) + pysb_model.add_component(p) + + # placeholders for multiple observables are mapped to the same symbol, so only add to local_syms + # when necessary + if name not in local_syms: + local_syms[name] = p + + # replace placeholder with parameter + if jax and name != str(s): + changed_formula = True + sym = sym.subs(s, local_syms[name]) + + # update forum + if jax and changed_formula: + obs_df.at[ir, col] = str(strip_pysb(sym)) # add observables and sigmas to pysb model for observable_id, observable_formula, noise_formula in zip( - petab_problem.observable_df.index, - petab_problem.observable_df[OBSERVABLE_FORMULA], - petab_problem.observable_df[NOISE_FORMULA], + obs_df.index, + obs_df[OBSERVABLE_FORMULA], + obs_df[NOISE_FORMULA], strict=True, ): obs_symbol = sp.sympify(observable_formula, locals=local_syms) @@ -210,7 +234,7 @@ def import_model_pysb( name=petab_problem.model.model_id, ) - _add_observation_model(pysb_model, petab_problem) + _add_observation_model(pysb_model, petab_problem, jax) # generate species for the _original_ model pysb.bng.generate_equations(petab_problem.model.model) fixed_parameters = _add_initialization_variables(pysb_model, petab_problem) @@ -274,6 +298,7 @@ def import_model_pysb( observables=observables, sigmas=sigmas, noise_distributions=noise_distrs, + pysb_model_has_obs_and_noise=True, **kwargs, ) return @@ -289,5 +314,6 @@ def import_model_pysb( sigmas=sigmas, constant_parameters=constant_parameters, noise_distributions=noise_distrs, + pysb_model_has_obs_and_noise=True, **kwargs, ) diff --git a/python/sdist/amici/petab/sbml_import.py b/python/sdist/amici/petab/sbml_import.py index e605a9cc80..9f5345d082 100644 --- a/python/sdist/amici/petab/sbml_import.py +++ b/python/sdist/amici/petab/sbml_import.py @@ -1,4 +1,6 @@ import logging +import re + import math import os import tempfile @@ -147,7 +149,7 @@ def _workaround_initial_states( def _workaround_observable_parameters( - observables, sigmas, sbml_model, output_parameter_defaults + observables, sigmas, sbml_model, output_parameter_defaults, jax=False ): # TODO: adding extra output parameters is currently not supported, # so we add any output parameters to the SBML model. @@ -165,7 +167,25 @@ def _workaround_observable_parameters( ) for free_sym in free_syms: sym = str(free_sym) - if ( + if jax and (m := re.match(r"(noiseParameter\d+)_(\w+)", sym)): + # group1 is the noise parameter, group2 is the observable, don't add to sbml but replace with generic + # noise parameter + sigmas[m.group(2)] = str( + sp.sympify(sigmas[m.group(2)], locals=_clash).subs( + free_sym, sp.Symbol(m.group(1)) + ) + ) + elif jax and ( + m := re.match(r"(observableParameter\d+)_(\w+)", sym) + ): + # group1 is the noise parameter, group2 is the observable, don't add to sbml but replace with generic + # observable parameter + observables[m.group(2)]["formula"] = str( + sp.sympify( + observables[m.group(2)]["formula"], locals=_clash + ).subs(free_sym, sp.Symbol(m.group(1))) + ) + elif ( sbml_model.getElementBySId(sym) is None and sym != "time" and sym not in observables @@ -317,7 +337,8 @@ def import_model_sbml( ) ) if ( - petab_problem.measurement_df is not None + not jax + and petab_problem.measurement_df is not None and petab.lint.measurement_table_has_timepoint_specific_mappings( petab_problem.measurement_df, allow_scalar_numeric_noise_parameters=allow_n_noise_pars, @@ -346,8 +367,9 @@ def import_model_sbml( ) _workaround_observable_parameters( - observables, sigmas, sbml_model, output_parameter_defaults + observables, sigmas, sbml_model, output_parameter_defaults, jax=jax ) + if not jax: fixed_parameters = _workaround_initial_states( petab_problem=petab_problem, diff --git a/python/sdist/amici/pysb_import.py b/python/sdist/amici/pysb_import.py index b84fadea44..1812906858 100644 --- a/python/sdist/amici/pysb_import.py +++ b/python/sdist/amici/pysb_import.py @@ -8,6 +8,7 @@ import itertools import logging import os +import re import sys from pathlib import Path from typing import ( @@ -33,6 +34,7 @@ SigmaY, ) from .de_model import DEModel +from .de_model_components import NoiseParameter, ObservableParameter from .import_utils import ( _get_str_symbol_identifiers, _parse_special_functions, @@ -62,6 +64,7 @@ def pysb2jax( # See https://github.com/AMICI-dev/AMICI/pull/1672 cache_simplify: bool = False, model_name: str | None = None, + pysb_model_has_obs_and_noise: bool = False, ): r""" Generate AMICI jax files for the provided model. @@ -118,6 +121,9 @@ def pysb2jax( :param model_name: Name for the generated model module. If None, :attr:`pysb.Model.name` will be used. + + :param pysb_model_has_obs_and_noise: + if set to ``True``, the pysb model is expected to have extra observables and noise variables added """ if observables is None: observables = [] @@ -137,6 +143,8 @@ def pysb2jax( simplify=simplify, cache_simplify=cache_simplify, verbose=verbose, + jax=True, + pysb_model_has_obs_and_noise=pysb_model_has_obs_and_noise, ) from amici.jax.ode_export import ODEExporter @@ -168,6 +176,7 @@ def pysb2amici( cache_simplify: bool = False, generate_sensitivity_code: bool = True, model_name: str | None = None, + pysb_model_has_obs_and_noise: bool = False, ): r""" Generate AMICI C++ files for the provided model. @@ -245,6 +254,9 @@ def pysb2amici( :param model_name: Name for the generated model module. If None, :attr:`pysb.Model.name` will be used. + + :param pysb_model_has_obs_and_noise: + if set to ``True``, the pysb model is expected to have extra observables and noise variables added """ if observables is None: observables = [] @@ -267,6 +279,7 @@ def pysb2amici( simplify=simplify, cache_simplify=cache_simplify, verbose=verbose, + pysb_model_has_obs_and_noise=pysb_model_has_obs_and_noise, ) exporter = DEExporter( ode_model, @@ -300,6 +313,8 @@ def ode_model_from_pysb_importer( # See https://github.com/AMICI-dev/AMICI/pull/1672 cache_simplify: bool = False, verbose: int | bool = False, + jax: bool = False, + pysb_model_has_obs_and_noise: bool = False, ) -> DEModel: """ Creates an :class:`amici.DEModel` instance from a :class:`pysb.Model` @@ -335,6 +350,12 @@ def ode_model_from_pysb_importer( :param verbose: verbosity level for logging, True/False default to :attr:`logging.DEBUG`/:attr:`logging.ERROR` + :param jax: + if set to ``True``, the generated model will be compatible with JAX export + + :param pysb_model_has_obs_and_noise: + if set to ``True``, the pysb model is expected to have extra observables and noise variables added + :return: New DEModel instance according to pysbModel """ @@ -357,14 +378,24 @@ def ode_model_from_pysb_importer( pysb.bng.generate_equations(model, verbose=verbose) _process_pysb_species(model, ode) - _process_pysb_parameters(model, ode, constant_parameters) + _process_pysb_parameters(model, ode, constant_parameters, jax) if compute_conservation_laws: _process_pysb_conservation_laws(model, ode) _process_pysb_observables( - model, ode, observables, sigmas, noise_distributions + model, + ode, + observables, + sigmas, + noise_distributions, + pysb_model_has_obs_and_noise, ) _process_pysb_expressions( - model, ode, observables, sigmas, noise_distributions + model, + ode, + observables, + sigmas, + noise_distributions, + pysb_model_has_obs_and_noise, ) ode._has_quadratic_nllh = not noise_distributions or all( noise_distr in ["normal", "lin-normal", "log-normal", "log10-normal"] @@ -510,7 +541,10 @@ def _process_pysb_species(pysb_model: pysb.Model, ode_model: DEModel) -> None: @log_execution_time("processing PySB parameters", logger) def _process_pysb_parameters( - pysb_model: pysb.Model, ode_model: DEModel, constant_parameters: list[str] + pysb_model: pysb.Model, + ode_model: DEModel, + constant_parameters: list[str], + jax: bool = False, ) -> None: """ Converts pysb parameters into Parameters or Constants and adds them to @@ -522,16 +556,26 @@ def _process_pysb_parameters( :param constant_parameters: list of Parameters that should be constants + :param jax: + if set to ``True``, the generated model will be compatible JAX export + :param ode_model: DEModel instance """ for par in pysb_model.parameters: + args = [par, f"{par.name}"] if par.name in constant_parameters: comp = Constant + args.append(par.value) + elif jax and re.match(r"noiseParameter\d+", par.name): + comp = NoiseParameter + elif jax and re.match(r"observableParameter\d+", par.name): + comp = ObservableParameter else: comp = Parameter + args.append(par.value) - ode_model.add_component(comp(par, f"{par.name}", par.value)) + ode_model.add_component(comp(*args)) @log_execution_time("processing PySB expressions", logger) @@ -541,6 +585,7 @@ def _process_pysb_expressions( observables: list[str], sigmas: dict[str, str], noise_distributions: dict[str, str | Callable] | None = None, + pysb_model_has_obs_and_noise: bool = False, ) -> None: r""" Converts pysb expressions/observables into Observables (with @@ -565,6 +610,9 @@ def _process_pysb_expressions( :param ode_model: DEModel instance + + :param pysb_model_has_obs_and_noise: + if set to ``True``, the pysb model is expected to have extra observables and noise variables added """ # we no longer expand expressions here. pysb/bng guarantees that # they are ordered according to their dependency and we can @@ -594,6 +642,7 @@ def _process_pysb_expressions( observables, sigmas, noise_distributions, + pysb_model_has_obs_and_noise, ) @@ -606,6 +655,7 @@ def _add_expression( observables: list[str], sigmas: dict[str, str], noise_distributions: dict[str, str | Callable] | None = None, + pysb_model_has_obs_and_noise: bool = False, ): """ Adds expressions to the ODE model given and adds observables/sigmas if @@ -634,10 +684,18 @@ def _add_expression( :param ode_model: see :py:func:`_process_pysb_expressions` + + :param pysb_model_has_obs_and_noise: + if set to ``True``, the pysb model is expected to have extra observables and noise variables added """ - ode_model.add_component( - Expression(sym, name, _parse_special_functions(expr)) - ) + if not pysb_model_has_obs_and_noise or name not in observables: + if name in list(sigmas.values()): + component = SigmaY + else: + component = Expression + ode_model.add_component( + component(sym, name, _parse_special_functions(expr)) + ) if name in observables: noise_dist = ( @@ -646,17 +704,22 @@ def _add_expression( else "normal" ) - y = sp.Symbol(f"{name}") + y = sp.Symbol(name) trafo = noise_distribution_to_observable_transformation(noise_dist) - obs = Observable(y, name, sym, transformation=trafo) - ode_model.add_component(obs) - - sigma_name, sigma_value = _get_sigma_name_and_value( - pysb_model, name, sigmas + # note that this is a bit iffy since we are potentially using the same symbolic identifier in expressions (w) + # and observables (y). This is not a problem as there currently are no model functions that use both. If this + # changes, I would expect symbol redefinition warnings in CPP models and overwriting in JAX models, but as both + # symbols refer to the same symbolic entity, this should not be a problem (untested) + obs = Observable( + y, name, _parse_special_functions(expr), transformation=trafo ) + ode_model.add_component(obs) - sigma = sp.Symbol(sigma_name) - ode_model.add_component(SigmaY(sigma, f"{sigma_name}", sigma_value)) + sigma = _get_sigma(pysb_model, name, sigmas) + if not pysb_model_has_obs_and_noise: + ode_model.add_component( + SigmaY(sigma, f"sigma_{name}", sp.Float(1.0)) + ) cost_fun_str = noise_distribution_to_cost_function(noise_dist)(name) my = generate_measurement_symbol(obs.get_id()) @@ -677,9 +740,9 @@ def _add_expression( ) -def _get_sigma_name_and_value( +def _get_sigma( pysb_model: pysb.Model, obs_name: str, sigmas: dict[str, str] -) -> tuple[str, sp.Basic]: +) -> sp.Symbol: """ Tries to extract standard deviation symbolic identifier and formula for a given observable name from the pysb model and if no specification is @@ -696,26 +759,17 @@ def _get_sigma_name_and_value( sigmas :return: - tuple containing symbolic identifier and formula for the specified - observable + symbolic variable representing the standard deviation of the observable """ if obs_name in sigmas: sigma_name = sigmas[obs_name] - try: - # find corresponding Expression instance - sigma_expr = next( - x for x in pysb_model.expressions if x.name == sigma_name - ) - except StopIteration: - raise ValueError( - f"value of sigma {obs_name} is not a " f"valid expression." - ) - sigma_value = sigma_expr.expand_expr() + if sigma_name in pysb_model.expressions.keys(): + return pysb_model.expressions[sigma_name] + raise ValueError( + f"value of sigma {obs_name} is not a valid expression." + ) else: - sigma_name = f"sigma_{obs_name}" - sigma_value = sp.sympify(1.0) - - return sigma_name, sigma_value + return sp.Symbol(f"sigma_{obs_name}") @log_execution_time("processing PySB observables", logger) @@ -725,6 +779,7 @@ def _process_pysb_observables( observables: list[str], sigmas: dict[str, str], noise_distributions: dict[str, str | Callable] | None = None, + pysb_model_has_obs_and_noise: bool = False, ) -> None: """ Converts :class:`pysb.core.Observable` into @@ -746,6 +801,9 @@ def _process_pysb_observables( :param noise_distributions: see :func:`amici.pysb_import.pysb2amici` + + :param pysb_model_has_obs_and_noise: + if set to ``True``, the pysb model is expected to have extra observables and noise variables added """ # only add those pysb observables that occur in the added # Observables as expressions @@ -759,6 +817,7 @@ def _process_pysb_observables( observables, sigmas, noise_distributions, + pysb_model_has_obs_and_noise, ) diff --git a/python/sdist/amici/sbml_import.py b/python/sdist/amici/sbml_import.py index 29a9608c4b..f4a3b6b4aa 100644 --- a/python/sdist/amici/sbml_import.py +++ b/python/sdist/amici/sbml_import.py @@ -1985,6 +1985,42 @@ def _process_observables( self.symbols[SymbolId.OBSERVABLE], "eventObservable" ) + if sigmas: + noise_pars = list( + { + name + for sigma in sigmas.values() + for symbol in self._sympy_from_sbml_math( + sigma + ).free_symbols + if re.match(r"noiseParameter\d+$", (name := str(symbol))) + } + ) + else: + noise_pars = [] + self.symbols[SymbolId.NOISE_PARAMETER] = { + symbol_with_assumptions(np): {"name": np} for np in noise_pars + } + + if observables: + observable_pars = list( + { + name + for obs in observables.values() + for symbol in self._sympy_from_sbml_math( + obs["formula"] + ).free_symbols + if re.match( + r"observableParameter\d+$", (name := str(symbol)) + ) + } + ) + else: + observable_pars = [] + self.symbols[SymbolId.OBSERVABLE_PARAMETER] = { + symbol_with_assumptions(op): {"name": op} for op in observable_pars + } + self._process_log_likelihood(sigmas, noise_distributions) @log_execution_time("processing SBML event observables", logger) diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index 78fa026cfc..a826962db0 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -194,6 +194,8 @@ def check_fields_jax( "ts_posteq": jnp.array(ts_posteq), "my": jnp.array(my), "iys": jnp.array(iys), + "ops": jnp.zeros((*my.shape[:2], 0)), + "nps": jnp.zeros((*my.shape[:2], 0)), "iy_trafos": jnp.array(iy_trafos), "x_preeq": jnp.array([]), "solver": diffrax.Kvaerno5(), diff --git a/tests/benchmark-models/test_petab_benchmark.py b/tests/benchmark-models/test_petab_benchmark.py index 2f3fbb433a..a34f14dd29 100644 --- a/tests/benchmark-models/test_petab_benchmark.py +++ b/tests/benchmark-models/test_petab_benchmark.py @@ -5,6 +5,7 @@ for a subset of the benchmark problems. """ +import copy from functools import partial from pathlib import Path @@ -245,17 +246,18 @@ def benchmark_problem(request): the benchmark problem collection.""" problem_id = request.param petab_problem = benchmark_models_petab.get_problem(problem_id) + flat_petab_problem = copy.deepcopy(petab_problem) if measurement_table_has_timepoint_specific_mappings( petab_problem.measurement_df, ): - petab.flatten_timepoint_specific_output_overrides(petab_problem) + petab.flatten_timepoint_specific_output_overrides(flat_petab_problem) # Setup AMICI objects. amici_model = import_petab_problem( - petab_problem, + flat_petab_problem, model_output_dir=benchmark_outdir / problem_id, ) - return problem_id, petab_problem, amici_model + return problem_id, flat_petab_problem, petab_problem, amici_model @pytest.mark.filterwarnings( @@ -272,7 +274,9 @@ def test_jax_llh(benchmark_problem): jax.config.update("jax_enable_x64", True) from beartype import beartype - problem_id, petab_problem, amici_model = benchmark_problem + problem_id, flat_petab_problem, petab_problem, amici_model = ( + benchmark_problem + ) amici_solver = amici_model.getSolver() cur_settings = settings[problem_id] @@ -282,7 +286,7 @@ def test_jax_llh(benchmark_problem): simulate_amici = partial( simulate_petab, - petab_problem=petab_problem, + petab_problem=flat_petab_problem, amici_model=amici_model, solver=amici_solver, scaled_parameters=True, @@ -294,7 +298,7 @@ def test_jax_llh(benchmark_problem): problem_parameters = None if problem_id in problems_for_gradient_check: - point = petab_problem.x_nominal_free_scaled + point = flat_petab_problem.x_nominal_free_scaled for _ in range(20): amici_solver.setSensitivityMethod(amici.SensitivityMethod.adjoint) amici_solver.setSensitivityOrder(amici.SensitivityOrder.first) @@ -306,7 +310,9 @@ def test_jax_llh(benchmark_problem): ) point += point_noise # avoid small gradients at nominal value - problem_parameters = dict(zip(petab_problem.x_free_ids, point)) + problem_parameters = dict( + zip(flat_petab_problem.x_free_ids, point) + ) r_amici = simulate_amici( problem_parameters=problem_parameters, @@ -372,7 +378,7 @@ def test_nominal_parameters_llh(benchmark_problem): Also check that the simulation time is within the reference range. """ - problem_id, petab_problem, amici_model = benchmark_problem + problem_id, petab_problem, _, amici_model = benchmark_problem if problem_id not in problems_for_llh_check: pytest.skip("Excluded from log-likelihood check.") @@ -526,7 +532,7 @@ def test_nominal_parameters_llh(benchmark_problem): def test_benchmark_gradient( benchmark_problem, scale, sensitivity_method, request ): - problem_id, petab_problem, amici_model = benchmark_problem + problem_id, petab_problem, _, amici_model = benchmark_problem if problem_id not in problems_for_gradient_check: pytest.skip("Excluded from gradient check.") diff --git a/tests/petab_test_suite/test_petab_suite.py b/tests/petab_test_suite/test_petab_suite.py index 4fcbe0b631..d79140cd54 100755 --- a/tests/petab_test_suite/test_petab_suite.py +++ b/tests/petab_test_suite/test_petab_suite.py @@ -54,7 +54,7 @@ def _test_case(case, model_type, version, jax): problem = petab.Problem.from_yaml(yaml_file) # compile amici model - if case.startswith("0006"): + if case.startswith("0006") and not jax: petab.flatten_timepoint_specific_output_overrides(problem) model_name = ( f"petab_{model_type}_test_case_{case}" f"_{version.replace('.', '_')}" @@ -115,7 +115,7 @@ def _test_case(case, model_type, version, jax): gt_chi2 = solution[petabtests.CHI2] gt_llh = solution[petabtests.LLH] gt_simulation_dfs = solution[petabtests.SIMULATION_DFS] - if case.startswith("0006"): + if case.startswith("0006") and not jax: # account for flattening gt_simulation_dfs[0].loc[:, petab.OBSERVABLE_ID] = ( "obs_a__10__c0",