Skip to content

Commit 6a586a7

Browse files
committed
Use networkx instead of causalgraphicalmodels
1 parent fd76466 commit 6a586a7

20 files changed

+290
-276
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ I am a fan of the book [*Statistical Rethinking*](https://xcelab.net/rm/statisti
1313

1414
## Installation
1515

16-
The following tools are used for some analysis and visualizations: [arviz](https://arviz-devs.github.io/arviz/) for [posteriors](https://en.wikipedia.org/wiki/Posterior_probability), [causalgraphicalmodels](https://github.com/ijmbarr/causalgraphicalmodels) and [daft](https://docs.daft-pgm.org/en/latest/) for [causal graphs](https://en.wikipedia.org/wiki/Causal_graph), and (optional) [ete3](http://etetoolkit.org/) for [phylogenetic trees](https://en.wikipedia.org/wiki/Phylogenetic_tree).
16+
The following tools are used for some analysis and visualizations: [arviz](https://arviz-devs.github.io/arviz/) for [posteriors](https://en.wikipedia.org/wiki/Posterior_probability), [networkx](https://networkx.org/) and [daft](https://docs.daft-pgm.org/en/latest/) for [causal graphs](https://en.wikipedia.org/wiki/Causal_graph), and (optional) [ete3](http://etetoolkit.org/) for [phylogenetic trees](https://en.wikipedia.org/wiki/Phylogenetic_tree).
1717

1818
```sh
19-
pip install numpyro arviz causalgraphicalmodels daft
19+
pip install numpyro arviz daft networkx
2020
```
2121

2222
## Excercises

notebooks/00_preface.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"metadata": {},
1414
"outputs": [],
1515
"source": [
16-
"!pip install -q numpyro arviz causalgraphicalmodels daft"
16+
"!pip install -q numpyro arviz"
1717
]
1818
},
1919
{
@@ -231,14 +231,14 @@
231231
},
232232
"source": [
233233
"```sh\n",
234-
"pip install numpyro arviz causalgraphicalmodels daft\n",
234+
"pip install numpyro arviz daft networkx\n",
235235
"```"
236236
]
237237
}
238238
],
239239
"metadata": {
240240
"kernelspec": {
241-
"display_name": "Python 3",
241+
"display_name": "Python 3 (ipykernel)",
242242
"language": "python",
243243
"name": "python3"
244244
},
@@ -252,7 +252,7 @@
252252
"name": "python",
253253
"nbconvert_exporter": "python",
254254
"pygments_lexer": "ipython3",
255-
"version": "3.8.8"
255+
"version": "3.11.6"
256256
},
257257
"varInspector": {
258258
"cols": {

notebooks/01_the_golem_of_prague.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
],
1818
"metadata": {
1919
"kernelspec": {
20-
"display_name": "Python 3",
20+
"display_name": "Python 3 (ipykernel)",
2121
"language": "python",
2222
"name": "python3"
2323
},
@@ -31,7 +31,7 @@
3131
"name": "python",
3232
"nbconvert_exporter": "python",
3333
"pygments_lexer": "ipython3",
34-
"version": "3.8.8"
34+
"version": "3.11.6"
3535
},
3636
"varInspector": {
3737
"cols": {

notebooks/02_small_worlds_and_large_worlds.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"metadata": {},
2222
"outputs": [],
2323
"source": [
24-
"!pip install -q numpyro arviz causalgraphicalmodels daft"
24+
"!pip install -q numpyro arviz"
2525
]
2626
},
2727
{
@@ -224,7 +224,7 @@
224224
"params = svi_result.params\n",
225225
"\n",
226226
"# display summary of quadratic approximation\n",
227-
"samples = guide.sample_posterior(random.PRNGKey(1), params, (1000,))\n",
227+
"samples = guide.sample_posterior(random.PRNGKey(1), params, sample_shape=(1000,))\n",
228228
"numpyro.diagnostics.print_summary(samples, prob=0.89, group_by_chain=False)"
229229
]
230230
},
@@ -323,7 +323,7 @@
323323
],
324324
"metadata": {
325325
"kernelspec": {
326-
"display_name": "Python 3",
326+
"display_name": "Python 3 (ipykernel)",
327327
"language": "python",
328328
"name": "python3"
329329
},
@@ -337,7 +337,7 @@
337337
"name": "python",
338338
"nbconvert_exporter": "python",
339339
"pygments_lexer": "ipython3",
340-
"version": "3.8.8"
340+
"version": "3.11.6"
341341
}
342342
},
343343
"nbformat": 4,

