Skip to content

Commit b43f5b8

Browse files
dweindlm-philippsdilpath
authored
Implement proper truncation for prior distributions (#335)
Previously, when sampled startpoints were outside the bounds, their value was set to the upper/lower bounds. This may put too much probability mass on the bounds. With these changes, we properly sample from the respective truncated distributions. Closes #330. This also evaluates all priors on the model parameter scale (instead of `parameterScale` scale, see PEtab-dev/PEtab#402. --------- Co-authored-by: Maren Philipps <[email protected]> Co-authored-by: Dilan Pathirana <[email protected]>
1 parent 258abc9 commit b43f5b8

File tree

9 files changed

+633
-201
lines changed

9 files changed

+633
-201
lines changed

doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
nb_execution_mode = "force"
8888
nb_execution_raise_on_error = True
8989
nb_execution_show_tb = True
90+
nb_execution_timeout = 90 # max. seconds/cell
9091

9192
source_suffix = {
9293
".rst": "restructuredtext",

doc/example/distributions.ipynb

Lines changed: 107 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -23,59 +23,99 @@
2323
},
2424
{
2525
"cell_type": "code",
26-
"execution_count": null,
2726
"id": "initial_id",
2827
"metadata": {
2928
"collapsed": true
3029
},
31-
"outputs": [],
3230
"source": [
3331
"import matplotlib.pyplot as plt\n",
3432
"import numpy as np\n",
3533
"import seaborn as sns\n",
3634
"\n",
3735
"from petab.v1.C import *\n",
36+
"from petab.v1.parameters import unscale\n",
3837
"from petab.v1.priors import Prior\n",
3938
"\n",
4039
"sns.set_style(None)\n",
4140
"\n",
4241
"\n",
43-
"def plot(prior: Prior, ax=None):\n",
42+
"def plot(prior: Prior):\n",
4443
" \"\"\"Visualize a distribution.\"\"\"\n",
44+
" fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\n",
45+
" sample = prior.sample(20_000, x_scaled=True)\n",
46+
"\n",
47+
" fig.suptitle(str(prior))\n",
48+
"\n",
49+
" plot_single(prior, ax=ax1, sample=sample, scaled=False)\n",
50+
" plot_single(prior, ax=ax2, sample=sample, scaled=True)\n",
51+
" plt.tight_layout()\n",
52+
" plt.show()\n",
53+
"\n",
54+
"\n",
55+
"def plot_single(\n",
56+
" prior: Prior, scaled: bool = False, ax=None, sample: np.array = None\n",
57+
"):\n",
58+
" fig = None\n",
4559
" if ax is None:\n",
4660
" fig, ax = plt.subplots()\n",
4761
"\n",
48-
" sample = prior.sample(10000)\n",
62+
" if sample is None:\n",
63+
" sample = prior.sample(20_000)\n",
4964
"\n",
50-
" # pdf\n",
65+
" # assuming scaled sample\n",
66+
" if not scaled:\n",
67+
" sample = unscale(sample, prior.transformation)\n",
68+
" bounds = prior.bounds\n",
69+
" else:\n",
70+
" bounds = (\n",
71+
" (prior.lb_scaled, prior.ub_scaled)\n",
72+
" if prior.bounds is not None\n",
73+
" else None\n",
74+
" )\n",
75+
"\n",
76+
" # plot pdf\n",
5177
" xmin = min(\n",
52-
" sample.min(),\n",
53-
" prior.lb_scaled if prior.bounds is not None else sample.min(),\n",
78+
" sample.min(), bounds[0] if prior.bounds is not None else sample.min()\n",
5479
" )\n",
5580
" xmax = max(\n",
56-
" sample.max(),\n",
57-
" prior.ub_scaled if prior.bounds is not None else sample.max(),\n",
81+
" sample.max(), bounds[1] if prior.bounds is not None else sample.max()\n",
5882
" )\n",
83+
" padding = 0.1 * (xmax - xmin)\n",
84+
" xmin -= padding\n",
85+
" xmax += padding\n",
5986
" x = np.linspace(xmin, xmax, 500)\n",
60-
" y = prior.pdf(x)\n",
87+
" y = prior.pdf(x, x_scaled=scaled, rescale=scaled)\n",
6188
" ax.plot(x, y, color=\"red\", label=\"pdf\")\n",
6289
"\n",
6390
" sns.histplot(sample, stat=\"density\", ax=ax, label=\"sample\")\n",
6491
"\n",
65-
" # bounds\n",
92+
" # plot bounds\n",
6693
" if prior.bounds is not None:\n",
67-
" for bound in (prior.lb_scaled, prior.ub_scaled):\n",
94+
" for bound in bounds:\n",
6895
" if bound is not None and np.isfinite(bound):\n",
6996
" ax.axvline(bound, color=\"black\", linestyle=\"--\", label=\"bound\")\n",
7097
"\n",
71-
" ax.set_title(str(prior))\n",
72-
" ax.set_xlabel(\"Parameter value on the parameter scale\")\n",
98+
" if fig is not None:\n",
99+
" ax.set_title(str(prior))\n",
100+
"\n",
101+
" if scaled:\n",
102+
" ax.set_xlabel(\n",
103+
" f\"Parameter value on parameter scale ({prior.transformation})\"\n",
104+
" )\n",
105+
" ax.set_ylabel(\"Rescaled density\")\n",
106+
" else:\n",
107+
" ax.set_xlabel(\"Parameter value\")\n",
108+
"\n",
73109
" ax.grid(False)\n",
74110
" handles, labels = ax.get_legend_handles_labels()\n",
75111
" unique_labels = dict(zip(labels, handles, strict=False))\n",
76112
" ax.legend(unique_labels.values(), unique_labels.keys())\n",
77-
" plt.show()"
78-
]
113+
"\n",
114+
" if ax is None:\n",
115+
" plt.show()"
116+
],
117+
"outputs": [],
118+
"execution_count": null
79119
},
80120
{
81121
"cell_type": "markdown",
@@ -85,38 +125,38 @@
85125
},
86126
{
87127
"cell_type": "code",
88-
"execution_count": null,
89128
"id": "4f09e50a3db06d9f",
90129
"metadata": {},
91-
"outputs": [],
92130
"source": [
93-
"plot(Prior(UNIFORM, (0, 1)))\n",
94-
"plot(Prior(NORMAL, (0, 1)))\n",
95-
"plot(Prior(LAPLACE, (0, 1)))\n",
96-
"plot(Prior(LOG_NORMAL, (0, 1)))\n",
97-
"plot(Prior(LOG_LAPLACE, (1, 0.5)))"
98-
]
131+
"plot_single(Prior(UNIFORM, (0, 1)))\n",
132+
"plot_single(Prior(NORMAL, (0, 1)))\n",
133+
"plot_single(Prior(LAPLACE, (0, 1)))\n",
134+
"plot_single(Prior(LOG_NORMAL, (0, 1)))\n",
135+
"plot_single(Prior(LOG_LAPLACE, (1, 0.5)))"
136+
],
137+
"outputs": [],
138+
"execution_count": null
99139
},
100140
{
101141
"cell_type": "markdown",
102142
"id": "dab4b2d1e0f312d8",
103143
"metadata": {},
104-
"source": "If a parameter scale is specified (`parameterScale=lin|log|log10` not a `parameterScale*`-type distribution), the sample is transformed accordingly (but not the distribution parameters):\n"
144+
"source": "If a parameter scale is specified (`parameterScale=lin|log|log10`), the distribution parameters are used as is without applying the `parameterScale` to them. The exception are the `parameterScale*`-type distributions, as explained below. In the context of PEtab prior distributions, `parameterScale` will only be used for the start point sampling for optimization, where the sample will be transformed accordingly. This is demonstrated below. The left plot always shows the prior distribution for unscaled parameter values, and the right plot shows the prior distribution for scaled parameter values. Note that in the objective function, the prior is always on the unscaled parameters.\n"
105145
},
106146
{
107147
"cell_type": "code",
108-
"execution_count": null,
109148
"id": "f6192c226f179ef9",
110149
"metadata": {},
111-
"outputs": [],
112150
"source": [
113151
"plot(Prior(NORMAL, (10, 2), transformation=LIN))\n",
114152
"plot(Prior(NORMAL, (10, 2), transformation=LOG))\n",
115153
"\n",
116154
"# Note that the log-normal distribution is different\n",
117155
"# from a log-transformed normal distribution:\n",
118156
"plot(Prior(LOG_NORMAL, (10, 2), transformation=LIN))"
119-
]
157+
],
158+
"outputs": [],
159+
"execution_count": null
120160
},
121161
{
122162
"cell_type": "markdown",
@@ -126,53 +166,69 @@
126166
},
127167
{
128168
"cell_type": "code",
129-
"execution_count": null,
130169
"id": "34c95268e8921070",
131170
"metadata": {},
132-
"outputs": [],
133171
"source": [
134172
"plot(Prior(LOG_NORMAL, (10, 2), transformation=LOG))\n",
135173
"plot(Prior(PARAMETER_SCALE_NORMAL, (10, 2)))"
136-
]
174+
],
175+
"outputs": [],
176+
"execution_count": null
137177
},
138178
{
139179
"cell_type": "markdown",
140180
"id": "263c9fd31156a4d5",
141181
"metadata": {},
142-
"source": "Prior distributions can also be defined on the parameter scale by using the types `parameterScaleUniform`, `parameterScaleNormal` or `parameterScaleLaplace`. In these cases, 1) the distribution parameter are interpreted on the transformed parameter scale, and 2) a sample from the given distribution is used directly, without applying any transformation according to `parameterScale` (this implies, that for `parameterScale=lin`, there is no difference between `parameterScaleUniform` and `uniform`):"
182+
"source": "Prior distributions can also be defined on the scaled parameters (i.e., transformed according to `parameterScale`) by using the types `parameterScaleUniform`, `parameterScaleNormal` or `parameterScaleLaplace`. In these cases, the distribution parameters are interpreted on the transformed parameter scale (but not the parameter bounds, see below). This implies, that for `parameterScale=lin`, there is no difference between `parameterScaleUniform` and `uniform`."
143183
},
144184
{
145185
"cell_type": "code",
146-
"execution_count": null,
147186
"id": "5ca940bc24312fc6",
148187
"metadata": {},
149-
"outputs": [],
150188
"source": [
189+
"# different, because transformation!=LIN\n",
151190
"plot(Prior(UNIFORM, (0.01, 2), transformation=LOG10))\n",
152191
"plot(Prior(PARAMETER_SCALE_UNIFORM, (0.01, 2), transformation=LOG10))\n",
153192
"\n",
193+
"# same, because transformation=LIN\n",
154194
"plot(Prior(UNIFORM, (0.01, 2), transformation=LIN))\n",
155195
"plot(Prior(PARAMETER_SCALE_UNIFORM, (0.01, 2), transformation=LIN))"
156-
]
196+
],
197+
"outputs": [],
198+
"execution_count": null
157199
},
158200
{
159201
"cell_type": "markdown",
160202
"id": "b1a8b17d765db826",
161203
"metadata": {},
162-
"source": "To prevent the sampled parameters from exceeding the bounds, the sampled parameters are clipped to the bounds. The bounds are defined in the parameter table. Note that the current implementation does not support sampling from a truncated distribution. Instead, the samples are clipped to the bounds. This may introduce unwanted bias, and thus, should only be used with caution (i.e., the bounds should be chosen wide enough):"
204+
"source": "The given distributions are truncated at the bounds defined in the parameter table:"
163205
},
164206
{
165207
"cell_type": "code",
166-
"execution_count": null,
167208
"id": "4ac42b1eed759bdd",
168209
"metadata": {},
169-
"outputs": [],
170210
"source": [
211+
"plot(Prior(NORMAL, (0, 1), bounds=(-2, 2)))\n",
212+
"plot(Prior(UNIFORM, (0, 1), bounds=(0.1, 0.9)))\n",
213+
"plot(Prior(UNIFORM, (1e-8, 1), bounds=(0.1, 0.9), transformation=LOG10))\n",
214+
"plot(Prior(LAPLACE, (0, 1), bounds=(-0.5, 0.5)))\n",
171215
"plot(\n",
172-
" Prior(NORMAL, (0, 1), bounds=(-4, 4))\n",
173-
") # negligible clipping-bias at 4 sigma\n",
174-
"plot(Prior(UNIFORM, (0, 1), bounds=(0.1, 0.9))) # significant clipping-bias"
175-
]
216+
" Prior(\n",
217+
" PARAMETER_SCALE_UNIFORM,\n",
218+
" (-3, 1),\n",
219+
" bounds=(1e-2, 1),\n",
220+
" transformation=LOG10,\n",
221+
" )\n",
222+
")"
223+
],
224+
"outputs": [],
225+
"execution_count": null
226+
},
227+
{
228+
"metadata": {},
229+
"cell_type": "markdown",
230+
"source": "This results in a constant shift in the probability density, compared to the non-truncated version (https://en.wikipedia.org/wiki/Truncated_distribution), such that the probability density still sums to 1.",
231+
"id": "67de0cace55617a2"
176232
},
177233
{
178234
"cell_type": "markdown",
@@ -182,22 +238,24 @@
182238
},
183239
{
184240
"cell_type": "code",
185-
"execution_count": null,
186241
"id": "581e1ac431860419",
187242
"metadata": {},
188-
"outputs": [],
189243
"source": [
190-
"plot(Prior(NORMAL, (10, 1), bounds=(6, 14), transformation=\"log10\"))\n",
244+
"plot(Prior(NORMAL, (10, 1), bounds=(6, 11), transformation=\"log10\"))\n",
191245
"plot(\n",
192246
" Prior(\n",
193247
" PARAMETER_SCALE_NORMAL,\n",
194-
" (10, 1),\n",
195-
" bounds=(10**6, 10**14),\n",
248+
" (2, 1),\n",
249+
" bounds=(10**0, 10**3),\n",
196250
" transformation=\"log10\",\n",
197251
" )\n",
198252
")\n",
199-
"plot(Prior(LAPLACE, (10, 2), bounds=(6, 14)))"
200-
]
253+
"plot(Prior(LAPLACE, (10, 2), bounds=(6, 14)))\n",
254+
"plot(Prior(LOG_LAPLACE, (1, 0.5), bounds=(0.5, 8)))\n",
255+
"plot(Prior(LOG_NORMAL, (2, 1), bounds=(0.5, 8)))"
256+
],
257+
"outputs": [],
258+
"execution_count": null
201259
}
202260
],
203261
"metadata": {

petab/v1/C.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,13 @@
208208
PARAMETER_SCALE_LAPLACE,
209209
]
210210

211+
#: parameterScale*-type prior distributions
212+
PARAMETER_SCALE_PRIOR_TYPES = [
213+
PARAMETER_SCALE_UNIFORM,
214+
PARAMETER_SCALE_NORMAL,
215+
PARAMETER_SCALE_LAPLACE,
216+
]
217+
211218
#: Supported noise distributions
212219
NOISE_MODELS = [NORMAL, LAPLACE]
213220

0 commit comments

Comments
 (0)