diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 6511181..7d1102d 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -10,8 +10,9 @@ jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 + - uses: actions/checkout@v3 + - uses: actions/setup-python@v3 + - uses: pre-commit/action@v3.0.0 test: name: unit tests @@ -22,9 +23,9 @@ jobs: matrix: python-version: [3.9] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v3 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -53,9 +54,9 @@ jobs: matrix: python-version: [3.9] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v3 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -74,9 +75,9 @@ jobs: matrix: python-version: [3.9] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v3 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 36f03bb..f4fa180 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,14 +24,14 @@ repos: hooks: - id: black args: ["--config=pyproject.toml"] - files: "(ramsey)" + files: "(ramsey|examples)" - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort args: ["--settings-path=pyproject.toml"] - files: "(ramsey)" + files: "(ramsey|examples)" - repo: https://github.com/pycqa/bandit rev: 1.7.1 @@ -42,22 +42,23 @@ repos: types: [python] args: ["-c", "pyproject.toml"] additional_dependencies: ["toml"] - files: "(ramsey)" + files: "(ramsey|examples)" -- repo: local +- repo: https://github.com/PyCQA/flake8 + rev: 5.0.1 hooks: - - id: pylint - name: pylint - entry: pylint - language: python - files: "ramsey" + - id: flake8 + additional_dependencies: [ + flake8-typing-imports==1.14.0, + flake8-pyproject==1.1.0.post0 + ] - repo: https://github.com/pre-commit/mirrors-mypy rev: v0.910-1 hooks: - id: mypy args: ["--ignore-missing-imports"] - files: "(ramsey)" + files: "(ramsey|examples)" - repo: https://github.com/nbQA-dev/nbQA rev: 1.3.1 diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 8807d40..fa2543e 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -1,13 +1,16 @@ version: 2 build: - os: ubuntu-20.04 + os: "ubuntu-20.04" tools: python: "3.9" sphinx: - configuration: docs/source/conf.py + configuration: docs/conf.py + fail_on_warning: false python: - install: - - requirements: docs/requirements.txt + install: + - method: pip + path: . + - requirements: docs/requirements.txt diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..990a179 --- /dev/null +++ b/Makefile @@ -0,0 +1,5 @@ +PKG_VERSION=`python setup.py --version` + +tag: + git tag -a $(PKG_VERSION) -m $(PKG_VERSION) + git push --tag diff --git a/docs/.gitignore b/docs/.gitignore index 51ac3e8..af6aca3 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -1,2 +1,6 @@ source/examples/ source/examples/* +build/ +build/* +_autosummary/ +_autosummary/* diff --git a/docs/Makefile b/docs/Makefile index 43c0fbd..0ad1935 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -5,7 +5,7 @@ # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build -SOURCEDIR = source +SOURCEDIR = ./ BUILDDIR = build # Put it first so that "make" without argument is like "make help". diff --git a/docs/_static/theme.css b/docs/_static/theme.css new file mode 100644 index 0000000..503d9c4 --- /dev/null +++ b/docs/_static/theme.css @@ -0,0 +1,27 @@ +html[data-theme="dark"], html[data-theme="light"] { + --pst-color-primary: #0048bc; +} + +h1 > code > span { + font-family: var(--pst-font-family-monospace); + color: #0048bc; + font-weight: 700; +} + +nav > li > a > code.literal { + padding-top: 0rem; + padding-bottom: 0rem; + background-color: white; + border: 0px; +} + + +nav.bd-links p.caption { + text-transform: uppercase; +} + +code.literal { + background-color: white; + border: 0px; + border-radius: 0px; +} diff --git a/docs/source/conf.py b/docs/conf.py similarity index 91% rename from docs/source/conf.py rename to docs/conf.py index 7824f1b..976278f 100644 --- a/docs/source/conf.py +++ b/docs/conf.py @@ -1,15 +1,13 @@ -import glob -import os from datetime import date project = "Ramsey" copyright = f"{date.today().year}, the Ramsey developers" -author = "Ramsey developers" -release = "0.0.2" +author = "the Ramsey developers" extensions = [ "nbsphinx", "sphinx.ext.autodoc", + 'sphinx_autodoc_typehints', "sphinx.ext.autosummary", "sphinx.ext.doctest", "sphinx.ext.intersphinx", @@ -19,10 +17,13 @@ "sphinx_autodoc_typehints", "sphinx_math_dollar", "IPython.sphinxext.ipython_console_highlighting", + 'sphinx_design' ] + templates_path = ["_templates"] html_static_path = ["_static"] +html_css_files = ['theme.css'] autodoc_default_options = { "member-order": "bysource", diff --git a/docs/getting_started.rst b/docs/getting_started.rst new file mode 100644 index 0000000..8b06c7b --- /dev/null +++ b/docs/getting_started.rst @@ -0,0 +1,4 @@ +Philosophy +========== + +Todo diff --git a/docs/source/index.rst b/docs/index.rst similarity index 65% rename from docs/source/index.rst rename to docs/index.rst index 45b880b..9650f83 100644 --- a/docs/source/index.rst +++ b/docs/index.rst @@ -1,10 +1,17 @@ :github_url: https://github.com/ramsey-devs/ramsey/ -Ramsey documentation -==================== +Ramsey: probabilistic modelling using Haiku +=========================================== Ramsey is a library for probabilistic modelling using `Haiku `_ and `JAX `_. -It builds upon the same module system that Haiku is using and is hence fully compatible with Haiku's, NumPyro's API. +It builds upon the same module system that Haiku is using and is hence fully compatible with its API. Ramsey implements **probabilistic** models, such as neural processes, Gaussian processes, +Bayesian neural networks, Bayesian timeseries models and state-space-models, and more. + +Example +------- + +Ramsey uses to Haiku's module system to construct probabilistic models +and define parameters. For instance, a simple neural process can be constructed like this: .. code-block:: python @@ -65,14 +72,23 @@ Contributions in the form of pull requests are more than welcome. A good way to In order to contribute: -1) Install Ramsey and dev dependencies via :code:`pip install -e '.[dev]'`, -2) test your contribution/implementation by calling :code:`tox` on the (Unix) command line before submitting a PR. +1) Clone Ramsey and install it and its dev dependencies via :code:`pip install -e '.[dev]'`, +2) create a new branch locally :code:`git checkout -b feature/my-new-feature` or :code:`git checkout -b issue/fixes-bug`, +3) implement your contribution, +4) test it by calling :code:`tox` on the (Unix) command line, +5) submit a PR 🙂 License ------- -Ramsey is licensed under a Apache 2.0 License +Ramsey is licensed under the Apache 2.0 License. + + +.. toctree:: + :maxdepth: 1 + :hidden: + Home .. toctree:: :caption: Tutorials @@ -80,12 +96,16 @@ Ramsey is licensed under a Apache 2.0 License :hidden: notebooks/neural_process - notebooks/gaussian_process notebooks/forecasting .. toctree:: - :caption: API reference + :caption: API :maxdepth: 1 :hidden: - api + ramsey + ramsey.attention + ramsey.contrib + ramsey.family + ramsey.kernels + ramsey.train diff --git a/docs/source/notebooks/forecasting.ipynb b/docs/notebooks/forecasting.ipynb similarity index 99% rename from docs/source/notebooks/forecasting.ipynb rename to docs/notebooks/forecasting.ipynb index e18f367..6c176b4 100644 --- a/docs/source/notebooks/forecasting.ipynb +++ b/docs/notebooks/forecasting.ipynb @@ -216,7 +216,7 @@ "outputs": [], "source": [ "from ramsey import GP\n", - "from ramsey.covariance_functions import ExponentiatedQuadratic, Periodic\n", + "from ramsey.kernels import ExponentiatedQuadratic, Periodic\n", "from ramsey.train import train_gaussian_process" ] }, @@ -620,7 +620,6 @@ " n_context,\n", " n_target,\n", "):\n", - "\n", " key, sample_key = random.split(key, 2)\n", " sample_idxs = random.choice(\n", " sample_key,\n", diff --git a/docs/source/notebooks/neural_process.ipynb b/docs/notebooks/neural_process.ipynb similarity index 99% rename from docs/source/notebooks/neural_process.ipynb rename to docs/notebooks/neural_process.ipynb index cb67ab6..d0b1bcf 100644 --- a/docs/source/notebooks/neural_process.ipynb +++ b/docs/notebooks/neural_process.ipynb @@ -129,7 +129,7 @@ } }, "source": [ - "## Neural process\n", + "# Neural process\n", "\n", "Having sampled a data set, we can define the model. A neural process model takes a `decoder` argument, a `latent_encoder` argument and an optional `deterministic_encoder` argument.\n", "\n", @@ -334,7 +334,7 @@ } }, "source": [ - "## Attentive neural process\n", + "# Attentive neural process\n", "\n", "An attentive neural process model takes a `decoder` argument, a `latent_encoder` argument and a `deterministic_encoder` argument with an attention module. In comparison to before the\n", "`deterministic_encoder` is also a tuple in this case. In addition to the network that generates the representation, one is also requires to provide a mechanism for cross-attention.\n", diff --git a/docs/ramsey.attention.rst b/docs/ramsey.attention.rst new file mode 100644 index 0000000..ce679eb --- /dev/null +++ b/docs/ramsey.attention.rst @@ -0,0 +1,18 @@ +``ramsey.attention`` +==================== + +.. currentmodule:: ramsey.attention + +.. automodule:: ramsey.attention + +Attention +--------- + +.. autosummary:: + MultiHeadAttention + +MultiHeadAttention +~~~~~~~~~~~~~~~~~~ + +.. autoclass:: MultiHeadAttention + :members: diff --git a/docs/ramsey.contrib.rst b/docs/ramsey.contrib.rst new file mode 100644 index 0000000..e48f593 --- /dev/null +++ b/docs/ramsey.contrib.rst @@ -0,0 +1,37 @@ +``ramsey.contrib`` +================== + +.. currentmodule:: ramsey.contrib + +.. automodule:: ramsey.contrib + +Modules +------- + +.. autosummary:: + BayesianLinear + +BayesianLinear +~~~~~~~~~~~~~~ + +.. autoclass:: BayesianLinear + :members: + +Models +------ + +.. autosummary:: + BNN + RANP + +Bayesian neural network +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: BNN + :members: + +Recurrent attentive neural process +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: RANP + :members: diff --git a/docs/ramsey.family.rst b/docs/ramsey.family.rst new file mode 100644 index 0000000..9e02a24 --- /dev/null +++ b/docs/ramsey.family.rst @@ -0,0 +1,25 @@ +``ramsey.family`` +================= + +.. currentmodule:: ramsey.family + +.. automodule:: ramsey.family + +Exponential families +-------------------- + +.. autosummary:: + Gaussian + NegativeBinomial + +Gaussian +~~~~~~~~ + +.. autoclass:: Gaussian + :members: + +NegativeBinomial +~~~~~~~~~~~~~~~~ + +.. autoclass:: NegativeBinomial + :members: diff --git a/docs/ramsey.kernels.rst b/docs/ramsey.kernels.rst new file mode 100644 index 0000000..a60fd83 --- /dev/null +++ b/docs/ramsey.kernels.rst @@ -0,0 +1,32 @@ +``ramsey.kernels`` +================== + +.. currentmodule:: ramsey.kernels + +.. automodule:: ramsey.kernels + +Covariance functions +-------------------- + +.. autosummary:: + ExponentiatedQuadratic + Linear + Periodic + +ExponentiatedQuadratic +~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: ExponentiatedQuadratic + :members: + +Linear +~~~~~~ + +.. autoclass:: Linear + :members: + +Periodic +~~~~~~~~~ + +.. autoclass:: Periodic + :members: diff --git a/docs/ramsey.rst b/docs/ramsey.rst new file mode 100644 index 0000000..2b7d6fa --- /dev/null +++ b/docs/ramsey.rst @@ -0,0 +1,45 @@ +``ramsey`` +========== + +.. currentmodule:: ramsey + +Models +------ + +.. autosummary:: + ANP + DANP + NP + GP + SparseGP + +Neural process +~~~~~~~~~~~~~~ + +.. autoclass:: NP + :members: + + +Attentive neural process +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: ANP + :members: + +Doubly attentive neural process +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: DANP + :members: + +Gaussian process +~~~~~~~~~~~~~~~~ + +.. autoclass:: GP + :members: + +Sparse Gaussian process +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: SparseGP + :members: diff --git a/docs/ramsey.train.rst b/docs/ramsey.train.rst new file mode 100644 index 0000000..ffa9e46 --- /dev/null +++ b/docs/ramsey.train.rst @@ -0,0 +1,30 @@ +``ramsey.train`` +================ + +.. currentmodule:: ramsey.train + +.. automodule:: ramsey.train + + +Train functions +--------------- + +.. autosummary:: + train_neural_process + train_gaussian_process + train_sparse_gaussian_process + +train_neural_process +~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: train_neural_process + +train_gaussian_process +~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: train_gaussian_process + +train_sparse_gaussian_process +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: train_sparse_gaussian_process diff --git a/docs/source/references.bib b/docs/references.bib similarity index 100% rename from docs/source/references.bib rename to docs/references.bib diff --git a/docs/requirements.txt b/docs/requirements.txt index 1a0d3f6..8ac3576 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,3 +1,4 @@ +-e . git+https://github.com/dirmeier/palettes@v0.1.1 ipython matplotlib @@ -8,6 +9,7 @@ sphinx-autobuild sphinx-book-theme sphinx-math-dollar sphinx_autodoc_typehints +sphinx_design sphinx_fontawesome sphinx_gallery sphinxcontrib-fulltoc diff --git a/docs/source/_static/.gitkeep b/docs/source/_static/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/docs/source/api.rst b/docs/source/api.rst deleted file mode 100644 index 9279fca..0000000 --- a/docs/source/api.rst +++ /dev/null @@ -1,69 +0,0 @@ -ramsey package -============== - -.. currentmodule:: ramsey - -ramsey.attention ----------------- - -.. currentmodule:: ramsey.attention - -.. autoclass:: MultiHeadAttention - :members: - - -ramsey.covariance_functions ---------------------------- - -.. currentmodule:: ramsey.covariance_functions - -.. autofunction:: - exponentiated_quadratic - - -ramsey.data ------------ - -.. automodule:: ramsey.data - -.. autosummary:: - - sample_from_sinus_function - sample_from_gaussian_process - - -ramsey.family -------------- - -.. automodule:: ramsey.family - -.. autoclass:: Gaussian - :members: - -.. autoclass:: NegativeBinomial - :members: - - -ramsey.models -------------- - -.. automodule:: ramsey.models - -.. autoclass:: NP - :members: - -.. autoclass:: ANP - :members: - -.. autoclass:: DANP - :members: - - -ramsey.train ------------- - -.. automodule:: ramsey.train - -.. autosummary:: - - train_neural_process diff --git a/examples/attentive_neural_process.py b/examples/attentive_neural_process.py index 43fe6c1..6c6f78c 100644 --- a/examples/attentive_neural_process.py +++ b/examples/attentive_neural_process.py @@ -17,9 +17,9 @@ from jax import numpy as jnp from jax import random +from ramsey import ANP from ramsey.attention import MultiHeadAttention from ramsey.data import sample_from_gaussian_process -from ramsey import ANP from ramsey.train import train_neural_process diff --git a/examples/gaussian_process.py b/examples/gaussian_process.py index eece52f..c41736e 100644 --- a/examples/gaussian_process.py +++ b/examples/gaussian_process.py @@ -19,9 +19,9 @@ from jax import random from jax.config import config -from ramsey.covariance_functions import ExponentiatedQuadratic -from ramsey.data import sample_from_gaussian_process from ramsey import GP +from ramsey.data import sample_from_gaussian_process +from ramsey.kernels import ExponentiatedQuadratic from ramsey.train import train_gaussian_process config.update("jax_enable_x64", True) diff --git a/examples/sparse_gaussian_process.py b/examples/sparse_gaussian_process.py index 698cf05..9da8d41 100644 --- a/examples/sparse_gaussian_process.py +++ b/examples/sparse_gaussian_process.py @@ -20,9 +20,9 @@ from jax import random from jax.config import config -from ramsey.covariance_functions import ExponentiatedQuadratic -from ramsey.data import sample_from_gaussian_process from ramsey import SparseGP +from ramsey.data import sample_from_gaussian_process +from ramsey.kernels import ExponentiatedQuadratic from ramsey.train import train_sparse_gaussian_process config.update("jax_enable_x64", True) diff --git a/pyproject.toml b/pyproject.toml index e5bb2cd..115c6cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,12 @@ profile = "black" line_length = 80 include_trailing_comma = true +[tool.flake8] +max-line-length = 80 +extend-ignore = ["E203", "W503", "E731", "E501"] +per-file-ignores = [ + '__init__.py:F401', +] [tool.pylint.'MESSAGES CONTROL'] max-line-length=80 @@ -45,4 +51,4 @@ module = [ no_strict_optional = true [tool.bandit] -skips = ["B101"] +skips = ["B101", "B310"] diff --git a/ramsey/_src/conftest.py b/ramsey/_src/conftest.py index 93f3a54..42d7f90 100644 --- a/ramsey/_src/conftest.py +++ b/ramsey/_src/conftest.py @@ -1,3 +1,5 @@ +# pylint: skip-file + import haiku as hk import pytest diff --git a/ramsey/_src/datasets.py b/ramsey/_src/datasets.py index dd521b6..d2e7571 100644 --- a/ramsey/_src/datasets.py +++ b/ramsey/_src/datasets.py @@ -1,60 +1,60 @@ import os from collections import namedtuple from dataclasses import dataclass -from typing import ClassVar, List, Tuple +from typing import Tuple from urllib.parse import urlparse from urllib.request import urlretrieve import pandas as pd dset = namedtuple("dset", ["name", "urls"]) - +URL__ = "https://github.com/Mcompetitions/M4-methods/raw/master/Dataset/" __M4_HOURLY = dset( "hourly", [ - "https://github.com/Mcompetitions/M4-methods/raw/master/Dataset/Train/Hourly-train.csv", # pylint: disable=line-too-long - "https://github.com/Mcompetitions/M4-methods/raw/master/Dataset/Test/Hourly-test.csv", # pylint: disable=line-too-long + f"{URL__}/Train/Hourly-train.csv", + f"{URL__}/Test/Hourly-test.csv", ], ) __M4_DAILY = dset( "daily", [ - "https://github.com/Mcompetitions/M4-methods/raw/master/Dataset/Train/Daily-train.csv", # pylint: disable=line-too-long - "https://github.com/Mcompetitions/M4-methods/raw/master/Dataset/Test/Daily-test.csv", # pylint: disable=line-too-long + "f{URL__}/Train/Daily-train.csv", + "f{URL__}/Test/Daily-test.csv", ], ) __M4_WEEKLY = dset( "weekly", [ - "https://github.com/Mcompetitions/M4-methods/raw/master/Dataset/Train/Weekly-train.csv", # pylint: disable=line-too-long - "https://github.com/Mcompetitions/M4-methods/raw/master/Dataset/Test/Weekly-test.csv", # pylint: disable=line-too-long + f"{URL__}/Train/Weekly-train.csv", + f"{URL__}/Test/Weekly-test.csv", ], ) __M4_MONTHLY = dset( "monthly", [ - "https://github.com/Mcompetitions/M4-methods/raw/master/Dataset/Train/Monthly-train.csv", # pylint: disable=line-too-long - "https://github.com/Mcompetitions/M4-methods/raw/master/Dataset/Test/Monthly-test.csv", # pylint: disable=line-too-long + f"{URL__}/Train/Monthly-train.csv", + f"{URL__}/Test/Monthly-test.csv", ], ) __M4_QUARTERLY = dset( "quarterly", [ - "https://github.com/Mcompetitions/M4-methods/raw/master/Dataset/Train/Quarterly-train.csv", # pylint: disable=line-too-long - "https://github.com/Mcompetitions/M4-methods/raw/master/Dataset/Test/Quarterly-test.csv", # pylint: disable=line-too-long + f"{URL__}/Train/Quarterly-train.csv", + f"{URL__}/Test/Quarterly-test.csv", ], ) __M4_YEARLY = dset( "yearly", [ - "https://github.com/Mcompetitions/M4-methods/raw/master/Dataset/Train/Yearly-train.csv", # pylint: disable=line-too-long - "https://github.com/Mcompetitions/M4-methods/raw/master/Dataset/Test/Yearly-test.csv", # pylint: disable=line-too-long + f"{URL__}/Train/Yearly-train.csv", + f"{URL__}/Test/Yearly-test.csv", ], ) @@ -103,7 +103,7 @@ class M4Dataset: """A wrapper class to load M4 data""" - __INTERVALS__: ClassVar[List[str]] = [ + __INTERVALS__ = [ "hourly", "daily", "weekly", diff --git a/ramsey/_src/family.py b/ramsey/_src/family.py index 6b5ac46..aaabf4c 100644 --- a/ramsey/_src/family.py +++ b/ramsey/_src/family.py @@ -58,7 +58,7 @@ def __call__(self, target: jnp.ndarray, **kwargs): concentration = kwargs.get("concentration", None) if concentration is None: mean, concentration = jnp.split(target, 2, axis=-1) - mean = jnp.exp(target) + mean = jnp.exp(mean) concentration = jnp.exp(concentration) return dist.NegativeBinomial2(mean=mean, concentration=concentration) diff --git a/ramsey/_src/neural_process/test_train_neural_process.py b/ramsey/_src/neural_process/test_train_neural_process.py index 018bfcb..91075d1 100644 --- a/ramsey/_src/neural_process/test_train_neural_process.py +++ b/ramsey/_src/neural_process/test_train_neural_process.py @@ -1,3 +1,5 @@ +# pylint: skip-file + import haiku as hk from jax import random diff --git a/ramsey/_src/test_networks.py b/ramsey/_src/test_networks.py index a2bad8a..d0de443 100644 --- a/ramsey/_src/test_networks.py +++ b/ramsey/_src/test_networks.py @@ -1,3 +1,5 @@ +# pylint: skip-file + import chex import haiku as hk import pytest diff --git a/ramsey/data.py b/ramsey/data.py index 07bffb4..5240264 100644 --- a/ramsey/data.py +++ b/ramsey/data.py @@ -9,7 +9,7 @@ from jax import random from ramsey._src.datasets import M4Dataset -from ramsey.covariance_functions import exponentiated_quadratic +from ramsey.kernels import exponentiated_quadratic # pylint: disable=too-many-locals,invalid-name diff --git a/ramsey/covariance_functions.py b/ramsey/kernels.py similarity index 100% rename from ramsey/covariance_functions.py rename to ramsey/kernels.py diff --git a/setup.py b/setup.py index 5d98b7d..d25d245 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,7 @@ def _version(): "blackjax", "chex", "dm-haiku>=0.0.9", + "distrax", "numpyro", "optax", "pandas" diff --git a/tox.ini b/tox.ini index 285bf87..6f13ad5 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = format, sort, lints, types +envlist = format, sort, lints, types, tests isolated_build = True [testenv:format] @@ -49,8 +49,8 @@ commands_pre = pip install -e . python -m ipykernel install --name ramsey-dev --user commands = - jupyter nbconvert --execute docs/source/notebooks/neural_process.ipynb --ExecutePreprocessor.kernel_name=ramsey-dev --to html - rm -rf docs/source/notebooks/neural_process.html + jupyter nbconvert --execute docs/notebooks/neural_process.ipynb --ExecutePreprocessor.kernel_name=ramsey-dev --to html + rm -rf docs/notebooks/neural_process.html [testenv:examples] skip_install = true