Skip to content

Commit

Permalink
Move to flax (#35)
Browse files Browse the repository at this point in the history
* Update dependencies: remove distrax/numpyro/haiku and add flax/tfp-jax
* Move to Flax
* Move GP and else to experimental
* Remove forecasting notebook
  • Loading branch information
dirmeier authored Sep 3, 2023
1 parent 9d261c9 commit 4840d00
Show file tree
Hide file tree
Showing 78 changed files with 3,320 additions and 3,289 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,5 @@ cython_debug/

# vscode
.vscode/

.data/
45 changes: 28 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,53 @@
[![documentation](https://readthedocs.org/projects/ramsey/badge/?version=latest)](https://ramsey.readthedocs.io/en/latest/?badge=latest)
[![version](https://img.shields.io/pypi/v/ramsey.svg?colorB=black&style=flat)](https://pypi.org/project/ramsey/)

> Probabilistic modelling using Haiku and JAX
> Probabilistic modelling using JAX
## About

Ramsey is a library for probabilistic modelling using [Haiku](https://github.com/deepmind/dm-haiku) and [JAX](https://github.com/google/jax).
It builds upon the same module system that Haiku is using and is hence fully compatible with Haiku's and NumPyro's API.
Ramsey is a library for probabilistic modelling using [JAX](https://github.com/google/jax),
[Flax](https://github.com/google/flax) and [NumPyro](https://github.com/pyro-ppl/numpyro).
It offers high quality implementations of neural processes, Gaussian processes, Bayesian time series and state-space models, clustering processes,
and everything else Bayesian.

Ramsey makes use of

- Flax`s module system for models with trainable parameters (such as neural or Gaussian processes),
- NumPyro for models where parameters are endowed with prior distributions (such as Gaussian processes, Bayesian neural networks, ARMA models)

and is hence aimed at being fully compatible with both of them.

## Example usage

Ramsey uses to Haiku's module system to construct probabilistic models
and define parameters. For instance, a simple neural process can be constructed like this:
You can, for instance, construct a simple neural process like this:

```python
import haiku as hk
import jax.random as random
from jax import random as jr

from ramsey import NP
from ramsey import NP, MLP
from ramsey.data import sample_from_sine_function

def neural_process(**kwargs):
def get_neural_process():
dim = 128
np = NP(
decoder=hk.nets.MLP([dim] * 3 + [2]),
decoder=MLP([dim] * 3 + [2]),
latent_encoder=(
hk.nets.MLP([dim] * 3), hk.nets.MLP([dim, dim * 2])
MLP([dim] * 3), MLP([dim, dim * 2])
)
)
return np(**kwargs)
return np

key = random.PRNGKey(23)
(x, y), _ = sample_from_sine_function(key)
key = jr.PRNGKey(23)
data = sample_from_sine_function(key)

neural_process = hk.transform(neural_process)
params = neural_process.init(key, x_context=x, y_context=y, x_target=x)
neural_process = get_neural_process()
params = neural_process.init(key, x_context=data.x, y_context=data.y, x_target=data.x)
```

The neural process takes a decoder and a set of two latent encoders as argument. All of these are typically MLPs, but
Ramsey is flexible enough that you can change them, for instance, to CNNs or RNNs. Once the model is defined, you can initialize
its parameters just like in Flax.

## Installation

To install from PyPI, call:
Expand All @@ -58,7 +69,7 @@ command line:
pip install git+https://github.com/ramsey-devs/ramsey@<RELEASE>
```

See also the installation instructions for [Haiku](https://github.com/deepmind/dm-haiku) and [JAX](https://github.com/google/jax), if
See also the installation instructions for [JAX](https://github.com/google/jax), if
you plan to use Ramsey on GPU/TPU.

## Contributing
Expand Down
7 changes: 4 additions & 3 deletions docs/_static/theme.css
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
html[data-theme="dark"], html[data-theme="light"] {
--pst-color-primary: #0048bc;
/*--pst-color-primary: #204a87;*/
/*--pst-color-link: #204a87;*/
/*--sd-color-primary: #204a87;*/
}

h1 > code > span {
font-family: var(--pst-font-family-monospace);
color: #0048bc;
/*color: #204a87;*/
font-weight: 700;
}

Expand All @@ -15,7 +17,6 @@ nav > li > a > code.literal {
border: 0px;
}


nav.bd-links p.caption {
text-transform: uppercase;
}
Expand Down
19 changes: 13 additions & 6 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,25 @@
"examples/*py"
]

intersphinx_mapping = {
"haiku": ("https://dm-haiku.readthedocs.io/en/latest/", None),
"jax": ("https://jax.readthedocs.io/en/latest/", None),
"numpyro": ("https://num.pyro.ai/en/stable/", None),
}
# intersphinx_mapping = {
# "haiku": ("https://dm-haiku.readthedocs.io/en/latest/", None),
# "jax": ("https://jax.readthedocs.io/en/latest/", None),
# "numpyro": ("https://num.pyro.ai/en/stable/", None),
# }

html_theme = "sphinx_book_theme"

html_theme_options = {
"repository_url": "https://github.com/ramsey-devs/ramsey",
"use_repository_button": True,
"use_download_button": False,

}
# html_sidebars = {
# "**": ["sbt-sidebar-nav.html"]
# }
html_theme_options = {
"extra_navbar": ""
}

html_title = "Ramsey"
html_title = "Ramsey 🚀"
7 changes: 7 additions & 0 deletions docs/examples.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Examples
========

.. note::

More example code can be found on GitHub in `examples <https://github.com/ramsey-devs/ramsey/tree/main/examples>`_.
The examples are executable from the command line, so forking/cloning the code suffices to run them.
4 changes: 0 additions & 4 deletions docs/getting_started.rst

This file was deleted.

77 changes: 45 additions & 32 deletions docs/index.rst
Original file line number Diff line number Diff line change
@@ -1,42 +1,51 @@
:github_url: https://github.com/ramsey-devs/ramsey/

Ramsey: probabilistic modelling using Haiku
===========================================
👋 Welcome to Ramsey!
=====================

Ramsey is a library for probabilistic modelling using `Haiku <https://github.com/deepmind/dm-haiku>`_ and `JAX <https://github.com/google/jax>`_.
It builds upon the same module system that Haiku is using and is hence fully compatible with its API. Ramsey implements **probabilistic** models, such as neural processes, Gaussian processes,
Bayesian neural networks, Bayesian timeseries models and state-space-models, and more.
Ramsey is a library for probabilistic modelling using `JAX <https://github.com/google/jax>`_ ,
`Flax <https://github.com/google/flax>`_ and `NumPyro <https://github.com/pyro-ppl/numpyro>`_.
It offers high quality implementations of neural processes, Gaussian processes, Bayesian time series and state-space models, clustering processes,
and everything else Bayesian.

Example
-------
Ramsey makes use of

- Flax`s module system for models with trainable parameters (such as neural or Gaussian processes),
- NumPyro for models where parameters are endowed with prior distributions (such as Gaussian processes, Bayesian neural networks, etc.)

and is hence aimed at being fully compatible with both of them.

Ramsey uses to Haiku's module system to construct probabilistic models
and define parameters. For instance, a simple neural process can be constructed like this:
Example usage
-------------

You can, for instance, construct a simple neural process like this:

.. code-block:: python
import haiku as hk
import jax.random as random
from ramsey.data import sample_from_sinus_function
from ramsey.models import NP
from jax import random as jr
from ramsey import NP, MLP
from ramsey.data import sample_from_sine_function
def neural_process(**kwargs):
def get_neural_process():
dim = 128
np = NP(
decoder=hk.nets.MLP([dim] * 3 + [2]),
decoder=MLP([dim] * 3 + [2]),
latent_encoder=(
hk.nets.MLP([dim] * 3), hk.nets.MLP([dim, dim * 2])
MLP([dim] * 3), MLP([dim, dim * 2])
)
)
return np(**kwargs)
return np
(x, y), _ = sample_from_sinus_function(random.PRNGKey(0))
key = jr.PRNGKey(23)
data = sample_from_sine_function(key)
neural_process = hk.transform(neural_process)
params = neural_process.init(
random.PRNGKey(1), x_context=x, y_context=y, x_target=x
)
neural_process = get_neural_process()
params = neural_process.init(key, x_context=data.x, y_context=data.y, x_target=data.x)
The neural process takes a decoder and a set of two latent encoders as argument. All of these are typically MLPs, but
Ramsey is flexible enough that you can change them, for instance, to CNNs or RNNs. Once the model is defined, you can initialize
its parameters just like in Flax.

Why Ramsey
----------
Expand All @@ -61,7 +70,7 @@ command line:
pip install git+https://github.com/ramsey-devs/ramsey@<RELEASE>
See also the installation instructions for `Haiku <https://github.com/deepmind/dm-haiku>`_ and `JAX <https://github.com/google/jax>`_, if
See also the installation instructions for `JAX <https://github.com/google/jax>`_, if
you plan to use Ramsey on GPU/TPU.

Contributing
Expand All @@ -83,29 +92,33 @@ License

Ramsey is licensed under the Apache 2.0 License.


.. toctree::
:maxdepth: 1
:hidden:

Home <self>
🏠 Home <self>
📰 News <news>

.. toctree::
:caption: Tutorials
:caption: 🎓 Tutorials
:maxdepth: 1
:hidden:

notebooks/neural_process
notebooks/forecasting

.. toctree::
:caption: API
:caption: 🎓 Example code
:maxdepth: 1
:hidden:

examples

.. toctree::
:caption: 🧱 API
:maxdepth: 1
:hidden:

ramsey
ramsey.attention
ramsey.contrib
ramsey.data
ramsey.experimental
ramsey.family
ramsey.kernels
ramsey.train
14 changes: 14 additions & 0 deletions docs/news.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
📰 News
=======

.. note::

With :code:`Haiku` having gone into a maintenance mode after the merger of Google Brain and Deepmind,
Ramsey will now be using :code:`Flax` as default neural network library.
Starting from version :code:`v0.2.0` all :code:`Haiku` code will be replaced with :code:`Flax` code. Sadge.

.. note::

Starting from version :code:`v0.2.0` on, experimental and possibly non-permant code will be put into submodule :code:`ramsey.experimental`.
Hence you can expect that the main code base won't change or be subject to API breaking changes, while the experimental code possibly can or even get
deleted in the future.
Loading

0 comments on commit 4840d00

Please sign in to comment.