Skip to content

Commit

Permalink
Merge pull request #394 from pynapple-org/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
gviejo authored Jan 15, 2025
2 parents f2ad617 + e8c8ea0 commit ae9129b
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 45 deletions.
16 changes: 8 additions & 8 deletions .github/workflows/documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ jobs:
- name: Sphinx build
run: |
sphinx-build doc _build
- name: Deploy to GitHub Pages
uses: peaceiris/actions-gh-pages@v3
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
with:
publish_branch: gh-pages
github_token: ${{ secrets.GITHUB_TOKEN }}
publish_dir: _build/
force_orphan: true
# - name: Deploy to GitHub Pages
# uses: peaceiris/actions-gh-pages@v3
# if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
# with:
# publish_branch: gh-pages
# github_token: ${{ secrets.GITHUB_TOKEN }}
# publish_dir: _build/
# force_orphan: true
2 changes: 1 addition & 1 deletion doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ pynapple: python neural analysis package

.. grid-item-card:: Filtering
:text-align: center
:link: ./user_guide/07_decoding.html
:link: ./user_guide/12_filtering.html

.. image:: _static/example_thumbs/filtering.svg
:class: dark-light
Expand Down
4 changes: 2 additions & 2 deletions pynapple/core/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1206,8 +1206,8 @@ def __getitem__(self, key, *args, **kwargs):
else:
index = self.index.__getitem__(key)

if isinstance(index, Number):
index = np.array([index])
# if isinstance(index, Number):
# index = np.array([index])

if all(is_array_like(a) for a in [index, output]):
if isinstance(key, tuple):
Expand Down
81 changes: 47 additions & 34 deletions tests/test_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,28 +1029,34 @@ def test_horizontal_slicing(self, tsdframe, index, nap_type):
],
)
def test_vertical_slicing(self, tsdframe, index):
assert isinstance(tsdframe[index], nap.TsdFrame)
if len(tsdframe[index] == 1):
# use ravel to ignore shape mismatch
np.testing.assert_array_almost_equal(
tsdframe.values[index].ravel(), tsdframe[index].values.ravel()
)
if isinstance(index, int):
assert isinstance(tsdframe[index], np.ndarray)
else:
assert isinstance(tsdframe[index], nap.TsdFrame)

output = tsdframe[index]
if isinstance(output, nap.TsdFrame):
if len(output == 1):
# use ravel to ignore shape mismatch
np.testing.assert_array_almost_equal(
tsdframe.values[index].ravel(), output.values.ravel()
)
else:
np.testing.assert_array_almost_equal(
tsdframe.values[index], output.values
)
assert isinstance(output.time_support, nap.IntervalSet)
np.testing.assert_array_almost_equal(
tsdframe.values[index], tsdframe[index].values
output.time_support, tsdframe.time_support
)
assert isinstance(tsdframe[index].time_support, nap.IntervalSet)
np.testing.assert_array_almost_equal(
tsdframe[index].time_support, tsdframe.time_support
)
if len(tsdframe.metadata_columns):
assert np.all(tsdframe[index].metadata_columns == tsdframe.metadata_columns)
assert np.all(tsdframe[index].metadata_index == tsdframe.metadata_index)
if len(tsdframe.metadata_columns):
assert np.all(output.metadata_columns == tsdframe.metadata_columns)
assert np.all(output.metadata_index == tsdframe.metadata_index)

@pytest.mark.parametrize(
"row",
[
0,
# 0,
[0, 2],
slice(20, 30),
np.hstack([np.zeros(10, bool), True, True, True, np.zeros(87, bool)]),
Expand Down Expand Up @@ -1108,28 +1114,35 @@ def test_vert_and_horz_slicing(self, tsdframe, row, col, expected):
else:
assert isinstance(tsdframe[row, col], expected)

if len(tsdframe[row, col] == 1):
# use ravel to ignore shape mismatch
output = tsdframe[row, col]

if isinstance(output, nap.TsdFrame):
if len(tsdframe[row, col] == 1):
# use ravel to ignore shape mismatch
np.testing.assert_array_almost_equal(
tsdframe.values[row, col].ravel(),
tsdframe[row, col].values.ravel(),
)
else:
np.testing.assert_array_almost_equal(
tsdframe.values[row, col], tsdframe[row, col].values
)
assert isinstance(tsdframe[row, col].time_support, nap.IntervalSet)
np.testing.assert_array_almost_equal(
tsdframe.values[row, col].ravel(), tsdframe[row, col].values.ravel()
tsdframe[row, col].time_support, tsdframe.time_support
)
if isinstance(tsdframe[row, col], nap.TsdFrame) and len(
tsdframe[row, col].metadata_columns
):
assert np.all(
tsdframe[row, col].metadata_columns == tsdframe.metadata_columns
)
assert np.all(
tsdframe[row, col].metadata_index
== tsdframe.metadata_index[col]
)
else:
np.testing.assert_array_almost_equal(
tsdframe.values[row, col], tsdframe[row, col].values
)
assert isinstance(tsdframe[row, col].time_support, nap.IntervalSet)
np.testing.assert_array_almost_equal(
tsdframe[row, col].time_support, tsdframe.time_support
)
if isinstance(tsdframe[row, col], nap.TsdFrame) and len(
tsdframe[row, col].metadata_columns
):
assert np.all(
tsdframe[row, col].metadata_columns == tsdframe.metadata_columns
)
assert np.all(
tsdframe[row, col].metadata_index == tsdframe.metadata_index[col]
)
np.testing.assert_array_almost_equal(output, tsdframe.values[row, col])

@pytest.mark.parametrize("index", [0, [0, 2]])
def test_str_indexing(self, tsdframe, index):
Expand Down

0 comments on commit ae9129b

Please sign in to comment.