notebooks/03_sampling_the_imaginary.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"metadata": {},
2222
"outputs": [],
2323
"source": [
24-
"!pip install -q numpyro arviz causalgraphicalmodels daft"
24+
"!pip install -q numpyro arviz"
2525
]
2626
},
2727
{
@@ -816,7 +816,7 @@
816816
"name": "python",
817817
"nbconvert_exporter": "python",
818818
"pygments_lexer": "ipython3",
819-
"version": "3.9.10"
819+
"version": "3.11.6"
820820
},
821821
"widgets": {
822822
"application/vnd.jupyter.widget-state+json": {

notebooks/04_geocentric_models.ipynb

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"metadata": {},
1414
"outputs": [],
1515
"source": [
16-
"!pip install -q numpyro arviz causalgraphicalmodels daft"
16+
"!pip install -q numpyro arviz"
1717
]
1818
},
1919
{
@@ -904,7 +904,7 @@
904904
}
905905
],
906906
"source": [
907-
"samples = m4_1.sample_posterior(random.PRNGKey(1), p4_1, (1000,))\n",
907+
"samples = m4_1.sample_posterior(random.PRNGKey(1), p4_1, sample_shape=(1000,))\n",
908908
"print_summary(samples, 0.89, False)"
909909
]
910910
},
@@ -978,7 +978,7 @@
978978
"svi = SVI(model, m4_2, optim.Adam(1), Trace_ELBO(), height=d2.height.values)\n",
979979
"svi_result = svi.run(random.PRNGKey(0), 2000)\n",
980980
"p4_2 = svi_result.params\n",
981-
"samples = m4_2.sample_posterior(random.PRNGKey(1), p4_2, (1000,))\n",
981+
"samples = m4_2.sample_posterior(random.PRNGKey(1), p4_2, sample_shape=(1000,))\n",
982982
"print_summary(samples, 0.89, False)"
983983
]
984984
},
@@ -1007,7 +1007,7 @@
10071007
}
10081008
],
10091009
"source": [
1010-
"samples = m4_1.sample_posterior(random.PRNGKey(1), p4_1, (1000,))\n",
1010+
"samples = m4_1.sample_posterior(random.PRNGKey(1), p4_1, sample_shape=(1000,))\n",
10111011
"vcov = jnp.cov(jnp.stack(list(samples.values()), axis=0))\n",
10121012
"vcov"
10131013
]
@@ -1064,7 +1064,7 @@
10641064
}
10651065
],
10661066
"source": [
1067-
"post = m4_1.sample_posterior(random.PRNGKey(1), p4_1, (int(1e4),))\n",
1067+
"post = m4_1.sample_posterior(random.PRNGKey(1), p4_1, sample_shape=(int(1e4),))\n",
10681068
"{latent: list(post[latent][:6]) for latent in post}"
10691069
]
10701070
},
@@ -1369,7 +1369,7 @@
13691369
}
13701370
],
13711371
"source": [
1372-
"samples = m4_3.sample_posterior(random.PRNGKey(1), p4_3, (1000,))\n",
1372+
"samples = m4_3.sample_posterior(random.PRNGKey(1), p4_3, sample_shape=(1000,))\n",
13731373
"samples.pop(\"mu\")\n",
13741374
"print_summary(samples, 0.89, False)"
13751375
]
@@ -1429,7 +1429,7 @@
14291429
],
14301430
"source": [
14311431
"az.plot_pair(d2[[\"weight\", \"height\"]].to_dict(orient=\"list\"))\n",
1432-
"post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, (1000,))\n",
1432+
"post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, sample_shape=(1000,))\n",
14331433
"a_map = jnp.mean(post[\"a\"])\n",
14341434
"b_map = jnp.mean(post[\"b\"])\n",
14351435
"x = jnp.linspace(d2.weight.min(), d2.weight.max(), 101)\n",
@@ -1464,7 +1464,7 @@
14641464
}
14651465
],
14661466
"source": [
1467-
"post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, (1000,))\n",
1467+
"post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, sample_shape=(1000,))\n",
14681468
"{latent: list(post[latent].reshape(-1)[:5]) for latent in post}"
14691469
]
14701470
},
@@ -1539,7 +1539,7 @@
15391539
],
15401540
"source": [
15411541
"# extract 20 samples from the posterior\n",
1542-
"post = mN.sample_posterior(random.PRNGKey(1), pN, (20,))\n",
1542+
"post = mN.sample_posterior(random.PRNGKey(1), pN, sample_shape=(20,))\n",
15431543
"\n",
15441544
"# display raw data and sample size\n",
15451545
"ax = az.plot_pair(dN[[\"weight\", \"height\"]].to_dict(orient=\"list\"))\n",
@@ -1568,7 +1568,7 @@
15681568
"metadata": {},
15691569
"outputs": [],
15701570
"source": [
1571-
"post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, (1000,))\n",
1571+
"post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, sample_shape=(1000,))\n",
15721572
"mu_at_50 = post[\"a\"] + post[\"b\"] * (50 - xbar)"
15731573
]
15741574
},
@@ -1797,7 +1797,7 @@
17971797
"metadata": {},
17981798
"outputs": [],
17991799
"source": [
1800-
"post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, (1000,))\n",
1800+
"post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, sample_shape=(1000,))\n",
18011801
"mu_link = lambda weight: post[\"a\"] + post[\"b\"] * (weight - xbar)\n",
18021802
"weight_seq = jnp.arange(start=25, stop=71, step=1)\n",
18031803
"mu = vmap(mu_link)(weight_seq).T\n",
@@ -1924,7 +1924,7 @@
19241924
"metadata": {},
19251925
"outputs": [],
19261926
"source": [
1927-
"post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, (1000,))\n",
1927+
"post = m4_3.sample_posterior(random.PRNGKey(1), p4_3, sample_shape=(1000,))\n",
19281928
"weight_seq = jnp.arange(25, 71)\n",
19291929
"sim_height = vmap(\n",
19301930
" lambda i, weight: dist.Normal(\n",
@@ -2126,7 +2126,7 @@
21262126
}
21272127
],
21282128
"source": [
2129-
"samples = m4_5.sample_posterior(random.PRNGKey(1), p4_5, (1000,))\n",
2129+
"samples = m4_5.sample_posterior(random.PRNGKey(1), p4_5, sample_shape=(1000,))\n",
21302130
"print_summary({k: v for k, v in samples.items() if k != \"mu\"}, 0.89, False)"
21312131
]
21322132
},
@@ -2145,7 +2145,7 @@
21452145
"source": [
21462146
"weight_seq = jnp.linspace(start=-2.2, stop=2, num=30)\n",
21472147
"pred_dat = {\"weight_s\": weight_seq, \"weight_s2\": weight_seq**2}\n",
2148-
"post = m4_5.sample_posterior(random.PRNGKey(1), p4_5, (1000,))\n",
2148+
"post = m4_5.sample_posterior(random.PRNGKey(1), p4_5, sample_shape=(1000,))\n",
21492149
"predictive = Predictive(m4_5.model, post)\n",
21502150
"mu = predictive(random.PRNGKey(2), **pred_dat)[\"mu\"]\n",
21512151
"mu_mean = jnp.mean(mu, 0)\n",
@@ -2479,7 +2479,7 @@
24792479
}
24802480
],
24812481
"source": [
2482-
"post = m4_7.sample_posterior(random.PRNGKey(1), p4_7, (1000,))\n",
2482+
"post = m4_7.sample_posterior(random.PRNGKey(1), p4_7, sample_shape=(1000,))\n",
24832483
"w = jnp.mean(post[\"w\"], 0)\n",
24842484
"plt.subplot(\n",
24852485
" xlim=(d2.year.min(), d2.year.max()),\n",
@@ -2578,7 +2578,7 @@
25782578
"name": "python",
25792579
"nbconvert_exporter": "python",
25802580
"pygments_lexer": "ipython3",
2581-
"version": "3.9.10"
2581+
"version": "3.11.6"
25822582
},
25832583
"toc": {
25842584
"base_numbering": 1,

notebooks/05_the_many_variables_and_the_spurious_waffles.ipynb

Lines changed: 58 additions & 94 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)