Skip to content

Commit

Permalink
Support the new minhash 25.02 api (#445)
Browse files Browse the repository at this point in the history
Signed-off-by: Praateek <[email protected]>
  • Loading branch information
praateekmahajan authored Dec 30, 2024
1 parent 4fb7f54 commit d401333
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 16 deletions.
13 changes: 9 additions & 4 deletions nemo_curator/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,15 @@
except (ImportError, TypeError):
CURRENT_CUDF_VERSION = parse_version("24.10.0")

# TODO remove this once 24.12.0 becomes the base version of cudf in nemo-curator
MINHASH_PERMUTED_AVAILABLE = CURRENT_CUDF_VERSION >= parse_version("24.12.0") or (
CURRENT_CUDF_VERSION.is_prerelease
and CURRENT_CUDF_VERSION.base_version >= "24.12.0"
# TODO remove this once 25.02 becomes the base version of cudf in nemo-curator

# minhash in < 24.12 used to have a minhash(txt) api which was deprecated in favor of
# minhash(a, b) in 25.02 (in 24.12, minhash_permuted(a,b) was introduced)
MINHASH_DEPRECATED_API = (
CURRENT_CUDF_VERSION.base_version < parse_version("24.12").base_version
)
MINHASH_PERMUTED_AVAILABLE = (CURRENT_CUDF_VERSION.major == 24) & (
CURRENT_CUDF_VERSION.minor == 12
)

# TODO: remove when dask min version gets bumped
Expand Down
36 changes: 24 additions & 12 deletions nemo_curator/modules/fuzzy_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from dask.utils import M
from tqdm import tqdm

from nemo_curator._compat import MINHASH_PERMUTED_AVAILABLE
from nemo_curator._compat import MINHASH_DEPRECATED_API, MINHASH_PERMUTED_AVAILABLE
from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.modules.config import FuzzyDuplicatesConfig
Expand Down Expand Up @@ -98,15 +98,17 @@ def __init__(
"""
self.num_hashes = num_hashes
self.char_ngram = char_ngrams
if MINHASH_PERMUTED_AVAILABLE:
if MINHASH_DEPRECATED_API:
self.seeds = self.generate_seeds(n_seeds=self.num_hashes, seed=seed)
else:
self.seeds = self.generate_hash_permutation_seeds(
bit_width=64 if use_64bit_hash else 32,
n_permutations=self.num_hashes,
seed=seed,
)
else:
self.seeds = self.generate_seeds(n_seeds=self.num_hashes, seed=seed)

self.minhash_method = self.minhash64 if use_64bit_hash else self.minhash32

self.id_field = id_field
self.text_field = text_field

Expand Down Expand Up @@ -171,7 +173,7 @@ def minhash32(
if not isinstance(ser, cudf.Series):
raise TypeError("Expected data of type cudf.Series")

if not MINHASH_PERMUTED_AVAILABLE:
if MINHASH_DEPRECATED_API:
warnings.warn(
"Using an outdated minhash implementation, please update to cuDF version 24.12 "
"or later for improved performance. "
Expand All @@ -184,9 +186,14 @@ def minhash32(
seeds_a = cudf.Series(seeds[:, 0], dtype="uint32")
seeds_b = cudf.Series(seeds[:, 1], dtype="uint32")

return ser.str.minhash_permuted(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)
if MINHASH_PERMUTED_AVAILABLE:
return ser.str.minhash_permuted(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)
else:
return ser.str.minhash(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)

def minhash64(
self, ser: cudf.Series, seeds: np.ndarray, char_ngram: int
Expand All @@ -196,7 +203,7 @@ def minhash64(
"""
if not isinstance(ser, cudf.Series):
raise TypeError("Expected data of type cudf.Series")
if not MINHASH_PERMUTED_AVAILABLE:
if MINHASH_DEPRECATED_API:
warnings.warn(
"Using an outdated minhash implementation, please update to cuDF version 24.12 "
"or later for improved performance. "
Expand All @@ -209,9 +216,14 @@ def minhash64(
seeds_a = cudf.Series(seeds[:, 0], dtype="uint64")
seeds_b = cudf.Series(seeds[:, 1], dtype="uint64")

return ser.str.minhash64_permuted(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)
if MINHASH_PERMUTED_AVAILABLE:
return ser.str.minhash64_permuted(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)
else:
return ser.str.minhash64(
a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram
)

def __call__(self, dataset: DocumentDataset) -> Union[str, DocumentDataset]:
"""
Expand Down

0 comments on commit d401333

Please sign in to comment.