From a3c7370371b38cc0d0b75c2e55b78d9bbcec185c Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Sat, 20 Jul 2024 17:22:51 +0200 Subject: [PATCH] Simplify build system and revert mkdocs (#45) * Move to hatch * Revert: Move documentation from sphinx to mkdocs --- .bumpversion.cfg | 6 - .github/workflows/ci.yaml | 102 ++++++++------ .pre-commit-config.yaml | 11 -- .readthedocs.yaml | 6 +- Makefile | 23 +++- README.md | 13 +- docs/Makefile | 21 +++ docs/_pygments/style.py | 14 ++ docs/_static/theme.css | 84 ++++++++++-- docs/conf.py | 54 ++++++++ docs/{examples.md => examples.rst} | 5 +- docs/index.md | 90 ------------ docs/index.rst | 129 ++++++++++++++++++ docs/{news.md => news.rst} | 7 +- docs/ramsey.data.md | 9 -- docs/ramsey.data.rst | 27 ++++ docs/ramsey.experimental.md | 47 ------- docs/ramsey.experimental.rst | 121 ++++++++++++++++ docs/ramsey.family.md | 9 -- docs/ramsey.family.rst | 25 ++++ docs/ramsey.md | 25 ---- docs/ramsey.nn.md | 9 -- docs/ramsey.nn.rst | 25 ++++ docs/ramsey.rst | 34 +++++ docs/requirements.txt | 17 ++- mkdocs.yaml | 74 ---------- pyproject.toml | 81 ++++++++++- ramsey/__init__.py | 2 +- ramsey/_src/data/dataset_m4.py | 5 +- .../bayesian_linear.py | 10 +- .../bayesian_neural_network.py | 2 +- .../distributions/autoregressive.py | 9 +- .../gaussian_process/gaussian_process.py | 4 +- .../gaussian_process/kernel/non_stationary.py | 10 +- .../gaussian_process/kernel/stationary.py | 16 +-- .../sparse_gaussian_process.py | 8 +- .../recurrent_attentive_neural_process.py | 6 +- .../attentive_neural_process.py | 6 +- ramsey/_src/neural_process/neural_process.py | 6 +- .../neural_process/train_neural_process.py | 6 +- ramsey/_src/nn/MLP.py | 4 +- ramsey/_src/nn/attention/attention.py | 4 +- .../_src/nn/attention/multihead_attention.py | 14 +- setup.py | 68 --------- tox.ini | 48 ------- 45 files changed, 764 insertions(+), 532 deletions(-) delete mode 100644 .bumpversion.cfg create mode 100644 docs/Makefile create mode 100644 docs/_pygments/style.py create mode 100644 docs/conf.py rename docs/{examples.md => examples.rst} (83%) delete mode 100644 docs/index.md create mode 100644 docs/index.rst rename docs/{news.md => news.rst} (94%) delete mode 100644 docs/ramsey.data.md create mode 100644 docs/ramsey.data.rst delete mode 100644 docs/ramsey.experimental.md create mode 100644 docs/ramsey.experimental.rst delete mode 100644 docs/ramsey.family.md create mode 100644 docs/ramsey.family.rst delete mode 100644 docs/ramsey.md delete mode 100644 docs/ramsey.nn.md create mode 100644 docs/ramsey.nn.rst create mode 100644 docs/ramsey.rst delete mode 100644 mkdocs.yaml delete mode 100644 setup.py delete mode 100644 tox.ini diff --git a/.bumpversion.cfg b/.bumpversion.cfg deleted file mode 100644 index dd6dbdd..0000000 --- a/.bumpversion.cfg +++ /dev/null @@ -1,6 +0,0 @@ -[bumpversion] -current_version = 0.2.3 -commit = False -tag = False - -[bumpversion:file:ramsey/__init__.py] diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index c9dc0f9..fa62995 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -7,18 +7,17 @@ on: branches: [ main ] jobs: - build: + precommit: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v3 - uses: pre-commit/action@v3.0.0 - lint: - name: lints + build: runs-on: ubuntu-latest needs: - - build + - precommit strategy: matrix: python-version: [3.11] @@ -30,50 +29,66 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - pip install tox - - name: Run format, lints and types + pip install hatch + - name: Build package run: | - tox -e format,lints,types + hatch build - test: - name: unit tests + lints: runs-on: ubuntu-latest needs: - - build + - precommit strategy: matrix: python-version: [3.11] steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - pip install tox - - name: Run tests and coverage - run: | - tox -e tests - - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v3 - env: - CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} - - name: Run codacy-coverage-reporter - uses: codacy/codacy-coverage-reporter-action@v1 - with: - project-token: ${{ secrets.CODACY_PROJECT_TOKEN }} - coverage-reports: coverage.xml + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install hatch + - name: Run lints + run: | + hatch run test:lints + + tests: + runs-on: ubuntu-latest + needs: + - lints + strategy: + matrix: + python-version: [3.11] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install hatch + - name: Build package + run: | + pip install jaxlib==0.4.24 jax==0.4.24 + - name: Run tests + run: | + hatch run test:tests + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v3 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} examples: - name: examples - runs-on: ubuntu-latest - needs: - - test - strategy: - matrix: - python-version: [3.11] - steps: + runs-on: ubuntu-latest + needs: + - tests + strategy: + matrix: + python-version: [3.11] + steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v3 @@ -81,7 +96,14 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - pip install tox + pip install hatch + - name: Build package + run: | + pip install jaxlib==0.4.24 jax==0.4.24 - name: Run examples run: | - tox -e examples + hatch run test:examples + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v3 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 39a10b4..1d55a15 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,17 +13,6 @@ repos: - id: requirements-txt-fixer - id: trailing-whitespace -- repo: https://github.com/pycqa/bandit - rev: 1.7.1 - hooks: - - id: bandit - language: python - language_version: python3 - types: [python] - args: ["-c", "pyproject.toml"] - additional_dependencies: ["toml"] - files: "(ramsey|examples)" - - repo: https://github.com/pre-commit/mirrors-mypy rev: v0.910-1 hooks: diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 9841692..3332335 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -5,8 +5,10 @@ build: tools: python: "3.11" -mkdocs: - configuration: mkdocs.yaml +sphinx: + builder: html + configuration: docs/conf.py + fail_on_warning: false python: install: diff --git a/Makefile b/Makefile index cb7749d..6896c94 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,22 @@ -PKG_VERSION=`python setup.py --version` +.PHONY: tag, tests, lints, docs, format, examples + +PKG_VERSION=`hatch version` tag: - git tag -a "v$(PKG_VERSION)" -m "v$(PKG_VERSION)" - git push --tag + git tag -a v${PKG_VERSION} -m v${PKG_VERSION} + git push --tag + +tests: + hatch run test:tests + +lints: + hatch run test:lints + +format: + hatch run test:format + +docs: + cd docs && make html + +examples: + hatch run test:examples diff --git a/README.md b/README.md index ee46552..03c0960 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@ # Ramsey [![active](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active) -[![ci](https://github.com/dirmeier/ramsey/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/ramsey/actions/workflows/ci.yaml) -[![codecov](https://codecov.io/gh/ramsey-devs/ramsey/branch/main/graph/badge.svg?token=dn1xNBSalZ)](https://codecov.io/gh/ramsey-devs/ramsey) -[![Codacy quality](https://app.codacy.com/project/badge/Grade/ed13460537fd4ac099c8534b1d9a0202)](https://app.codacy.com/gh/ramsey-devs/ramsey/dashboard?utm_source=gh&utm_medium=referral&utm_content=&utm_campaign=Badge_grade) +[![ci](https://github.com/ramsey-devs/ramsey/actions/workflows/ci.yaml/badge.svg)](https://github.com/ramsey-devs/ramsey/actions/workflows/ci.yaml) +[![coverage](https://codecov.io/gh/ramsey-devs/ramsey/branch/main/graph/badge.svg?token=dn1xNBSalZ)](https://codecov.io/gh/ramsey-devs/ramsey) +[![quality](https://app.codacy.com/project/badge/Grade/ed13460537fd4ac099c8534b1d9a0202)](https://app.codacy.com/gh/ramsey-devs/ramsey/dashboard?utm_source=gh&utm_medium=referral&utm_content=&utm_campaign=Badge_grade) [![documentation](https://readthedocs.org/projects/ramsey/badge/?version=latest)](https://ramsey.readthedocs.io/en/latest/?badge=latest) [![version](https://img.shields.io/pypi/v/ramsey.svg?colorB=black&style=flat)](https://pypi.org/project/ramsey/) @@ -80,8 +80,11 @@ Contributions in the form of pull requests are more than welcome. A good way to In order to contribute: -1) Install Ramsey and dev dependencies via `pip install -e '.[dev]'`, -2) test your contribution/implementation by calling `tox` on the (Unix) command line before submitting a PR. +1) Clone Ramsey and install the package manager `hatch` via `pip install hatch`, +2) create a new branch locally `git checkout -b feature/my-new-feature` or `git checkout -b issue/fixes-bug`, +3) implement your contribution and ideally a test case, +4) test it by calling `make format`, `make lints` and `make tests` on the (Unix) command line, +5) submit a PR 🙂 ## Why Ramsey diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..7e56c26 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,21 @@ +# Minimal makefile for Sphinx documentation + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = ./ +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + rm -rf build + rm -rf source/examples + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/_pygments/style.py b/docs/_pygments/style.py new file mode 100644 index 0000000..9b349aa --- /dev/null +++ b/docs/_pygments/style.py @@ -0,0 +1,14 @@ +from pygments.style import Style +from pygments.token import Error, Text, Whitespace, Other + + +class MyStyle(Style): + background_color = "black" + highlight_color = "#49483e" + + styles = { + Text: "#f8f8f2", # class: '' + Whitespace: "", # class: 'w' + Error: "#960050 bg:#1e0010", # class: 'err' + Other: "", # class 'x' + } diff --git a/docs/_static/theme.css b/docs/_static/theme.css index 9a402ad..f1b2ac9 100644 --- a/docs/_static/theme.css +++ b/docs/_static/theme.css @@ -1,11 +1,79 @@ -:root { - --md-primary-fg-color: #b26679; - --md-primary-bg-color: #ffe9dd; - --pst-color-secondary: rgb(121, 40, 161); - --pst-color-inline-code-links: rgb(121, 40, 161); - --md-accent-fg-color: rgb(121, 40, 161); +html[data-theme="light"] { + --pst-color-primary: rgb(121, 40, 161); + --pst-color-secondary: #b26679; + --pst-color-inline-code-links: #b26679; } -html { - font-size: 130%; +pre > span { + line-height: 20px; +} + +span.kn { + color: rgb(0, 120, 161) !important; +} + +span.ml, span.mi, span.nb { + color: lightcoral !important; +} + +span.k, span.nn { + color: rgb(168, 70, 185) !important; +} + +h1 > code > span { + font-weight: 700 !important; +} + +pre { + border: 0; + font-size: 13px; + background-color: rgb(245, 245, 245) !important; +} + +h1 { + margin-bottom: 50px; +} +h3, h2, h1 { + font-weight: 300 !important; +} + +nav > li > a > code.literal { + padding-top: 0; + padding-bottom: 0; + background-color: white; + border: 0; +} + +nav.bd-links p.caption { + text-transform: uppercase; +} + +code.literal { + background-color: white; + border: 0; + border-radius: 0; +} + +a > code { + font-weight: 575; +} + +a:hover { + text-decoration-thickness: 1px !important; +} + +ul.bd-breadcrumbs li.breadcrumb-item a:hover { + text-decoration-thickness: 1px; +} + +nav.bd-links li > a:hover { + text-decoration-thickness: 1px; +} + +.prev-next-area a p.prev-next-title { + text-decoration: none !important; +} + +button.theme-switch-button { + display: none !important; } diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..25b342b --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,54 @@ +from datetime import date + +project = "Ramsey" +copyright = f"{date.today().year}, the Ramsey developers" +author = "the Ramsey developers" + +extensions = [ + "nbsphinx", + "sphinx.ext.autodoc", + "sphinx_autodoc_typehints", + "sphinx.ext.autosummary", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx.ext.mathjax", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinx_autodoc_typehints", + "sphinx_copybutton", + "sphinx_math_dollar", + "IPython.sphinxext.ipython_console_highlighting", + "sphinx_design", +] + +templates_path = ["_templates"] +html_static_path = ["_static"] +html_css_files = ["theme.css"] + +autodoc_default_options = { + "member-order": "bysource", + "special-members": True, + "exclude-members": "__repr__, __str__, __weakref__", +} + +exclude_patterns = [ + "_build", + "build", + "Thumbs.db", + ".DS_Store", + "notebooks/.ipynb_checkpoints", + "examples/*ipynb", + "examples/*py", +] + +html_theme = "sphinx_book_theme" + +html_theme_options = { + "repository_url": "https://github.com/ramsey-devs/ramsey", + "use_repository_button": True, + "use_download_button": False, + "use_fullscreen_button": False, + "launch_buttons": {"colab_url": "https://colab.research.google.com"}, +} + +html_title = "Ramsey" diff --git a/docs/examples.md b/docs/examples.rst similarity index 83% rename from docs/examples.md rename to docs/examples.rst index ed4efca..1afa793 100644 --- a/docs/examples.md +++ b/docs/examples.rst @@ -1,5 +1,6 @@ -# Examples +Examples +======== -!!! note +.. note:: More self-contained examples code can be found on GitHub in `examples `_. diff --git a/docs/index.md b/docs/index.md deleted file mode 100644 index 61560a3..0000000 --- a/docs/index.md +++ /dev/null @@ -1,90 +0,0 @@ -# 👋 Welcome to Ramsey! - -*Probabilistic deep learning using JAX* - -Ramsey is a library for probabilistic modelling using [`JAX`](https://github.com/google/jax), -[`Flax`](https://github.com/google/flax) and [`NumPyro`](https://github.com/pyro-ppl/numpyro). -It offers high quality implementations of neural processes, Gaussian processes, Bayesian time series and state-space models, clustering processes, -and everything else Bayesian. - -Ramsey makes use of - -- [`Flax's`](https://github.com/google/flax)s module system for models with trainable parameters (such as neural or Gaussian processes), -- [`NumPyro`](https://github.com/pyro-ppl/numpyro) for models where parameters are endowed with prior distributions (such as Gaussian processes, Bayesian neural networks, etc.) - -and is hence aimed at being fully compatible with both of them. - -## Example usage - -You can, for instance, construct a simple neural process like this: - -``` py -from jax import random as jr - -from ramsey import NP -from ramsey.nn import MLP -from ramsey.data import sample_from_sine_function - -def get_neural_process(): - dim = 128 - np = NP( - decoder=MLP([dim] * 3 + [2]), - latent_encoder=( - MLP([dim] * 3), MLP([dim, dim * 2]) - ) - ) - return np - -key = jr.PRNGKey(23) -data = sample_from_sine_function(key) - -neural_process = get_neural_process() -params = neural_process.init( - key, x_context=data.x, y_context=data.y, x_target=data.x -) -``` - -The neural process takes a decoder and a set of two latent encoders as arguments. All of these are typically MLPs, but -Ramsey is flexible enough that you can change them, for instance, to CNNs or RNNs. Once the model is defined, you can initialize -its parameters just like in Flax. - -## Why Ramsey - -Just as the names of other probabilistic languages are inspired by researchers in the field -(e.g., Stan, Edward, Turing), Ramsey takes its name from one of my favourite philosophers/mathematicians, -[Frank Ramsey](https://plato.stanford.edu/entries/ramsey/). - -## Installation - -To install from PyPI, call: - -```sh -pip install ramsey -``` - -To install the latest GitHub , just call the following on the -command line: - -```shell -pip install git+https://github.com/ramsey-devs/ramsey@ -``` - -See also the installation instructions for [`JAX`](https://github.com/google/jax), if -you plan to use Ramsey on GPU/TPU. - -## Contributing - -Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled -["good first issue"](https://github.com/ramsey-devs/ramsey/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22). - -In order to contribute: - -1. Clone Ramsey, and install it and its dev dependencies via `pip install -e '.[dev]'`, -2. create a new branch locally `git checkout -b feature/my-new-feature` or `git checkout -b issue/fixes-bug`, -3. implement your contribution, -4. test it by calling `tox` on the (Unix) command line, -5. submit a PR 🙂 - -## License - -Ramsey is licensed under the Apache 2.0 License. diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..85fdcd1 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,129 @@ +:github_url: https://github.com/ramsey-devs/ramsey/ + +👋 Welcome to Ramsey! +===================== + +*Probabilistic deep learning using JAX* + +Ramsey is a library for probabilistic modelling using `JAX `_ , +`Flax `_ and `NumPyro `_. +It offers high quality implementations of neural processes, Gaussian processes, Bayesian time series and state-space models, clustering processes, +and everything else Bayesian. + +Ramsey makes use of + +- Flax`s module system for models with trainable parameters (such as neural or Gaussian processes), +- NumPyro for models where parameters are endowed with prior distributions (such as Gaussian processes, Bayesian neural networks, etc.) + +and is hence aimed at being fully compatible with both of them. + +Example +------- + +You can, for instance, construct a simple neural process like this: + +.. code-block:: python + + from jax import random as jr + + from ramsey import NP + from ramsey.nn import MLP + from ramsey.data import sample_from_sine_function + + def get_neural_process(): + dim = 128 + np = NP( + decoder=MLP([dim] * 3 + [2]), + latent_encoder=( + MLP([dim] * 3), MLP([dim, dim * 2]) + ) + ) + return np + + key = jr.PRNGKey(23) + data = sample_from_sine_function(key) + + neural_process = get_neural_process() + params = neural_process.init(key, x_context=data.x, y_context=data.y, x_target=data.x) + +The neural process takes a decoder and a set of two latent encoders as arguments. All of these are typically MLPs, but +Ramsey is flexible enough that you can change them, for instance, to CNNs or RNNs. Once the model is defined, you can initialize +its parameters just like in Flax. + +Why Ramsey +---------- + +Just as the names of other probabilistic languages are inspired by researchers in the field +(e.g., Stan, Edward, Turing), Ramsey takes its name from one of my favourite philosophers/mathematicians, +`Frank Ramsey `_. + +Installation +------------ + +To install from PyPI, call: + +.. code-block:: bash + + pip install ramsey + +To install the latest GitHub , just call the following on the +command line: + +.. code-block:: bash + + pip install git+https://github.com/ramsey-devs/ramsey@ + +See also the installation instructions for `JAX `_, if +you plan to use Ramsey on GPU/TPU. + +Contributing +------------ + +Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled +`"good first issue" `_. + +In order to contribute: + +1) Clone Ramsey and install it and the package manager hatch via :code:`pip install hatch` :code:`pip install -e '.[dev]'`, +2) create a new branch locally via :code:`git checkout -b feature/my-new-feature` or :code:`git checkout -b issue/fixes-bug`, +3) implement your contribution, +4) test it by calling ``make format``, ``make lints`` and ``make tests`` on the (Unix) command line, +5) submit a PR 🙂 + +License +------- + +Ramsey is licensed under the Apache 2.0 License. + +.. toctree:: + :maxdepth: 1 + :hidden: + + 🏠 Home + 📰 News + +.. toctree:: + :caption: 🎓 Tutorials + :maxdepth: 1 + :hidden: + + notebooks/inference_with_flax_and_numpyro + notebooks/neural_processes + +.. toctree:: + :caption: 🎓 Example code + :maxdepth: 1 + :hidden: + + examples + +.. toctree:: + :caption: 🧱 API + :maxdepth: 2 + :hidden: + + ramsey + ramsey.data + ramsey.experimental + ramsey.family + ramsey.nn diff --git a/docs/news.md b/docs/news.rst similarity index 94% rename from docs/news.md rename to docs/news.rst index d01f535..eba0878 100644 --- a/docs/news.md +++ b/docs/news.rst @@ -1,14 +1,15 @@ -# 📰 News +📰 News +======= *Latest news on the development of Ramsey.* -!!! note +.. note:: With :code:`Haiku` having gone into a maintenance mode after the merger of Google Brain and Deepmind, Ramsey will now be using :code:`Flax` as default neural network library. Starting from version :code:`v0.2.0` all :code:`Haiku` code will be replaced with :code:`Flax` code. Sadge. -!!! note +.. note:: Starting from version :code:`v0.2.0` on, experimental and possibly non-permant code will be put into submodule :code:`ramsey.experimental`. Hence you can expect that the main code base won't change or be subject to API breaking changes, while the experimental code possibly can or even get diff --git a/docs/ramsey.data.md b/docs/ramsey.data.md deleted file mode 100644 index bbb404f..0000000 --- a/docs/ramsey.data.md +++ /dev/null @@ -1,9 +0,0 @@ -# `ramsey.data` - -Functionality for loading and sampling data sets. - -::: ramsey.data.m4_data - -::: ramsey.data.sample_from_gaussian_process - -::: ramsey.data.sample_from_sine_function diff --git a/docs/ramsey.data.rst b/docs/ramsey.data.rst new file mode 100644 index 0000000..64c258c --- /dev/null +++ b/docs/ramsey.data.rst @@ -0,0 +1,27 @@ +``ramsey.data`` +=============== + +.. currentmodule:: ramsey.data + +Functionality for loading and sampling data sets. + +.. autosummary:: + m4_data + sample_from_gaussian_process + sample_from_sine_function + +M4 competition data +------------------- + +.. autofunction:: m4_data + +Gaussian process samples +------------------------ + +.. autofunction:: sample_from_gaussian_process + + +Noisy sine function samples +--------------------------- + +.. autofunction:: sample_from_sine_function diff --git a/docs/ramsey.experimental.md b/docs/ramsey.experimental.md deleted file mode 100644 index 5340e55..0000000 --- a/docs/ramsey.experimental.md +++ /dev/null @@ -1,47 +0,0 @@ -# `ramsey.experimental` - -Experimental modules such as Gaussian processes or Bayesian neural networks. - -!!! note - - Experimental code is not native Ramsey code and subject to change, and might even get deleted in the future. - Better don't build critical code bases around the :code:`ramsey.experimental` submodule. - - -## Distributions - -::: ramsey.experimental.Autoregressive - -## Models - -::: ramsey.experimental.BNN - -::: ramsey.experimental.RANP - -::: ramsey.experimental.GP - -::: ramsey.experimental.SparseGP - -## Modules - -::: ramsey.experimental.BayesianLinear - -## Covariance functions - -::: ramsey.experimental.ExponentiatedQuadratic - -::: ramsey.experimental.exponentiated_quadratic - -::: ramsey.experimental.Linear - -::: ramsey.experimental.linear - -::: ramsey.experimental.Periodic - -::: ramsey.experimental.periodic - -## Train functions - -::: ramsey.experimental.train_gaussian_process - -::: ramsey.experimental.train_sparse_gaussian_process diff --git a/docs/ramsey.experimental.rst b/docs/ramsey.experimental.rst new file mode 100644 index 0000000..904f0fc --- /dev/null +++ b/docs/ramsey.experimental.rst @@ -0,0 +1,121 @@ +``ramsey.experimental`` +======================= + +.. currentmodule:: ramsey.experimental + +Experimental modules such as Gaussian processes or Bayesian neural networks. + +.. warning:: + + Experimental code is not native Ramsey code and subject to change, and might even get deleted in the future. + Better don't build critical code bases around the :code:`ramsey.experimental` submodule. + +Distributions +------------- + +.. autosummary:: + Autoregressive + +Autoregressive +~~~~~~~~~~~~~~ + +.. autoclass:: Autoregressive + :members: __call__ + +Models +------ + +.. autosummary:: + BNN + RANP + GP + SparseGP + +Bayesian neural network +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: BNN + :members: __call__ + +Recurrent attentive neural process +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: RANP + :members: __call__ + +Gaussian process +~~~~~~~~~~~~~~~~ + +.. autoclass:: GP + :members: __call__ + +Sparse Gaussian process +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: SparseGP + :members: __call__ + +Modules +------- + +.. autosummary:: + BayesianLinear + +BayesianLinear +~~~~~~~~~~~~~~ + +.. autoclass:: BayesianLinear + :members: __call__ + +Covariance functions +-------------------- + +.. autosummary:: + ExponentiatedQuadratic + Linear + Periodic + exponentiated_quadratic + linear + periodic + +ExponentiatedQuadratic +~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: ExponentiatedQuadratic + :members: __call__ + +.. autofunction:: exponentiated_quadratic + +Linear +~~~~~~ + +.. autoclass:: Linear + :members: __call__ + +.. autofunction:: linear + +Periodic +~~~~~~~~~ + +.. autoclass:: Periodic + :members: __call__ + +.. autofunction:: periodic + +Train functions +--------------- + +.. autosummary:: + train_gaussian_process + train_sparse_gaussian_process + +train_gaussian_process +~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: train_gaussian_process + + +train_sparse_gaussian_process +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: train_sparse_gaussian_process diff --git a/docs/ramsey.family.md b/docs/ramsey.family.md deleted file mode 100644 index 8ae7ebc..0000000 --- a/docs/ramsey.family.md +++ /dev/null @@ -1,9 +0,0 @@ -# `ramsey.family` - -Distributional families for constructing likelihoods and predictive distributions. - -## Exponential families - -::: ramsey.family.Gaussian - -::: ramsey.family.NegativeBinomial diff --git a/docs/ramsey.family.rst b/docs/ramsey.family.rst new file mode 100644 index 0000000..87d77f7 --- /dev/null +++ b/docs/ramsey.family.rst @@ -0,0 +1,25 @@ +``ramsey.family`` +================= + +.. currentmodule:: ramsey.family + +Distributional families for constructing likelihoods and predictive distributions. + +Exponential families +-------------------- + +.. autosummary:: + Gaussian + NegativeBinomial + +Gaussian +~~~~~~~~ + +.. autoclass:: Gaussian + :members: + +NegativeBinomial +~~~~~~~~~~~~~~~~ + +.. autoclass:: NegativeBinomial + :members: diff --git a/docs/ramsey.md b/docs/ramsey.md deleted file mode 100644 index 7e284ab..0000000 --- a/docs/ramsey.md +++ /dev/null @@ -1,25 +0,0 @@ -# `ramsey` - -Module containing all implemented probabilistic models and training functions. - -## Models - -::: ramsey.NP - options: - members: - - __call__ - -::: ramsey.ANP - options: - members: - - __call__ - -::: ramsey.DANP - options: - members: - - __call__ - - -## Functions - -::: ramsey.train_neural_process diff --git a/docs/ramsey.nn.md b/docs/ramsey.nn.md deleted file mode 100644 index ca770e7..0000000 --- a/docs/ramsey.nn.md +++ /dev/null @@ -1,9 +0,0 @@ -# ``ramsey.nn`` - -Neural networks and other modules for building neural processes, Bayesian neural networks, etc. - -## Modules - -::: ramsey.nn.MLP - -::: ramsey.nn.MultiHeadAttention diff --git a/docs/ramsey.nn.rst b/docs/ramsey.nn.rst new file mode 100644 index 0000000..2506175 --- /dev/null +++ b/docs/ramsey.nn.rst @@ -0,0 +1,25 @@ +``ramsey.nn`` +============= + +.. currentmodule:: ramsey.nn + +Neural networks and other modules for building neural processes, Bayesian neural networks, etc. + +Modules +------- + +.. autosummary:: + MLP + MultiHeadAttention + +MLP +~~~ + +.. autoclass:: MLP + :members: __call__ + +MultiHeadAttention +~~~~~~~~~~~~~~~~~~ + +.. autoclass:: MultiHeadAttention + :members: __call__ diff --git a/docs/ramsey.rst b/docs/ramsey.rst new file mode 100644 index 0000000..9a35630 --- /dev/null +++ b/docs/ramsey.rst @@ -0,0 +1,34 @@ +``ramsey`` +========== + +.. currentmodule:: ramsey + +Module containing all implemented probabilistic models and training functions. + +Models +------ + +.. autosummary:: + ANP + DANP + NP + +Neural processes +~~~~~~~~~~~~~~~~ + +.. autoclass:: NP + :members: __call__ + +.. autoclass:: ANP + :members: __call__ + +.. autoclass:: DANP + :members: __call__ + +Train functions +--------------- + +.. autosummary:: + train_neural_process + +.. autofunction:: train_neural_process diff --git a/docs/requirements.txt b/docs/requirements.txt index db96996..b7982a6 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -2,7 +2,16 @@ fybdthemes ipython matplotlib -mkdocs -mkdocs-jupyter -mkdocs-material -mkdocstrings[crystal,python] +nbsphinx +seaborn +session_info +sphinx +sphinx-autobuild +sphinx-book-theme +sphinx-copybutton +sphinx-math-dollar +sphinx_autodoc_typehints +sphinx_design +sphinx_fontawesome +sphinx_gallery +sphinxcontrib-fulltoc diff --git a/mkdocs.yaml b/mkdocs.yaml deleted file mode 100644 index 7367193..0000000 --- a/mkdocs.yaml +++ /dev/null @@ -1,74 +0,0 @@ -site_name: "Ramsey" -site_description: Documentation for Ramsey -site_url: https://ramsey.readthedocs.io -repo_url: https://github.com/ramsey-devs/ramsey -repo_name: ramsey-devs/ramsey -edit_uri: "" - -nav: - - 🏠 Home: 'index.md' - - 📰 News: 'news.md' - - 🎓 Tutorials: - - notebooks/inference_with_flax_and_numpyro.ipynb - - notebooks/neural_processes.ipynb - - 🎓 Example code: - - 'examples.md' - - 🧱 API: - - ramsey: ramsey.md - - ramsey.data: ramsey.data.md - - ramsey.experimental: ramsey.experimental.md - - ramsey.family: ramsey.family.md - - ramsey.nn: ramsey.nn.md - -theme: - name: material - features: - - navigation.instant - - navigation.tracking - - navigation.sections - - navigation.expand - - navigation.top - - content.code.copy - - search.suggest - - search.highlight - - content.code.annotate - icon: - repo: fontawesome/brands/github - - -markdown_extensions: - - admonition - - pymdownx.details - - pymdownx.highlight: - anchor_linenums: true - line_spans: __span - pygments_lang_class: true - - pymdownx.inlinehilite - - pymdownx.snippets - - pymdownx.superfences - - -plugins: - - search - - mkdocs-jupyter - - autorefs - - mkdocstrings: - default_handler: python - handlers: - python: - options: - heading_level: 3 - show_root_heading: true - show_root_full_path: true - show_root_toc_entry: true - show_symbol_type_heading: true - show_if_no_docstring: true - show_signature_annotations: true - show_source: false - inherited_members: true - docstring_style: "numpy" - - -extra_css: - - _static/theme.css - - stylesheets/permalinks.css diff --git a/pyproject.toml b/pyproject.toml index aa9846f..92fd552 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,82 @@ [build-system] -requires = ["setuptools", "wheel"] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "ramsey" +description = "Probabilistic deep learning using JAX" +authors = [{name = "Simon Dirmeier", email = "sfyrbnd@pm.me"}] +readme = "README.md" +license = "Apache-2.0" +homepage = "https://github.com/ramsey-devs/ramsey" +keywords=[ + "Bayes", + "jax", + "probabilistic deep learning", + "probabilistic models", + "neural process", +] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] +requires-python = ">=3.10" +dependencies = [ + "chex", + "flax>=0.7.2", + "jax>=0.4.4", + "jaxlib>=0.4.4", + "numpyro", + "optax", + "pandas", + "rmsyutls", + "tqdm", +] +dynamic = ["version"] + +[project.optional-dependencies] +dev = ["pre-commit", "tox", "ruff"] +examples = ["matplotlib"] + +[project.urls] +Documentation = "https://ramsey.rtfd.io" +Homepage = "https://github.com/ramsey-devs/ramsey" + +[tool.hatch.version] +path = "ramsey/__init__.py" + + +[tool.hatch.build.targets.sdist] +exclude = [ + "/.github", + "./gitignore", + "/.pre-commit-config.yaml" +] + +[tool.hatch.envs.test] +dependencies = [ + "ruff>=0.3.0", + "pytest>=7.2.0", + "pytest-cov>=4.0.0", + "matplotlib" +] + +[tool.hatch.envs.test.scripts] +lints = 'ruff check ramsey examples' +format = 'ruff format ramsey examples' +tests = 'pytest -v --doctest-modules --cov=./ramsey --cov-report=xml ramsey' +examples = """ + python ./examples/attentive_neural_process.py -n 10 + python ./examples/experimental/bayesian_neural_network.py -n 10 + python ./examples/experimental/gaussian_process.py -n 10 + python ./examples/experimental/recurrent_attentive_neural_process.py -n 10 + python ./examples/experimental/sparse_gaussian_process.py -n 10 +""" [tool.bandit] skips = ["B101", "B310"] @@ -9,11 +86,11 @@ line-length = 80 exclude = ["*_test.py", "setup.py", "docs/**", "examples/experimental/**"] [tool.ruff.lint] -ignore= ["S101", "ANN1", "ANN2", "ANN0"] select = ["ANN", "D", "E", "F"] extend-select = [ "UP", "I", "PL", "S" ] +ignore= ["S101", "ANN1", "ANN2", "ANN0"] [tool.ruff.lint.pydocstyle] convention= 'numpy' diff --git a/ramsey/__init__.py b/ramsey/__init__.py index 1e40fc4..7a11e02 100644 --- a/ramsey/__init__.py +++ b/ramsey/__init__.py @@ -5,7 +5,7 @@ from ramsey._src.neural_process.neural_process import NP from ramsey._src.neural_process.train_neural_process import train_neural_process -__version__ = "0.2.3" +__version__ = "0.2.4" __all__ = [ "ANP", diff --git a/ramsey/_src/data/dataset_m4.py b/ramsey/_src/data/dataset_m4.py index a507d4d..c210b6d 100644 --- a/ramsey/_src/data/dataset_m4.py +++ b/ramsey/_src/data/dataset_m4.py @@ -1,7 +1,6 @@ import os from collections import namedtuple from dataclasses import dataclass -from typing import Tuple from urllib.parse import urlparse from urllib.request import urlretrieve @@ -114,7 +113,7 @@ class M4Dataset: os.path.join(os.path.dirname(__file__), ".data") ) - def load(self, interval: str) -> Tuple[pd.DataFrame, pd.DataFrame]: + def load(self, interval: str) -> tuple[pd.DataFrame, pd.DataFrame]: """Load a M4 data set. Parameters @@ -156,7 +155,7 @@ def _download(self, dataset): def _load( self, dataset, train_csv_path: str, test_csv_path: str - ) -> Tuple[pd.DataFrame, pd.DataFrame]: + ) -> tuple[pd.DataFrame, pd.DataFrame]: self._download(dataset) train_df = pd.read_csv(train_csv_path, sep=",", header=0, index_col=0) test_df = pd.read_csv(test_csv_path, sep=",", header=0, index_col=0) diff --git a/ramsey/_src/experimental/bayesian_neural_network/bayesian_linear.py b/ramsey/_src/experimental/bayesian_neural_network/bayesian_linear.py index 153603c..0117941 100644 --- a/ramsey/_src/experimental/bayesian_neural_network/bayesian_linear.py +++ b/ramsey/_src/experimental/bayesian_neural_network/bayesian_linear.py @@ -1,5 +1,3 @@ -from typing import Optional, Tuple - from flax import linen as nn from flax.linen import initializers from jax import Array @@ -49,9 +47,9 @@ class BayesianLinear(nn.Module): output_size: int use_bias: bool = True mc_sample_size: int = 10 - w_prior: Optional[dist.Distribution] = dist.Normal(loc=0.0, scale=1.0) - b_prior: Optional[dist.Distribution] = dist.Normal(loc=0.0, scale=1.0) - name: Optional[str] = None + w_prior: dist.Distribution | None = dist.Normal(loc=0.0, scale=1.0) + b_prior: dist.Distribution | None = dist.Normal(loc=0.0, scale=1.0) + name: str | None = None def setup(self): """Construct a linear Bayesian layer.""" @@ -119,7 +117,7 @@ def _get_bias(self, layer_dim, dtype): def _init_param(self, weight_name, param_name, constraint, shape, dtype): init = initializers.xavier_normal() - shape = (shape,) if not isinstance(shape, Tuple) else shape + shape = (shape,) if not isinstance(shape, tuple) else shape params = self.param(f"{weight_name}_{param_name}", init, shape, dtype) params = jnp.where( diff --git a/ramsey/_src/experimental/bayesian_neural_network/bayesian_neural_network.py b/ramsey/_src/experimental/bayesian_neural_network/bayesian_neural_network.py index df4d7a1..27ab586 100644 --- a/ramsey/_src/experimental/bayesian_neural_network/bayesian_neural_network.py +++ b/ramsey/_src/experimental/bayesian_neural_network/bayesian_neural_network.py @@ -1,4 +1,4 @@ -from typing import Iterable +from collections.abc import Iterable from flax import linen as nn from jax import Array diff --git a/ramsey/_src/experimental/distributions/autoregressive.py b/ramsey/_src/experimental/distributions/autoregressive.py index 3b7a18d..cd7d79f 100644 --- a/ramsey/_src/experimental/distributions/autoregressive.py +++ b/ramsey/_src/experimental/distributions/autoregressive.py @@ -1,5 +1,4 @@ from functools import partial -from typing import Optional import jax import numpy as np @@ -53,8 +52,8 @@ def __init__(self, loc, ar_coefficients, scale, length=None): def sample( self, rng_key: jr.PRNGKey, - length: Optional[int] = None, - initial_state: Optional[float] = None, + length: int | None = None, + initial_state: float | None = None, sample_shape=(), ): """Sample from the distribution. @@ -125,8 +124,8 @@ def log_prob(self, value: Array): def mean( self, - length: Optional[int] = None, - initial_state: Optional[float] = None, + length: int | None = None, + initial_state: float | None = None, ): """Compute the mean of the autoregressive distribution. diff --git a/ramsey/_src/experimental/gaussian_process/gaussian_process.py b/ramsey/_src/experimental/gaussian_process/gaussian_process.py index b89baca..c181334 100644 --- a/ramsey/_src/experimental/gaussian_process/gaussian_process.py +++ b/ramsey/_src/experimental/gaussian_process/gaussian_process.py @@ -1,5 +1,3 @@ -from typing import Optional - from flax import linen as nn from flax.linen import initializers from jax import Array @@ -25,7 +23,7 @@ class GP(nn.Module): """ kernel: Kernel - sigma_init: Optional[initializers.Initializer] = None + sigma_init: initializers.Initializer | None = None @nn.compact def __call__(self, x: Array, **kwargs): diff --git a/ramsey/_src/experimental/gaussian_process/kernel/non_stationary.py b/ramsey/_src/experimental/gaussian_process/kernel/non_stationary.py index 8ea78df..53b0a31 100644 --- a/ramsey/_src/experimental/gaussian_process/kernel/non_stationary.py +++ b/ramsey/_src/experimental/gaussian_process/kernel/non_stationary.py @@ -1,5 +1,3 @@ -from typing import Optional - from flax import linen as nn from flax.linen import initializers from jax import Array @@ -23,10 +21,10 @@ class Linear(Kernel, nn.Module): an initializer object from Flax or None """ - active_dims: Optional[list] = None - sigma_b_init: Optional[initializers.Initializer] = initializers.uniform() - sigma_v_init: Optional[initializers.Initializer] = initializers.uniform() - offset_init: Optional[initializers.Initializer] = initializers.uniform() + active_dims: list | None = None + sigma_b_init: initializers.Initializer | None = initializers.uniform() + sigma_v_init: initializers.Initializer | None = initializers.uniform() + offset_init: initializers.Initializer | None = initializers.uniform() def setup(self): """Construct parameters.""" diff --git a/ramsey/_src/experimental/gaussian_process/kernel/stationary.py b/ramsey/_src/experimental/gaussian_process/kernel/stationary.py index cbfa967..262d512 100644 --- a/ramsey/_src/experimental/gaussian_process/kernel/stationary.py +++ b/ramsey/_src/experimental/gaussian_process/kernel/stationary.py @@ -1,5 +1,3 @@ -from typing import Optional, Union - from flax import linen as nn from flax.linen import initializers from jax import Array @@ -26,9 +24,9 @@ class Periodic(Kernel, nn.Module): """ period: float - active_dims: Optional[list] = None - rho_init: Optional[initializers.Initializer] = initializers.uniform() - sigma_init: Optional[initializers.Initializer] = initializers.uniform() + active_dims: list | None = None + rho_init: initializers.Initializer | None = initializers.uniform() + sigma_init: initializers.Initializer | None = initializers.uniform() def setup(self): """Construct the covariance function.""" @@ -74,9 +72,9 @@ class ExponentiatedQuadratic(Kernel, nn.Module): name of the layer """ - active_dims: Optional[list] = None - rho_init: Optional[initializers.Initializer] = None - sigma_init: Optional[initializers.Initializer] = None + active_dims: list | None = None + rho_init: initializers.Initializer | None = None + sigma_init: initializers.Initializer | None = None def setup(self): """Construct a stationary covariance.""" @@ -117,7 +115,7 @@ def exponentiated_quadratic( x1: Array, x2: Array, sigma: float, - rho: Union[float, jnp.ndarray], + rho: float | jnp.ndarray, ): """Exponentiated-quadratic convariance function. diff --git a/ramsey/_src/experimental/gaussian_process/sparse_gaussian_process.py b/ramsey/_src/experimental/gaussian_process/sparse_gaussian_process.py index ea0bd2e..207b4ac 100644 --- a/ramsey/_src/experimental/gaussian_process/sparse_gaussian_process.py +++ b/ramsey/_src/experimental/gaussian_process/sparse_gaussian_process.py @@ -1,5 +1,3 @@ -from typing import Optional - from flax import linen as nn from flax.linen import initializers from jax import Array @@ -38,11 +36,11 @@ class SparseGP(nn.Module): kernel: Kernel n_inducing: int - jitter: Optional[float] = 10e-8 - log_sigma_init: Optional[initializers.Initializer] = initializers.constant( + jitter: float | None = 10e-8 + log_sigma_init: initializers.Initializer | None = initializers.constant( jnp.log(1.0) ) - inducing_init: Optional[initializers.Initializer] = initializers.uniform(1) + inducing_init: initializers.Initializer | None = initializers.uniform(1) @nn.compact def __call__(self, x: Array, **kwargs): diff --git a/ramsey/_src/experimental/timeseries/recurrent_attentive_neural_process.py b/ramsey/_src/experimental/timeseries/recurrent_attentive_neural_process.py index 224ce44..1dd422c 100644 --- a/ramsey/_src/experimental/timeseries/recurrent_attentive_neural_process.py +++ b/ramsey/_src/experimental/timeseries/recurrent_attentive_neural_process.py @@ -1,5 +1,3 @@ -from typing import Optional, Tuple - from chex import assert_axis_dimension, assert_rank from flax import linen as nn from jax import Array @@ -39,8 +37,8 @@ class RANP(ANP): """ decoder: nn.Module - latent_encoder: Optional[Tuple[nn.Module, nn.Module]] = None - deterministic_encoder: Optional[Tuple[nn.Module, Attention]] = None + latent_encoder: tuple[nn.Module, nn.Module] | None = None + deterministic_encoder: tuple[nn.Module, Attention] | None = None family: Family = Gaussian() def setup(self): diff --git a/ramsey/_src/neural_process/attentive_neural_process.py b/ramsey/_src/neural_process/attentive_neural_process.py index 5718d5e..2efe5f0 100644 --- a/ramsey/_src/neural_process/attentive_neural_process.py +++ b/ramsey/_src/neural_process/attentive_neural_process.py @@ -1,5 +1,3 @@ -from typing import Optional - from chex import assert_axis_dimension from flax import linen as nn from jax import numpy as jnp @@ -43,8 +41,8 @@ class ANP(NP): """ decoder: nn.Module - latent_encoder: Optional[nn.Module] = None - deterministic_encoder: Optional[nn.Module] = None + latent_encoder: nn.Module | None = None + deterministic_encoder: nn.Module | None = None family: Family = Gaussian() def setup(self): diff --git a/ramsey/_src/neural_process/neural_process.py b/ramsey/_src/neural_process/neural_process.py index 707697f..98fe893 100644 --- a/ramsey/_src/neural_process/neural_process.py +++ b/ramsey/_src/neural_process/neural_process.py @@ -1,5 +1,3 @@ -from typing import Optional, Tuple - import flax import jax import numpyro.distributions as dist @@ -45,8 +43,8 @@ class NP(nn.Module): """ decoder: nn.Module - latent_encoder: Optional[Tuple[flax.linen.Module, flax.linen.Module]] = None - deterministic_encoder: Optional[flax.linen.Module] = None + latent_encoder: tuple[flax.linen.Module, flax.linen.Module] | None = None + deterministic_encoder: flax.linen.Module | None = None family: Family = Gaussian() def setup(self): diff --git a/ramsey/_src/neural_process/train_neural_process.py b/ramsey/_src/neural_process/train_neural_process.py index 0f1cd2b..2f496ff 100644 --- a/ramsey/_src/neural_process/train_neural_process.py +++ b/ramsey/_src/neural_process/train_neural_process.py @@ -1,5 +1,3 @@ -from typing import Tuple, Union - import jax import numpy as np import optax @@ -33,8 +31,8 @@ def train_neural_process( neural_process: NP, # pylint: disable=invalid-name x: Array, # pylint: disable=invalid-name y: Array, # pylint: disable=invalid-name - n_context: Union[int, Tuple[int]], - n_target: Union[int, Tuple[int]], + n_context: int | tuple[int], + n_target: int | tuple[int], batch_size: int, optimizer=optax.adam(3e-4), n_iter=20000, diff --git a/ramsey/_src/nn/MLP.py b/ramsey/_src/nn/MLP.py index 9e8ce31..e125862 100644 --- a/ramsey/_src/nn/MLP.py +++ b/ramsey/_src/nn/MLP.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Optional +from collections.abc import Callable, Iterable import jax from flax import linen as nn @@ -29,7 +29,7 @@ class MLP(nn.Module): """ output_sizes: Iterable[int] - dropout: Optional[float] = None + dropout: float | None = None kernel_init: initializers.Initializer = default_kernel_init bias_init: initializers.Initializer = initializers.zeros_init() use_bias: bool = True diff --git a/ramsey/_src/nn/attention/attention.py b/ramsey/_src/nn/attention/attention.py index 2a026ec..acc61cf 100644 --- a/ramsey/_src/nn/attention/attention.py +++ b/ramsey/_src/nn/attention/attention.py @@ -1,5 +1,3 @@ -from typing import Optional - import chex from flax import linen as nn from jax import Array @@ -17,7 +15,7 @@ class Attention(nn.Module): an optional embedding network that embeds keys and queries """ - embedding: Optional[nn.Module] + embedding: nn.Module | None @nn.compact def __call__(self, key: Array, value: Array, query: Array): diff --git a/ramsey/_src/nn/attention/multihead_attention.py b/ramsey/_src/nn/attention/multihead_attention.py index 8a57f9e..ae19e59 100644 --- a/ramsey/_src/nn/attention/multihead_attention.py +++ b/ramsey/_src/nn/attention/multihead_attention.py @@ -1,5 +1,5 @@ import functools -from typing import Callable, Optional +from collections.abc import Callable from flax import linen as nn from flax.linen import dot_product_attention, initializers @@ -39,7 +39,7 @@ class MultiHeadAttention(Attention): num_heads: int head_size: int - embedding: Optional[nn.Module] + embedding: nn.Module | None def setup(self) -> None: """Construct the networks.""" @@ -78,11 +78,11 @@ class _MultiHeadAttention(nn.Module): num_heads: int dtype = None param_dtype = jnp.float32 - qkv_features: Optional[int] = None - out_features: Optional[int] = None + qkv_features: int | None = None + out_features: int | None = None broadcast_dropout: bool = True dropout_rate: float = 0.0 - deterministic: Optional[bool] = None + deterministic: bool | None = None precision: PrecisionLike = None kernel_init: Callable = default_kernel_init bias_init: Callable = initializers.zeros_init() @@ -98,8 +98,8 @@ def __call__( query: Array, key: Array, value: Array, - mask: Optional[Array] = None, - deterministic: Optional[bool] = None, + mask: Array | None = None, + deterministic: bool | None = None, ) -> Array: features = self.out_features or query.shape[-1] qkv_features = self.qkv_features or query.shape[-1] diff --git a/setup.py b/setup.py deleted file mode 100644 index a01d8aa..0000000 --- a/setup.py +++ /dev/null @@ -1,68 +0,0 @@ -import re -from os.path import abspath, dirname, join - -from setuptools import find_packages, setup - -PROJECT_PATH = dirname(abspath(__file__)) - - -def readme(): - with open("README.md") as fl: - return fl.read() - - -def _version(): - version = None - for line in open(join(PROJECT_PATH, "ramsey", "__init__.py")): - if line.startswith("__version__"): - version = re.match(r"__version__.*(\d+\.\d+\.\d+).*", line).group(1) - if version is None: - raise ValueError("couldn't parse version number from __init__.py") - return version - - -setup( - name="ramsey", - version=_version(), - description="Probabilistic deep learning using JAX", - long_description=readme(), - long_description_content_type="text/markdown", - url="https://github.com/ramsey-devs/ramsey", - author="The Ramsey developers", - license="Apache 2.0", - keywords=[ - "Bayes", - "jax", - "probabilistic deep learning", - "probabilistic models", - "neural process", - ], - packages=find_packages(), - include_package_data=True, - python_requires=">=3.9", - install_requires=[ - "chex", - "flax>=0.7.2", - "jax>=0.4.4", - "jaxlib>=0.4.4", - "numpyro", - "optax", - "pandas", - "rmsyutls", - "tqdm", - ], - extras_require={ - "dev": ["pre-commit", "tox", "ruff"], - "examples": ["matplotlib"], - }, - classifiers=[ - "Development Status :: 3 - Alpha", - "Intended Audience :: Science/Research", - "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - ], -) diff --git a/tox.ini b/tox.ini deleted file mode 100644 index 7912c4d..0000000 --- a/tox.ini +++ /dev/null @@ -1,48 +0,0 @@ -[tox] -envlist = format, lints, types, tests -isolated_build = True - -[testenv:format] -skip_install = true -commands_pre = - pip install ruff -commands = - ruff format ramsey examples - -[testenv:lints] -skip_install = true -commands_pre = - pip install ruff bandit - pip install -e . -commands = - bandit -r ramsey -c pyproject.toml - ruff check ramsey - - -[testenv:types] -skip_install = true -commands_pre = - pip install mypy -commands = - mypy --ignore-missing-imports ramsey - -[testenv:tests] -skip_install = true -commands_pre = - pip install pytest - pip install pytest-cov - pip install -e . -commands = - pytest -v --doctest-modules --cov=./ramsey --cov-report=xml ramsey - -[testenv:examples] -skip_install = true -commands_pre = - pip install -e . - pip install matplotlib -commands = - python examples/attentive_neural_process.py -n 10 - python examples/experimental/bayesian_neural_network.py -n 10 - python examples/experimental/gaussian_process.py -n 10 - python examples/experimental/recurrent_attentive_neural_process.py -n 10 - python examples/experimental/sparse_gaussian_process.py -n 10