Skip to content

Commit

Permalink
Minor beautifications and updates (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier authored Mar 1, 2024
1 parent 9e68fb9 commit f79b352
Show file tree
Hide file tree
Showing 25 changed files with 71 additions and 145 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
pip install tox
- name: Run format, sort, lints and types
run: |
tox -e format,sort,lints,types
tox -e format,lints,types
test:
name: unit tests
Expand Down
38 changes: 5 additions & 33 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,6 @@ repos:
- id: requirements-txt-fixer
- id: trailing-whitespace

- repo: https://github.com/asottile/pyupgrade
rev: v2.29.1
hooks:
- id: pyupgrade
args: [--py38-plus]

- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
args: ["--config=pyproject.toml"]
files: "(ramsey|examples)"

- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
args: ["--settings-path=pyproject.toml"]
files: "(ramsey|examples)"

- repo: https://github.com/pycqa/bandit
rev: 1.7.1
hooks:
Expand All @@ -44,15 +24,6 @@ repos:
additional_dependencies: ["toml"]
files: "(ramsey|examples)"

- repo: https://github.com/PyCQA/flake8
rev: 5.0.1
hooks:
- id: flake8
additional_dependencies: [
flake8-typing-imports==1.14.0,
flake8-pyproject==1.1.0.post0
]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.910-1
hooks:
Expand All @@ -77,8 +48,9 @@ repos:
- id: gitlint
- id: gitlint-ci

- repo: https://github.com/pycqa/pydocstyle
rev: 6.1.1
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0
hooks:
- id: pydocstyle
additional_dependencies: ["toml"]
- id: ruff
args: [ --fix ]
- id: ruff-format
57 changes: 11 additions & 46 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,54 +1,19 @@
[build-system]
requires = ["setuptools", "wheel"]

[tool.black]
line-length = 80
target-version = ['py38']
exclude = '''
/(
\.eggs
| \.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
'''

[tool.isort]
profile = "black"
line_length = 80
include_trailing_comma = true
[tool.bandit]
skips = ["B101", "B310"]

[tool.flake8]
max-line-length = 80
extend-ignore = ["E203", "W503", "E731", "E501"]
per-file-ignores = [
'__init__.py:F401',
]
[tool.ruff]
line-length = 80
exclude = ["*_test.py", "setup.py"]

[tool.pylint.'MESSAGES CONTROL']
max-line-length=80
disable = [
"missing-module-docstring",
"missing-function-docstring",
"no-name-in-module",
"too-many-arguments",
"duplicate-code",
"invalid-name",
"attribute-defined-outside-init",
"unsubscriptable-object",
"unpacking-non-sequence",
"arguments-differ"
[tool.ruff.lint]
ignore= ["S101", "ANN1", "ANN2", "ANN0"]
select = ["ANN", "D", "E", "F"]
extend-select = [
"UP", "D", "I", "PL", "S"
]

[tool.bandit]
skips = ["B101", "B310"]

[tool.pydocstyle]
[tool.ruff.lint.pydocstyle]
convention= 'numpy'
match = '^ramsey/.*/((?!test).)*\.py'
4 changes: 1 addition & 3 deletions ramsey/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
ramsey: Probabilistic deep learning using JAX
"""
"""ramsey: Probabilistic deep learning using JAX."""

from ramsey._src.neural_process.attentive_neural_process import ANP
from ramsey._src.neural_process.doubly_attentive_neural_process import DANP
Expand Down
4 changes: 1 addition & 3 deletions ramsey/_src/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ def m4_data(interval: str = "hourly", drop_na: bool = True):
train_idxs = jnp.arange(train.shape[1])
test_idxs = jnp.arange(test.shape[1]) + train.shape[1]

return namedtuple(
"data", ["y", "x", "train_idxs", "test_idxs"]
)( # type: ignore
return namedtuple("data", ["y", "x", "train_idxs", "test_idxs"])( # type: ignore
y, x, train_idxs, test_idxs
)

Expand Down
1 change: 1 addition & 0 deletions ramsey/_src/data/dataset_m4.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def load(self, interval: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
train, test = self._load(dataset, train_csv_path, test_csv_path)
return train, test

# ruff: noqa: S310
def _download(self, dataset):
for url in dataset.urls:
file = os.path.basename(urlparse(url).path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def _get_bias(self, layer_dim, dtype):
)
return samples, params

# ruff: noqa: PLR0913
def _init_param(self, weight_name, param_name, constraint, shape, dtype):
init = initializers.xavier_normal()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@


class BNN(nn.Module):
"""
A Bayesian neural network.
"""A Bayesian neural network.
The BNN layers can a mix of Bayesian layers and conventional layers.
The training objective is the ELBO and is calculated according to [1].
Expand Down
4 changes: 2 additions & 2 deletions ramsey/_src/experimental/bayesian_neural_network/train_bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rmsyutls import as_batch_iterator
from tqdm import tqdm

# pylint: disable=line-too-long
# ruff: noqa: E501
from ramsey._src.experimental.bayesian_neural_network.bayesian_neural_network import (
BNN,
)
Expand All @@ -24,7 +24,7 @@ def _create_train_state(rng, model, optimizer, **init_data):
return state


# pylint: disable=too-many-locals
# ruff: noqa: PLR0913
def train_bnn(
rng_key,
bnn: BNN,
Expand Down
6 changes: 2 additions & 4 deletions ramsey/_src/experimental/gaussian_process/gaussian_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

# pylint: disable=too-many-instance-attributes,duplicate-code
class GP(nn.Module):
"""
A Gaussian process.
"""A Gaussian process.
Attributes
----------
Expand All @@ -30,8 +29,7 @@ class GP(nn.Module):

@nn.compact
def __call__(self, x: Array, **kwargs):
"""
Evaluate the Gaussian process.
"""Evaluate the Gaussian process.
Parameters
----------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

# pylint: disable=invalid-name
class Periodic(Kernel, nn.Module):
"""
Periodic covariance function.
"""Periodic covariance function.
Attributes
----------
Expand Down Expand Up @@ -60,8 +59,7 @@ def __call__(self, x1: Array, x2: Array = None):


class ExponentiatedQuadratic(Kernel, nn.Module):
"""
Exponentiated quadratic covariance function.
"""Exponentiated quadratic covariance function.
Attributes
----------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)


# pylint: disable=too-many-locals,invalid-name
# ruff: noqa: PLR0913
def train_gaussian_process(
rng_key: jr.PRNGKey,
gaussian_process: GP,
Expand Down Expand Up @@ -79,7 +79,7 @@ def obj_fn(params):
return state.params, objectives


# pylint: disable=too-many-locals,invalid-name
# ruff: noqa: D406
def train_sparse_gaussian_process(
rng_key: jr.PRNGKey,
gaussian_process: SparseGP,
Expand Down
3 changes: 1 addition & 2 deletions ramsey/_src/neural_process/attentive_neural_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
__all__ = ["ANP"]


# pylint: disable=too-many-instance-attributes,duplicate-code
# pylint: disable=unpacking-non-sequence,
# ruff: noqa: PLR0913
class ANP(NP):
"""An attentive neural process.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@

# pylint: disable=too-many-instance-attributes
class DANP(ANP):
"""
A doubly-attentive neural process.
"""A doubly-attentive neural process.
Implements the core structure of a 'doubly-attentive' neural process [1],
i.e., a deterministic encoder, a latent encoder with self-attention module,
Expand Down
12 changes: 3 additions & 9 deletions ramsey/_src/neural_process/neural_process_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,9 @@ def module():
chex.assert_shape(params["latent_encoder_0"]["linear_0"]["kernel"], (2, 3))
chex.assert_shape(params["latent_encoder_0"]["linear_1"]["kernel"], (3, 3))
chex.assert_shape(params["latent_encoder_1"]["linear_0"]["kernel"], (3, 3))
chex.assert_shape(
params["latent_encoder_1"]["linear_1"]["kernel"], (3, 2 * 3)
)
chex.assert_shape(
params["deterministic_encoder"]["linear_0"]["kernel"], (2, 4)
)
chex.assert_shape(
params["deterministic_encoder"]["linear_1"]["kernel"], (4, 4)
)
chex.assert_shape(params["latent_encoder_1"]["linear_1"]["kernel"], (3, 2 * 3))
chex.assert_shape(params["deterministic_encoder"]["linear_0"]["kernel"], (2, 4))
chex.assert_shape(params["deterministic_encoder"]["linear_1"]["kernel"], (4, 4))
chex.assert_shape(params["decoder"]["linear_0"]["kernel"], (3 + 4 + 1, 3))
chex.assert_shape(params["decoder"]["linear_1"]["kernel"], (3, 2))

Expand Down
1 change: 1 addition & 0 deletions ramsey/_src/neural_process/train_neural_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def _split_data(
}


# ruff: noqa: ANN001,ANN003,PLR0913
def _create_train_state(rng, model, optimizer, **init_data):
init_key, sample_key = jr.split(rng)
params = model.init({"sample": sample_key, "params": init_key}, **init_data)
Expand Down
9 changes: 4 additions & 5 deletions ramsey/_src/nn/MLP.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@


class MLP(nn.Module):
"""
A multi-layer perceptron.
"""A multi-layer perceptron.
Attributes
----------
Expand Down Expand Up @@ -56,19 +55,19 @@ def setup(self):
self.dropout_layer = nn.Dropout(self.dropout)

# pylint: disable=too-many-function-args
def __call__(self, inputs: Array, is_training=False):
def __call__(self, inputs: Array, is_training: bool = False):
"""Transform the inputs through the MLP.
Parameters
----------
inputs: jax.Array
inputs: Array
input data of dimension (*batch_dims, spatial_dims..., feature_dims)
is_training: boolean
if true, uses training mode (i.e., dropout)
Returns
-------
jax.Array
Array
returns the transformed inputs
"""
num_layers = len(self.layers)
Expand Down
3 changes: 2 additions & 1 deletion ramsey/_src/nn/attention/dotproduct_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from ramsey._src.nn.attention.attention import Attention


# ruff: noqa: PLR0913
class DotProductAttention(Attention):
"""Dot-product attention."""

def __call__(self, key: Array, value: Array, query: Array):
def __call__(self, key: Array, value: Array, query: Array) -> Array:
"""Apply attention to the query.
Arguments
Expand Down
11 changes: 6 additions & 5 deletions ramsey/_src/nn/attention/multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from ramsey._src.nn.attention.attention import Attention


# ruff: noqa: PLR0913
class MultiHeadAttention(Attention):
"""
Multi-head attention.
"""Multi-head attention.
As described in [1].
Expand All @@ -41,7 +41,7 @@ class MultiHeadAttention(Attention):
head_size: int
embedding: Optional[nn.Module]

def setup(self):
def setup(self) -> None:
"""Construct the networks."""
self._attention = _MultiHeadAttention(
num_heads=self.num_heads,
Expand All @@ -50,7 +50,7 @@ def setup(self):
)

@nn.compact
def __call__(self, key: Array, value: Array, query: Array):
def __call__(self, key: Array, value: Array, query: Array) -> Array:
"""Apply attention to the query.
Arguments
Expand All @@ -73,6 +73,7 @@ def __call__(self, key: Array, value: Array, query: Array):
return rep


# ruff: noqa: E501
class _MultiHeadAttention(nn.Module):
num_heads: int
dtype = None
Expand All @@ -99,7 +100,7 @@ def __call__(
value: Array,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None,
):
) -> Array:
features = self.out_features or query.shape[-1]
qkv_features = self.qkv_features or query.shape[-1]
assert (
Expand Down
2 changes: 2 additions & 0 deletions ramsey/data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Methods for downlaading data set."""

from ramsey._src.data.data import (
m4_data,
sample_from_gaussian_process,
Expand Down
Loading

0 comments on commit f79b352

Please sign in to comment.