Skip to content

Commit 9886004

Browse files
Merge pull request #301 from flatironinstitute/negative_axis_id
Negative axis
2 parents 913bc20 + 18cb2f3 commit 9886004

File tree

5 files changed

+11
-0
lines changed

5 files changed

+11
-0
lines changed

src/nemos/basis/_basis.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,9 @@ def split_by_feature(
665665
- **Key**: Label of the basis.
666666
- **Value**: the array reshaped to: ``(..., n_inputs, n_basis_funcs, ...)``
667667
"""
668+
# convert axis to positive ints
669+
axis = axis if axis >= 0 else x.ndim + axis
670+
668671
if x.shape[axis] != self.n_output_features:
669672
raise ValueError(
670673
"`x.shape[axis]` does not match the expected number of features."

src/nemos/convolve.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def _shift_time_axis_and_convolve(array: NDArray, eval_basis: NDArray, axis: int
9090
-----
9191
This function supports arrays of any dimensionality greater or equal than 1.
9292
"""
93+
# convert axis
94+
axis = axis if axis >= 0 else array.ndim + axis
9395
# move time axis to first
9496
new_axis = (jnp.arange(array.ndim) + axis) % array.ndim
9597
array = jnp.transpose(array, new_axis)

src/nemos/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ def _pad_dimension(
156156
"acausal": ((pad_size) // 2, pad_size - (pad_size) // 2),
157157
"anti-causal": (0, pad_size),
158158
}
159+
# convert negative axis in jax jit compilable way
160+
axis = axis * (axis >= 0) + (array.ndim + axis) * (axis < 0)
159161

160162
pad_width = (
161163
((0, 0),) * axis

tests/test_basis.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5156,6 +5156,8 @@ def test_duplicate_keys(bas1, bas2, bas3, basis_class_specific_params):
51565156
"x, axis, expectation, exp_shapes", # num output is 5*2 + 6*3 = 28
51575157
[
51585158
(np.ones((1, 28)), 1, does_not_raise(), [(1, 2, 5), (1, 3, 6)]),
5159+
(np.ones((1, 28)), -1, does_not_raise(), [(1, 2, 5), (1, 3, 6)]),
5160+
(np.ones((1, 28, 2)), -2, does_not_raise(), [(1, 2, 5, 2), (1, 3, 6, 2)]),
51595161
(np.ones((28,)), 0, does_not_raise(), [(2, 5), (3, 6)]),
51605162
(np.ones((2, 2, 28)), 2, does_not_raise(), [(2, 2, 2, 5), (2, 2, 3, 6)]),
51615163
(

tests/test_convolution.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ class TestShiftTimeAxisAndConvolve:
1414
"time_series, check_func, axis",
1515
[
1616
(np.zeros((1, 20)), lambda x: x.ndim == 3, 1),
17+
(np.zeros((1, 20)), lambda x: x.ndim == 3, -1),
1718
(np.zeros((20,)), lambda x: x.ndim == 2, 0),
1819
(np.zeros((20, 1)), lambda x: x.ndim == 3, 0),
1920
(np.zeros((1, 20, 1)), lambda x: x.ndim == 4, 0),
21+
(np.zeros((1, 20, 1)), lambda x: x.ndim == 4, -3),
2022
],
2123
)
2224
def test_output_ndim(self, time_series, check_func, axis):

0 commit comments

Comments
 (0)