Skip to content

Commit

Permalink
add column names for compute_wavelet_transform
Browse files Browse the repository at this point in the history
  • Loading branch information
sjvenditto committed Jan 29, 2025
1 parent c961c66 commit 71a1f54
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
5 changes: 4 additions & 1 deletion pynapple/process/wavelets.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,10 @@ def compute_wavelet_transform(

if len(output_shape) == 2:
return nap.TsdFrame(
t=sig.index, d=np.squeeze(cwt, axis=1), time_support=sig.time_support
t=sig.index,
d=np.squeeze(cwt, axis=1),
time_support=sig.time_support,
columns=freqs,
)
else:
return nap.TsdTensor(
Expand Down
3 changes: 3 additions & 0 deletions tests/test_signal_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,9 @@ def test_compute_wavelet_transform(
np.testing.assert_array_almost_equal(
mwt.time_support.values, sig.time_support.values
)
if isinstance(mwt, nap.TsdFrame):
# test column names if TsdFrame
np.testing.assert_array_almost_equal(mwt.columns, freqs)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 71a1f54

Please sign in to comment.