Skip to content

Commit

Permalink
No flattening of timepoint specific overrides in jax (#2641)
Browse files Browse the repository at this point in the history
* first working implementation

* bugfixes

* vectorisation go brrrrr

* add doc, workaround Zhao_QuantBiol2020

* fixup

* fixup

* fixup

* fixup

* Update ExampleJaxPEtab.ipynb

* fix notebook

* fix pysb

* fixup

* fixup petab 0006 jax

* fix noise parameters

* fixup noise model

* fixup

* fixup sigma

* remove workaround

* Apply suggestions from code review

Co-authored-by: Daniel Weindl <[email protected]>

* Update de_model.py

* Update pytest.ini

* Update de_model.py

* Update pysb_import.py

* fixup

---------

Co-authored-by: Daniel Weindl <[email protected]>
  • Loading branch information
FFroehlich and dweindl authored Feb 5, 2025
1 parent c9f698d commit b7bdf63
Show file tree
Hide file tree
Showing 16 changed files with 674 additions and 164 deletions.
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ filterwarnings =
# ignore jax deprecation warnings
ignore:jax.* is deprecated:DeprecationWarning


norecursedirs = .git amici_models build doc documentation matlab models ThirdParty amici sdist examples
4 changes: 4 additions & 0 deletions python/examples/example_jax_petab/ExampleJaxPEtab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,8 @@
"my = jax_problem._my[ic, :]\n",
"iys = jax_problem._iys[ic, :]\n",
"iy_trafos = jax_problem._iy_trafos[ic, :]\n",
"ops = jax_problem._op_numeric[ic, :]\n",
"nps = jax_problem._np_numeric[ic, :]\n",
"\n",
"# Load parameters for the specified condition\n",
"p = jax_problem.load_parameters(simulation_condition[0])\n",
Expand All @@ -472,6 +474,8 @@
" my=jnp.array(my),\n",
" iys=jnp.array(iys),\n",
" iy_trafos=jnp.array(iy_trafos),\n",
" ops=jnp.array(ops),\n",
" nps=jnp.array(nps),\n",
" solver=diffrax.Kvaerno5(),\n",
" controller=diffrax.PIDController(atol=1e-8, rtol=1e-8),\n",
" steady_state_event=diffrax.steady_state_event(),\n",
Expand Down
2 changes: 2 additions & 0 deletions python/sdist/amici/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,5 @@ class SymbolId(str, enum.Enum):
SIGMAZ = "sigmaz"
LLHZ = "llhz"
LLHRZ = "llhrz"
NOISE_PARAMETER = "noise_parameter"
OBSERVABLE_PARAMETER = "observable_parameter"
36 changes: 36 additions & 0 deletions python/sdist/amici/de_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
LogLikelihoodY,
LogLikelihoodZ,
LogLikelihoodRZ,
NoiseParameter,
ObservableParameter,
Expression,
ConservationLaw,
Event,
Expand Down Expand Up @@ -226,6 +228,8 @@ def __init__(
self._log_likelihood_ys: list[LogLikelihoodY] = []
self._log_likelihood_zs: list[LogLikelihoodZ] = []
self._log_likelihood_rzs: list[LogLikelihoodRZ] = []
self._noise_parameters: list[NoiseParameter] = []
self._observable_parameters: list[ObservableParameter] = []
self._expressions: list[Expression] = []
self._conservation_laws: list[ConservationLaw] = []
self._events: list[Event] = []
Expand Down Expand Up @@ -273,6 +277,8 @@ def __init__(
"sigmay": self.sigma_ys,
"sigmaz": self.sigma_zs,
"h": self.events,
"np": self.noise_parameters,
"op": self.observable_parameters,
}
self._value_prototype: dict[str, Callable] = {
"p": self.parameters,
Expand Down Expand Up @@ -385,6 +391,14 @@ def log_likelihood_rzs(self) -> list[LogLikelihoodRZ]:
"""Get all event observable regularization log likelihoods."""
return self._log_likelihood_rzs

def noise_parameters(self) -> list[NoiseParameter]:
"""Get all noise parameters."""
return self._noise_parameters

def observable_parameters(self) -> list[ObservableParameter]:
"""Get all observable parameters."""
return self._observable_parameters

def is_ode(self) -> bool:
"""Check if model is ODE model."""
return len(self._algebraic_equations) == 0
Expand Down Expand Up @@ -565,6 +579,8 @@ def add_component(
ConservationLaw,
Event,
EventObservable,
NoiseParameter,
ObservableParameter,
}:
raise ValueError(f"Invalid component type {type(component)}")

Expand Down Expand Up @@ -1087,6 +1103,26 @@ def _generate_symbol(self, name: str) -> None:
"""
if name in self._variable_prototype:
components = self._variable_prototype[name]()
# ensure placeholder parameters are consistently and correctly ordered
# we want that components are ordered by their placeholder index
if name == "op":
components = sorted(
components,
key=lambda x: int(
str(strip_pysb(x.get_id())).replace(
"observableParameter", ""
)
),
)
if name == "np":
components = sorted(
components,
key=lambda x: int(
str(strip_pysb(x.get_id())).replace(
"noiseParameter", ""
)
),
)
self._syms[name] = sp.Matrix(
[comp.get_id() for comp in components]
)
Expand Down
42 changes: 42 additions & 0 deletions python/sdist/amici/de_model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,46 @@ def __init__(
super().__init__(identifier, name, value)


class NoiseParameter(ModelQuantity):
"""
A NoiseParameter is an input variable for the computation of ``sigma`` that can be specified in a data-point
specific manner, abbreviated by ``np``. Only used for jax models.
"""

def __init__(self, identifier: sp.Symbol, name: str):
"""
Create a new Expression instance.
:param identifier:
unique identifier of the NoiseParameter
:param name:
individual name of the NoiseParameter (does not need to be
unique)
"""
super().__init__(identifier, name, 0.0)


class ObservableParameter(ModelQuantity):
"""
A NoiseParameter is an input variable for the computation of ``y`` that can be specified in a data-point specific
manner, abbreviated by ``op``. Only used for jax models.
"""

def __init__(self, identifier: sp.Symbol, name: str):
"""
Create a new Expression instance.
:param identifier:
unique identifier of the ObservableParameter
:param name:
individual name of the ObservableParameter (does not need to be
unique)
"""
super().__init__(identifier, name, 0.0)


class LogLikelihood(ModelQuantity):
"""
A LogLikelihood defines the distance between measurements and
Expand Down Expand Up @@ -751,4 +791,6 @@ def get_trigger_time(self) -> sp.Float:
SymbolId.LLHRZ: LogLikelihoodRZ,
SymbolId.EXPRESSION: Expression,
SymbolId.EVENT: Event,
SymbolId.NOISE_PARAMETER: NoiseParameter,
SymbolId.OBSERVABLE_PARAMETER: ObservableParameter,
}
12 changes: 7 additions & 5 deletions python/sdist/amici/jax/jax.template.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,28 +64,30 @@ def _tcl(self, x, p):

return TPL_TOTAL_CL_RET

def _y(self, t, x, p, tcl):
def _y(self, t, x, p, tcl, op):
TPL_X_SYMS = x
TPL_P_SYMS = p
TPL_W_SYMS = self._w(t, x, p, tcl)
TPL_OP_SYMS = op

TPL_Y_EQ

return TPL_Y_RET

def _sigmay(self, y, p):
def _sigmay(self, y, p, np):
TPL_P_SYMS = p

TPL_Y_SYMS = y
TPL_NP_SYMS = np

TPL_SIGMAY_EQ

return TPL_SIGMAY_RET

def _nllh(self, t, x, p, tcl, my, iy):
y = self._y(t, x, p, tcl)
def _nllh(self, t, x, p, tcl, my, iy, op, np):
y = self._y(t, x, p, tcl, op)
TPL_Y_SYMS = y
TPL_SIGMAY_SYMS = self._sigmay(y, p)
TPL_SIGMAY_SYMS = self._sigmay(y, p, np)

TPL_JY_EQ

Expand Down
Loading

0 comments on commit b7bdf63

Please sign in to comment.