Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix duplicate context point sampling in "int" strategy for pandas #153

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ dist/*
_build
*.png
deepsensor.egg-info/
.venv/
9 changes: 5 additions & 4 deletions deepsensor/data/loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from deepsensor.data.task import Task, flatten_X

import os
import json
import copy
Expand All @@ -10,7 +8,8 @@

from typing import List, Tuple, Union, Optional

from deepsensor.errors import InvalidSamplingStrategyError
from deepsensor.data.task import Task
from deepsensor.errors import InvalidSamplingStrategyError, SamplingTooManyPointsError


class TaskLoader:
Expand Down Expand Up @@ -696,7 +695,9 @@ def sample_df(
if isinstance(sampling_strat, (int, np.integer)):
N = sampling_strat
rng = np.random.default_rng(seed)
idx = rng.choice(df.index, N)
if N > df.index.size:
raise SamplingTooManyPointsError(requested=N, available=df.index.size)
idx = rng.choice(df.index, N, replace=False)
X_c = df.loc[idx].reset_index()[["x1", "x2"]].values.T.astype(self.dtype)
Y_c = df.loc[idx].values.T
elif isinstance(sampling_strat, str) and sampling_strat in [
Expand Down
9 changes: 9 additions & 0 deletions deepsensor/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,12 @@ class InvalidSamplingStrategyError(Exception):
"""Raised when TaskLoader sampling strategy is invalid."""

pass


class SamplingTooManyPointsError(ValueError):
"""Raised when the number of points to sample is greater than the number of points in the dataset."""

def __init__(self, requested: int, available: int):
super().__init__(
f"Requested {requested} points to sample, but only {available} are available."
)
35 changes: 31 additions & 4 deletions tests/test_task_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@
import pandas as pd
import unittest

import os
import shutil
import tempfile
import copy

from deepsensor.errors import InvalidSamplingStrategyError
from deepsensor.errors import InvalidSamplingStrategyError, SamplingTooManyPointsError
from tests.utils import (
gen_random_data_xr,
gen_random_data_pandas,
Expand Down Expand Up @@ -59,6 +57,8 @@ class TestTaskLoader(unittest.TestCase):

@classmethod
def setUpClass(cls):
# Set fixed random seed for deterministic tests
np.random.seed(42)
# It's safe to share data between tests because the TaskLoader does not modify data
cls.da = _gen_data_xr()
cls.aux_da = cls.da.isel(time=0)
Expand Down Expand Up @@ -156,6 +156,33 @@ def test_aux_at_contexts_and_aux_at_targets(self):
target_sampling,
) in self._gen_task_loader_call_args(len(context), 1):
task = tl("2020-01-01", context_sampling, target_sampling)

def test_int_sampling_strat_pandas(self):
"""Test integer sampling strategy in ``TaskLoader.__call__``."""
DUMMY_TARGET_SAMPLING = 10

tl = TaskLoader(
context=self.df,
target=self.df,
)

num_unique_coords = len(self.df.xs("2020-01-01", level="time").index)

task = tl("2020-01-01", num_unique_coords, DUMMY_TARGET_SAMPLING)
x1 = task["X_c"][0][0]
x2 = task["X_c"][0][1]
coords = list(zip(x1, x2))

# Ensure that there are no duplicates when sampling with an integer
# sampling strategy
self.assertEqual(len(coords), len(set(coords)))

def sample_too_many_points():
tl("2020-01-01", num_unique_coords + 1, DUMMY_TARGET_SAMPLING)

# If we're sampling more coordinates than exist in the dataset, we should throw.
self.assertRaises(SamplingTooManyPointsError, sample_too_many_points)


def test_invalid_sampling_strat(self):
"""Test invalid sampling strategy in ``TaskLoader.__call__``."""
Expand Down Expand Up @@ -215,7 +242,7 @@ def test_wrong_links(self):
tl = TaskLoader(context=self.df, target=self.df, links=[(0, 1)])

def test_links_gapfill_da(self) -> None:
"""TODO"""
"""Test gapfill sampling with NaN values in data."""
da_with_nans = copy.deepcopy(self.da)
# Convert 10% of data to NaNs
nan_idxs = np.random.randint(
Expand Down