Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor beautifications and updates #41

Merged
merged 1 commit into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
pip install tox
- name: Run format, sort, lints and types
run: |
tox -e format,sort,lints,types
tox -e format,lints,types
test:
name: unit tests
Expand Down
38 changes: 5 additions & 33 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,6 @@ repos:
- id: requirements-txt-fixer
- id: trailing-whitespace

- repo: https://github.com/asottile/pyupgrade
rev: v2.29.1
hooks:
- id: pyupgrade
args: [--py38-plus]

- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
args: ["--config=pyproject.toml"]
files: "(ramsey|examples)"

- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
args: ["--settings-path=pyproject.toml"]
files: "(ramsey|examples)"

- repo: https://github.com/pycqa/bandit
rev: 1.7.1
hooks:
Expand All @@ -44,15 +24,6 @@ repos:
additional_dependencies: ["toml"]
files: "(ramsey|examples)"

- repo: https://github.com/PyCQA/flake8
rev: 5.0.1
hooks:
- id: flake8
additional_dependencies: [
flake8-typing-imports==1.14.0,
flake8-pyproject==1.1.0.post0
]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.910-1
hooks:
Expand All @@ -77,8 +48,9 @@ repos:
- id: gitlint
- id: gitlint-ci

- repo: https://github.com/pycqa/pydocstyle
rev: 6.1.1
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0
hooks:
- id: pydocstyle
additional_dependencies: ["toml"]
- id: ruff
args: [ --fix ]
- id: ruff-format
57 changes: 11 additions & 46 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,54 +1,19 @@
[build-system]
requires = ["setuptools", "wheel"]

[tool.black]
line-length = 80
target-version = ['py38']
exclude = '''
/(
\.eggs
| \.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
'''

[tool.isort]
profile = "black"
line_length = 80
include_trailing_comma = true
[tool.bandit]
skips = ["B101", "B310"]

[tool.flake8]
max-line-length = 80
extend-ignore = ["E203", "W503", "E731", "E501"]
per-file-ignores = [
'__init__.py:F401',
]
[tool.ruff]
line-length = 80
exclude = ["*_test.py", "setup.py"]

