Skip to content

Commit e3bbafb

Browse files
committed
v2: Additional probability distributions
Add additional probability distributions as required for PEtab-dev/PEtab#595. See PEtab-dev#374.
1 parent 87cec8c commit e3bbafb

File tree

1 file changed

+270
-16
lines changed

1 file changed

+270
-16
lines changed

petab/v1/distributions.py

+270-16
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,19 @@
33
from __future__ import annotations
44

55
import abc
6+
from typing import Any
67

78
import numpy as np
8-
from scipy.stats import laplace, norm, uniform
9+
from scipy.stats import (
10+
cauchy,
11+
chi2,
12+
expon,
13+
gamma,
14+
laplace,
15+
norm,
16+
rayleigh,
17+
uniform,
18+
)
919

1020
__all__ = [
1121
"Distribution",
@@ -277,6 +287,21 @@ def _inverse_transform_sample(self, shape) -> np.ndarray | float:
277287
)
278288
return self._ppf_transformed_untruncated(uniform_sample)
279289

290+
def _repr(self, pars: dict[str, Any] = None) -> str:
291+
"""Return a string representation of the distribution."""
292+
pars = ", ".join(f"{k}={v}" for k, v in pars.items()) if pars else ""
293+
294+
if self._logbase is False:
295+
log = ""
296+
elif self._logbase == np.exp(1):
297+
log = ", log=True"
298+
else:
299+
log = f", log={self._logbase}"
300+
301+
trunc = f", trunc={self._trunc}" if self._trunc else ""
302+
303+
return f"{self.__class__.__name__}({pars}{log}{trunc})"
304+
280305

