Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Negative axis #301

Merged
merged 2 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/nemos/basis/_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,9 @@ def split_by_feature(
- **Key**: Label of the basis.
- **Value**: the array reshaped to: ``(..., n_inputs, n_basis_funcs, ...)``
"""
# convert axis to positive ints
axis = axis if axis >= 0 else x.ndim + axis

if x.shape[axis] != self.n_output_features:
raise ValueError(
"`x.shape[axis]` does not match the expected number of features."
Expand Down
2 changes: 2 additions & 0 deletions src/nemos/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def _shift_time_axis_and_convolve(array: NDArray, eval_basis: NDArray, axis: int
-----
This function supports arrays of any dimensionality greater or equal than 1.
"""
# convert axis
axis = axis if axis >= 0 else array.ndim + axis
# move time axis to first
new_axis = (jnp.arange(array.ndim) + axis) % array.ndim
array = jnp.transpose(array, new_axis)
Expand Down
2 changes: 2 additions & 0 deletions src/nemos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ def _pad_dimension(
"acausal": ((pad_size) // 2, pad_size - (pad_size) // 2),
"anti-causal": (0, pad_size),
}
# convert negative axis in jax jit compilable way
axis = axis * (axis >= 0) + (array.ndim + axis) * (axis < 0)

pad_width = (
((0, 0),) * axis
Expand Down
2 changes: 2 additions & 0 deletions tests/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5156,6 +5156,8 @@ def test_duplicate_keys(bas1, bas2, bas3, basis_class_specific_params):
"x, axis, expectation, exp_shapes", # num output is 5*2 + 6*3 = 28
[
(np.ones((1, 28)), 1, does_not_raise(), [(1, 2, 5), (1, 3, 6)]),
(np.ones((1, 28)), -1, does_not_raise(), [(1, 2, 5), (1, 3, 6)]),
(np.ones((1, 28, 2)), -2, does_not_raise(), [(1, 2, 5, 2), (1, 3, 6, 2)]),
(np.ones((28,)), 0, does_not_raise(), [(2, 5), (3, 6)]),
(np.ones((2, 2, 28)), 2, does_not_raise(), [(2, 2, 2, 5), (2, 2, 3, 6)]),
(
Expand Down
2 changes: 2 additions & 0 deletions tests/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ class TestShiftTimeAxisAndConvolve:
"time_series, check_func, axis",
[
(np.zeros((1, 20)), lambda x: x.ndim == 3, 1),
(np.zeros((1, 20)), lambda x: x.ndim == 3, -1),
(np.zeros((20,)), lambda x: x.ndim == 2, 0),
(np.zeros((20, 1)), lambda x: x.ndim == 3, 0),
(np.zeros((1, 20, 1)), lambda x: x.ndim == 4, 0),
(np.zeros((1, 20, 1)), lambda x: x.ndim == 4, -3),
],
)
def test_output_ndim(self, time_series, check_func, axis):
Expand Down