forked from NVIDIA/NeMo-Curator
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_duplicates_removal.py
208 lines (179 loc) · 6.38 KB
/
test_duplicates_removal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
from typing import Literal
import pandas as pd
import pytest
from dask import dataframe as dd
from nemo_curator.utils.duplicates_removal import remove_duplicates
@pytest.fixture()
def ids():
# Dataset has id a0...a9, b0...b9, c0...c9, d0...d9
l = [f"{group}{i}" for group in ["a", "b", "c", "d"] for i in range(10)]
return l
@pytest.fixture
def sample_data(ids):
df = pd.DataFrame(
{
"id": ids,
"text": [f"text for {_id}" for _id in ids],
}
)
return dd.from_pandas(df, npartitions=4)
@pytest.fixture
def duplicate_data(ids):
# In each group we want to keep only the first occurrence (e.g. a1, b1, c1, d1)
df = pd.DataFrame([{"id": _id, "group": _id[0]} for _id in ids])
return dd.from_pandas(df, npartitions=2)
@pytest.mark.parametrize(
"backend",
[
"pandas",
pytest.param("cudf", marks=pytest.mark.gpu),
],
)
@pytest.mark.parametrize("perform_shuffle", [False, True])
def test_remove_duplicates_basic(
backend: Literal["cudf", "pandas"],
perform_shuffle: bool,
sample_data: dd.DataFrame,
duplicate_data: dd.DataFrame,
):
if perform_shuffle:
# We shuffle the data to make sure that duplicates are not in the same partition
duplicate_data = duplicate_data.sample(frac=1).reset_index(drop=True)
sample_data = sample_data.to_backend(backend)
duplicate_data = duplicate_data.to_backend(backend)
# Test basic duplicate removal functionality
result = remove_duplicates(
left=sample_data,
duplicates=duplicate_data,
id_field="id",
group_field="group",
perform_shuffle=perform_shuffle,
).to_backend("pandas")
result = result.compute()
assert list(result.columns) == ["id", "text"]
assert len(result) == 4
# It's not guaranteed that we'll have a0, b0, c0, d0 in the result
# So we should check the first character
assert set(result["id"].apply(lambda x: x[0]).tolist()) == set(["a", "b", "c", "d"])
@pytest.mark.parametrize(
"backend",
[
"pandas",
pytest.param("cudf", marks=pytest.mark.gpu),
],
)
@pytest.mark.parametrize("perform_shuffle", [False, True])
def test_remove_duplicates_all_duplicates(
backend: Literal["cudf", "pandas"],
perform_shuffle: bool,
ids: list[str],
sample_data: dd.DataFrame,
):
duplicates = dd.from_pandas(
pd.DataFrame({"id": ids, "group": [1] * len(ids)}), npartitions=2
)
sample_data = sample_data.to_backend(backend)
duplicates = duplicates.to_backend(backend)
result = remove_duplicates(
left=sample_data,
duplicates=duplicates,
id_field="id",
group_field="group",
perform_shuffle=perform_shuffle,
).to_backend("pandas")
assert list(result.columns) == ["id", "text"]
result = result.compute()
if perform_shuffle:
assert len(result) == 1
else:
# If we don't shuffle, and both partitions have the same group
# in both partitions we'd be left with 1 row after "deduplication"
# and after the left-anti join we'd be left with 2 rows
assert len(result) == 2
@pytest.mark.parametrize(
"backend",
[
"pandas",
pytest.param("cudf", marks=pytest.mark.gpu),
],
)
@pytest.mark.parametrize("perform_shuffle", [False, True])
def test_not_remove_duplicates_unique(
backend: Literal["cudf", "pandas"],
perform_shuffle: bool,
ids: list[str],
sample_data: dd.DataFrame,
):
# We create a dataset where first 30 ids are in one group
# Next 9 ids are in distinct groups
# And last id is not mentioned in duplicates
duplicates = dd.from_pandas(
pd.DataFrame(
{
"id": ids[:30] + ids[30:39],
"group": ["group0"] * 30 + [f"group{i}" for i in range(1, 10)],
}
),
npartitions=2,
)
sample_data = sample_data.to_backend(backend)
duplicates = duplicates.to_backend(backend)
if perform_shuffle:
# We shuffle the data to make sure that duplicates are not in the same partition
duplicates = duplicates.sample(frac=1, random_state=42).reset_index(drop=True)
result = remove_duplicates(
left=sample_data,
duplicates=duplicates,
id_field="id",
group_field="group",
perform_shuffle=perform_shuffle,
).to_backend("pandas")
result = result.compute()
assert list(result.columns) == ["id", "text"]
if perform_shuffle:
# Since we've performed a shuffle, we know groups are collacated and there are 3 groups
# 1. 1 row from the first group of 30
# 2. 9 rows from the 9 distinct groups
# 3. And 1 row from the last group which is not included in set of duplicates
assert len(result) == 1 + 9 + 1
# The last 10 ids should be in the result, there would be one more from the first 30
assert set(ids[30:]).issubset(set(result["id"].tolist()))
else:
# If we don't shuffle, we'de be left with 2 partitions both having rows from group 1
assert len(result) == 2 + 9 + 1
@pytest.mark.parametrize(
"backend",
[
"pandas",
pytest.param("cudf", marks=pytest.mark.gpu),
],
)
def test_remove_duplicates_raise_error(
backend: Literal["cudf", "pandas"],
):
# Create sample dataframes with specific partition counts
df1 = dd.from_pandas(
pd.DataFrame({"id": ["a1", "a2", "a3"], "text": ["text1", "text2", "text3"]}),
npartitions=2,
) # dataset with 2 partitions
duplicates = dd.from_pandas(
pd.DataFrame(
{"id": ["a1", "a2", "a3"], "group": ["group1", "group1", "group1"]}
),
npartitions=3,
) # duplicates dataset with 3 partitions
df1 = df1.to_backend(backend)
duplicates = duplicates.to_backend(backend)
# Test that it raises ValueError when right npartitions are greater than left npartitions
with pytest.raises(ValueError) as exc_info:
remove_duplicates(
left=df1,
duplicates=duplicates,
id_field="id",
group_field="group",
)
expected_msg = (
"The number of partitions in `left` is less than the number of partitions in the duplicates dataset. "
"This may lead to a shuffle join. Please re-read left and right with different partition sizes, or repartition left / right."
)
assert str(exc_info.value) == expected_msg