[tool.pylint.'MESSAGES CONTROL']
max-line-length=80
disable = [
"missing-module-docstring",
"missing-function-docstring",
"no-name-in-module",
"too-many-arguments",
"duplicate-code",
"invalid-name",
"attribute-defined-outside-init",
"unsubscriptable-object",
"unpacking-non-sequence",
"arguments-differ"
[tool.ruff.lint]
ignore= ["S101", "ANN1", "ANN2", "ANN0"]
select = ["ANN", "D", "E", "F"]
extend-select = [
"UP", "D", "I", "PL", "S"
]

[tool.bandit]
skips = ["B101", "B310"]

[tool.pydocstyle]
[tool.ruff.lint.pydocstyle]
convention= 'numpy'
match = '^ramsey/.*/((?!test).)*\.py'
4 changes: 1 addition & 3 deletions ramsey/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
ramsey: Probabilistic deep learning using JAX
"""
"""ramsey: Probabilistic deep learning using JAX."""

from ramsey._src.neural_process.attentive_neural_process import ANP
from ramsey._src.neural_process.doubly_attentive_neural_process import DANP
Expand Down
4 changes: 1 addition & 3 deletions ramsey/_src/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ def m4_data(interval: str = "hourly", drop_na: bool = True):
train_idxs = jnp.arange(train.shape[1])
test_idxs = jnp.arange(test.shape[1]) + train.shape[1]

return namedtuple(
"data", ["y", "x", "train_idxs", "test_idxs"]
)( # type: ignore
return namedtuple("data", ["y", "x", "train_idxs", "test_idxs"])( # type: ignore
y, x, train_idxs, test_idxs
)

Expand Down
1 change: 1 addition & 0 deletions ramsey/_src/data/dataset_m4.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def load(self, interval: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
train, test = self._load(dataset, train_csv_path, test_csv_path)
return train, test

# ruff: noqa: S310
def _download(self, dataset):
for url in dataset.urls:
file = os.path.basename(urlparse(url).path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def _get_bias(self, layer_dim, dtype):
)
return samples, params

# ruff: noqa: PLR0913
def _init_param(self, weight_name, param_name, constraint, shape, dtype):
init = initializers.xavier_normal()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@


class BNN(nn.Module):
"""
A Bayesian neural network.
"""A Bayesian neural network.
The BNN layers can a mix of Bayesian layers and conventional layers.
The training objective is the ELBO and is calculated according to [1].
Expand Down
4 changes: 2 additions & 2 deletions ramsey/_src/experimental/bayesian_neural_network/train_bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rmsyutls import as_batch_iterator
from tqdm import tqdm

# pylint: disable=line-too-long
# ruff: noqa: E501
from ramsey._src.experimental.bayesian_neural_network.bayesian_neural_network import (
BNN,
)
Expand All @@ -24,7 +24,7 @@ def _create_train_state(rng, model, optimizer, **init_data):
return state


# pylint: disable=too-many-locals
# ruff: noqa: PLR0913
def train_bnn(
rng_key,
bnn: BNN,
Expand Down
6 changes: 2 additions & 4 deletions ramsey/_src/experimental/gaussian_process/gaussian_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

# pylint: disable=too-many-instance-attributes,duplicate-code
class GP(nn.Module):
"""
A Gaussian process.
"""A Gaussian process.
Attributes
----------
Expand All @@ -30,8 +29,7 @@ class GP(nn.Module):

@nn.compact
def __call__(self, x: Array, **kwargs):
"""
Evaluate the Gaussian process.
"""Evaluate the Gaussian process.
Parameters
----------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

# pylint: disable=invalid-name
class Periodic(Kernel, nn.Module):
"""
Periodic covariance function.
"""Periodic covariance function.
Attributes
----------
Expand Down Expand Up @@ -60,8 +59,7 @@ def __call__(self, x1: Array, x2: Array = None):


class ExponentiatedQuadratic(Kernel, nn.Module):
"""
Exponentiated quadratic covariance function.
"""Exponentiated quadratic covariance function.
Attributes
----------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)


# pylint: disable=too-many-locals,invalid-name
# ruff: noqa: PLR0913
def train_gaussian_process(
rng_key: jr.PRNGKey,
gaussian_process: GP,
Expand Down Expand Up @@ -79,7 +79,7 @@ def obj_fn(params):
return state.params, objectives


# pylint: disable=too-many-locals,invalid-name
# ruff: noqa: D406
def train_sparse_gaussian_process(
rng_key: jr.PRNGKey,
gaussian_process: SparseGP,
Expand Down
3 changes: 1 addition & 2 deletions ramsey/_src/neural_process/attentive_neural_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
__all__ = ["ANP"]


# pylint: disable=too-many-instance-attributes,duplicate-code
# pylint: disable=unpacking-non-sequence,
# ruff: noqa: PLR0913
class ANP(NP):
"""An attentive neural process.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@

# pylint: disable=too-many-instance-attributes
class DANP(ANP):
"""
A doubly-attentive neural process.
"""A doubly-attentive neural process.
Implements the core structure of a 'doubly-attentive' neural process [1],
i.e., a deterministic encoder, a latent encoder with self-attention module,
Expand Down
12 changes: 3 additions & 9 deletions ramsey/_src/neural_process/neural_process_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,9 @@ def module():
chex.assert_shape(params["latent_encoder_0"]["linear_0"]["kernel"], (2, 3))
chex.assert_shape(params["latent_encoder_0"]["linear_1"]["kernel"], (3, 3))
chex.assert_shape(params["latent_encoder_1"]["linear_0"]["kernel"], (3, 3))
chex.assert_shape(
params["latent_encoder_1"]["linear_1"]["kernel"], (3, 2 * 3)
)
chex.assert_shape(
params["deterministic_encoder"]["linear_0"]["kernel"], (2, 4)
)
chex.assert_shape(
params["deterministic_encoder"]["linear_1"]["kernel"], (4, 4)
)
chex.assert_shape(params["latent_encoder_1"]["linear_1"]["kernel"], (3, 2 * 3))
chex.assert_shape(params["deterministic_encoder"]["linear_0"]["kernel"], (2, 4))
chex.assert_shape(params["deterministic_encoder"]["linear_1"]["kernel"], (4, 4))
chex.assert_shape(params["decoder"]["linear_0"]["kernel"], (3 + 4 + 1, 3))
chex.assert_shape(params["decoder"]["linear_1"]["kernel"], (3, 2))

Expand Down
1 change: 1 addition & 0 deletions ramsey/_src/neural_process/train_neural_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def _split_data(
}


# ruff: noqa: ANN001,ANN003,PLR0913
def _create_train_state(rng, model, optimizer, **init_data):
init_key, sample_key = jr.split(rng)
params = model.init({"sample": sample_key, "params": init_key}, **init_data)
Expand Down
9 changes: 4 additions & 5 deletions ramsey/_src/nn/MLP.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@


class MLP(nn.Module):
"""
A multi-layer perceptron.
"""A multi-layer perceptron.
Attributes
----------
Expand Down Expand Up @@ -56,19 +55,19 @@ def setup(self):
self.dropout_layer = nn.Dropout(self.dropout)

# pylint: disable=too-many-function-args
def __call__(self, inputs: Array, is_training=False):
def __call__(self, inputs: Array, is_training: bool = False):
"""Transform the inputs through the MLP.
Parameters
----------
inputs: jax.Array
inputs: Array
input data of dimension (*batch_dims, spatial_dims..., feature_dims)
is_training: boolean
if true, uses training mode (i.e., dropout)
Returns
-------
jax.Array
Array
returns the transformed inputs
"""
num_layers = len(self.layers)
Expand Down
3 changes: 2 additions & 1 deletion ramsey/_src/nn/attention/dotproduct_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from ramsey._src.nn.attention.attention import Attention


# ruff: noqa: PLR0913
class DotProductAttention(Attention):
"""Dot-product attention."""

def __call__(self, key: Array, value: Array, query: Array):
def __call__(self, key: Array, value: Array, query: Array) -> Array:
"""Apply attention to the query.
Arguments
Expand Down
11 changes: 6 additions & 5 deletions ramsey/_src/nn/attention/multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from ramsey._src.nn.attention.attention import Attention


# ruff: noqa: PLR0913
class MultiHeadAttention(Attention):
"""
Multi-head attention.
"""Multi-head attention.
As described in [1].
Expand All @@ -41,7 +41,7 @@ class MultiHeadAttention(Attention):
head_size: int
embedding: Optional[nn.Module]

def setup(self):
def setup(self) -> None:
"""Construct the networks."""
self._attention = _MultiHeadAttention(
num_heads=self.num_heads,
Expand All @@ -50,7 +50,7 @@ def setup(self):
)

@nn.compact
def __call__(self, key: Array, value: Array, query: Array):
def __call__(self, key: Array, value: Array, query: Array) -> Array:
"""Apply attention to the query.
Arguments
Expand All @@ -73,6 +73,7 @@ def __call__(self, key: Array, value: Array, query: Array):
return rep


# ruff: noqa: E501
class _MultiHeadAttention(nn.Module):
num_heads: int
dtype = None
Expand All @@ -99,7 +100,7 @@ def __call__(
value: Array,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None,
):
) -> Array:
features = self.out_features or query.shape[-1]
qkv_features = self.qkv_features or query.shape[-1]
assert (
Expand Down
2 changes: 2 additions & 0 deletions ramsey/data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Methods for downlaading data set."""

from ramsey._src.data.data import (
m4_data,
sample_from_gaussian_process,
Expand Down
Loading
Loading