We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 913bc20 + 18cb2f3 commit 9886004Copy full SHA for 9886004
src/nemos/basis/_basis.py
@@ -665,6 +665,9 @@ def split_by_feature(
665
- **Key**: Label of the basis.
666
- **Value**: the array reshaped to: ``(..., n_inputs, n_basis_funcs, ...)``
667
"""
668
+ # convert axis to positive ints
669
+ axis = axis if axis >= 0 else x.ndim + axis
670
+
671
if x.shape[axis] != self.n_output_features:
672
raise ValueError(
673
"`x.shape[axis]` does not match the expected number of features."
src/nemos/convolve.py
@@ -90,6 +90,8 @@ def _shift_time_axis_and_convolve(array: NDArray, eval_basis: NDArray, axis: int
90
-----
91
This function supports arrays of any dimensionality greater or equal than 1.
92
93
+ # convert axis
94
+ axis = axis if axis >= 0 else array.ndim + axis
95
# move time axis to first
96
new_axis = (jnp.arange(array.ndim) + axis) % array.ndim
97
array = jnp.transpose(array, new_axis)
src/nemos/utils.py
@@ -156,6 +156,8 @@ def _pad_dimension(
156
"acausal": ((pad_size) // 2, pad_size - (pad_size) // 2),
157
"anti-causal": (0, pad_size),
158
}
159
+ # convert negative axis in jax jit compilable way
160
+ axis = axis * (axis >= 0) + (array.ndim + axis) * (axis < 0)
161
162
pad_width = (
163
((0, 0),) * axis
tests/test_basis.py
@@ -5156,6 +5156,8 @@ def test_duplicate_keys(bas1, bas2, bas3, basis_class_specific_params):
5156
"x, axis, expectation, exp_shapes", # num output is 5*2 + 6*3 = 28
5157
[
5158
(np.ones((1, 28)), 1, does_not_raise(), [(1, 2, 5), (1, 3, 6)]),
5159
+ (np.ones((1, 28)), -1, does_not_raise(), [(1, 2, 5), (1, 3, 6)]),
5160
+ (np.ones((1, 28, 2)), -2, does_not_raise(), [(1, 2, 5, 2), (1, 3, 6, 2)]),
5161
(np.ones((28,)), 0, does_not_raise(), [(2, 5), (3, 6)]),
5162
(np.ones((2, 2, 28)), 2, does_not_raise(), [(2, 2, 2, 5), (2, 2, 3, 6)]),
5163
(
tests/test_convolution.py
@@ -14,9 +14,11 @@ class TestShiftTimeAxisAndConvolve:
14
"time_series, check_func, axis",
15
16
(np.zeros((1, 20)), lambda x: x.ndim == 3, 1),
17
+ (np.zeros((1, 20)), lambda x: x.ndim == 3, -1),
18
(np.zeros((20,)), lambda x: x.ndim == 2, 0),
19
(np.zeros((20, 1)), lambda x: x.ndim == 3, 0),
20
(np.zeros((1, 20, 1)), lambda x: x.ndim == 4, 0),
21
+ (np.zeros((1, 20, 1)), lambda x: x.ndim == 4, -3),
22
],
23
)
24
def test_output_ndim(self, time_series, check_func, axis):
0 commit comments