281306
class Normal(Distribution):
282307
"""A (log-)normal distribution.
@@ -307,16 +332,7 @@ def __init__(
307332
super().__init__(log=log, trunc=trunc)
308333

309334
def __repr__(self):
310-
if self._logbase is False:
311-
log = ""
312-
if self._logbase == np.exp(1):
313-
log = ", log=True"
314-
else:
315-
log = f", log={self._logbase}"
316-
317-
trunc = f", trunc={self._trunc}" if self._trunc else ""
318-
319-
return f"Normal(loc={self._loc}, scale={self._scale}{log}{trunc})"
335+
return self._repr({"loc": self._loc, "scale": self._scale})
320336

321337
def _sample(self, shape=None) -> np.ndarray | float:
322338
return np.random.normal(loc=self._loc, scale=self._scale, size=shape)
@@ -366,8 +382,7 @@ def __init__(
366382
super().__init__(log=log)
367383

368384
def __repr__(self):
369-
log = f", log={self._logbase}" if self._logbase else ""
370-
return f"Uniform(low={self._low}, high={self._high}{log})"
385+
return self._repr({"low": self._low, "high": self._high})
371386

372387
def _sample(self, shape=None) -> np.ndarray | float:
373388
return np.random.uniform(low=self._low, high=self._high, size=shape)
@@ -411,9 +426,7 @@ def __init__(
411426
super().__init__(log=log, trunc=trunc)
412427

413428
def __repr__(self):
414-
trunc = f", trunc={self._trunc}" if self._trunc else ""
415-
log = f", log={self._logbase}" if self._logbase else ""
416-
return f"Laplace(loc={self._loc}, scale={self._scale}{trunc}{log})"
429+
return self._repr({"loc": self._loc, "scale": self._scale})
417430

418431
def _sample(self, shape=None) -> np.ndarray | float:
419432
return np.random.laplace(loc=self._loc, scale=self._scale, size=shape)
@@ -436,3 +449,244 @@ def loc(self) -> float:
436449
def scale(self) -> float:
437450
"""The scale parameter of the underlying distribution."""
438451
return self._scale
452+
453+
454+
class Cauchy(Distribution):
455+
"""Cauchy distribution.
456+
457+
A (possibly truncated) `Cauchy distribution
458+
<https://en.wikipedia.org/wiki/Cauchy_distribution>`__.
459+
460+
:param loc: The location parameter of the distribution.
461+
:param scale: The scale parameter of the distribution.
462+
:param trunc: The truncation limits of the distribution.
463+
``None`` if the distribution is not truncated.
464+
If the distribution is log-scaled, the truncation limits are expected
465+
to be on the same log scale.
466+
:param log: If ``True``, the distribution is transformed to a log-Cauchy
467+
distribution. If a float, the distribution is transformed to a
468+
log-Cauchy distribution with the given log-base.
469+
If ``False``, no transformation is applied.
470+
If a transformation is applied, the location and scale parameters
471+
are the location and scale of the underlying Cauchy distribution.
472+
"""
473+
474+
def __init__(
475+
self,
476+
loc: float,
477+
scale: float,
478+
trunc: tuple[float, float] | None = None,
479+
log: bool | float = False,
480+
):
481+
self._loc = loc
482+
self._scale = scale
483+
super().__init__(log=log, trunc=trunc)
484+
485+
def __repr__(self):
486+
return self._repr({"loc": self._loc, "scale": self._scale})
487+
488+
def _pdf_untransformed_untruncated(self, x) -> np.ndarray | float:
489+
return cauchy.pdf(x, loc=self._loc, scale=self._scale)
490+
491+
def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
492+
return cauchy.cdf(x, loc=self._loc, scale=self._scale)
493+
494+
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
495+
return cauchy.ppf(q, loc=self._loc, scale=self._scale)
496+
497+
@property
498+
def loc(self) -> float:
499+
"""The location parameter of the underlying distribution."""
500+
return self._loc
501+
502+
@property
503+
def scale(self) -> float:
504+
"""The scale parameter of the underlying distribution."""
505+
return self._scale
506+
507+
508+
class ChiSquare(Distribution):
509+
"""Chi-squared distribution.
510+
511+
A (possibly truncated) `Chi-squared distribution
512+
<https://en.wikipedia.org/wiki/Chi-squared_distribution>`__.
513+
514+
:param dof: The degrees of freedom parameter of the distribution.
515+
:param trunc: The truncation limits of the distribution.
516+
``None`` if the distribution is not truncated.
517+
If the distribution is log-scaled, the truncation limits are expected
518+
to be on the same log scale.
519+
:param log: If ``True``, the distribution is transformed to a
520+
log-Chi-squared distribution.
521+
If a float, the distribution is transformed to a
522+
log-Chi-squared distribution with the given log-base.
523+
If ``False``, no transformation is applied.
524+
If a transformation is applied, the degrees of freedom parameter
525+
is the degrees of freedom of the underlying Chi-squared distribution.
526+
"""
527+
528+
def __init__(
529+
self,
530+
dof: int,
531+
trunc: tuple[float, float] | None = None,
532+
log: bool | float = False,
533+
):
534+
if not dof.is_integer() or dof < 1:
535+
raise ValueError(
536+
f"`dof' must be a positive integer, but was `{dof}'."
537+
)
538+
539+
self._dof = dof
540+
super().__init__(log=log, trunc=trunc)
541+
542+
def __repr__(self):
543+
return self._repr({"dof": self._dof})
544+
545+
def _pdf_untransformed_untruncated(self, x) -> np.ndarray | float:
546+
return chi2.pdf(x, df=self._dof)
547+
548+
def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
549+
return chi2.cdf(x, df=self._dof)
550+
551+
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
552+
return chi2.ppf(q, df=self._dof)
553+
554+
@property
555+
def dof(self) -> int:
556+
"""The degrees of freedom parameter."""
557+
return self._dof
558+
559+
560+
class Exponential(Distribution):
561+
"""Exponential distribution.
562+
563+
A (possibly truncated) `Exponential distribution
564+
<https://en.wikipedia.org/wiki/Exponential_distribution>`__.
565+
566+
:param scale: The scale parameter of the distribution.
567+
:param trunc: The truncation limits of the distribution.
568+
``None`` if the distribution is not truncated.
569+
"""
570+
571+
def __init__(
572+
self,
573+
scale: float,
574+
trunc: tuple[float, float] | None = None,
575+
):
576+
self._scale = scale
577+
super().__init__(log=False, trunc=trunc)
578+
579+
def __repr__(self):
580+
return self._repr({"scale": self._scale})
581+
582+
def _pdf_untransformed_untruncated(self, x) -> np.ndarray | float:
583+
return expon.pdf(x, scale=self._scale)
584+
585+
def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
586+
return expon.cdf(x, scale=self._scale)
587+
588+
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
589+
return expon.ppf(q, scale=self._scale)
590+
591+
@property
592+
def scale(self) -> float:
593+
"""The scale parameter of the underlying distribution."""
594+
return self._scale
595+
596+
597+
class Gamma(Distribution):
598+
"""Gamma distribution.
599+
600+
A (possibly truncated) `Gamma distribution
601+
<https://en.wikipedia.org/wiki/Gamma_distribution>`__.
602+
603+
:param shape: The shape parameter of the distribution.
604+
:param scale: The scale parameter of the distribution.
605+
:param trunc: The truncation limits of the distribution.
606+
``None`` if the distribution is not truncated.
607+
:param log: If ``True``, the distribution is transformed to a
608+
log-Gamma distribution.
609+
If a float, the distribution is transformed to a
610+
log-Gamma distribution with the given log-base.
611+
If ``False``, no transformation is applied.
612+
If a transformation is applied, the shape and scale parameters
613+
are the shape and scale of the underlying Gamma distribution.
614+
"""
615+
616+
def __init__(
617+
self,
618+
shape: float,
619+
scale: float,
620+
trunc: tuple[float, float] | None = None,
621+
log: bool | float = False,
622+
):
623+
self._shape = shape
624+
self._scale = scale
625+
super().__init__(log=log, trunc=trunc)
626+
627+
def __repr__(self):
628+
return self._repr({"shape": self._shape, "scale": self._scale})
629+
630+
def _pdf_untransformed_untruncated(self, x) -> np.ndarray | float:
631+
return gamma.pdf(x, a=self._shape, scale=self._scale)
632+
633+
def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
634+
return gamma.cdf(x, a=self._shape, scale=self._scale)
635+
636+
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
637+
return gamma.ppf(q, a=self._shape, scale=self._scale)
638+
639+
@property
640+
def shape(self) -> float:
641+
"""The shape parameter of the underlying distribution."""
642+
return self._shape
643+
644+
@property
645+
def scale(self) -> float:
646+
"""The scale parameter of the underlying distribution."""
647+
return self._scale
648+
649+
650+
class Rayleigh(Distribution):
651+
"""Rayleigh distribution.
652+
653+
A (possibly truncated) `Rayleigh distribution
654+
<https://en.wikipedia.org/wiki/Rayleigh_distribution>`__.
655+
656+
:param scale: The scale parameter of the distribution.
657+
:param trunc: The truncation limits of the distribution.
658+
``None`` if the distribution is not truncated.
659+
:param log: If ``True``, the distribution is transformed to a
660+
log-Rayleigh distribution.
661+
If a float, the distribution is transformed to a
662+
log-Rayleigh distribution with the given log-base.
663+
If ``False``, no transformation is applied.
664+
If a transformation is applied, the scale parameter
665+
is the scale of the underlying Rayleigh distribution.
666+
"""
667+
668+
def __init__(
669+
self,
670+
scale: float,
671+
trunc: tuple[float, float] | None = None,
672+
log: bool | float = False,
673+
):
674+
self._scale = scale
675+
super().__init__(log=log, trunc=trunc)
676+
677+
def __repr__(self):
678+
return self._repr({"scale": self._scale})
679+
680+
def _pdf_untransformed_untruncated(self, x) -> np.ndarray | float:
681+
return rayleigh.pdf(x, scale=self._scale)
682+
683+
def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float:
684+
return rayleigh.cdf(x, scale=self._scale)
685+
686+
def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float:
687+
return rayleigh.ppf(q, scale=self._scale)
688+
689+
@property
690+
def scale(self) -> float:
691+
"""The scale parameter of the underlying distribution."""
692+
return self._scale

0 commit comments

Comments
 (0)