Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce Dataframe Backend Checks #514

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
8 changes: 5 additions & 3 deletions nemo_curator/classifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from dataclasses import dataclass

os.environ["RAPIDS_NO_INITIALIZE"] = "1"
from abc import ABC, abstractmethod
from abc import abstractmethod
from typing import List, Optional, Union

import torch
Expand All @@ -25,10 +25,11 @@
from transformers import AutoModel

from nemo_curator.datasets import DocumentDataset
from nemo_curator.modules.base import Module
from nemo_curator.utils.distributed_utils import get_gpu_memory_info


class DistributedDataClassifier(ABC):
class DistributedDataClassifier(Module):
"""Abstract class for running multi-node multi-GPU data classification"""

def __init__(
Expand All @@ -43,6 +44,7 @@ def __init__(
device_type: str,
autocast: bool,
):
super().__init__(input_backend="cudf")
self.model = model
self.labels = labels
self.filter_by = filter_by
Expand All @@ -53,7 +55,7 @@ def __init__(
self.device_type = device_type
self.autocast = autocast

def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
def call(self, dataset: DocumentDataset) -> DocumentDataset:
result_doc_dataset = self._run_classifier(dataset)
if self.filter_by is not None:
return self._filter_documents(result_doc_dataset)
Expand Down
9 changes: 9 additions & 0 deletions nemo_curator/filters/doc_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ def keep_document(self, scores: Any) -> bool:
"keep_document method must be implemented by subclasses"
)

def get_backend(self) -> str:
praateekmahajan marked this conversation as resolved.
Show resolved Hide resolved
"""
The dataframe backend the filter operates on.
Can be 'pandas', 'cudf', or 'any'. Defaults to 'pandas'.
Returns:
str: A string representing the dataframe backend the filter needs as input
"""
return "pandas"

@property
def name(self):
return self._name
Expand Down
9 changes: 9 additions & 0 deletions nemo_curator/modifiers/doc_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,12 @@ def __init__(self):
@abstractmethod
def modify_document(self, text):
pass

def get_backend(self) -> str:
"""
The dataframe backend the filter operates on.
Can be 'pandas', 'cudf', or 'any'. Defaults to 'pandas'.
Returns:
str: A string representing the dataframe backend the filter needs as input
"""
return "pandas"
ryantwolf marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 4 additions & 0 deletions nemo_curator/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
from nemo_curator.utils.import_utils import gpu_only_import_from

from .add_id import AddId
from .base import Module
from .config import FuzzyDuplicatesConfig, SemDedupConfig
from .dataset_ops import blend_datasets, Shuffle
from .exact_dedup import ExactDuplicates
from .meta import Sequential
from .modify import Modify
from .task import TaskDecontamination
from .to_backend import ToBackend

# GPU packages
MinHash = gpu_only_import_from("nemo_curator.modules.fuzzy_dedup.minhash", "MinHash")
ryantwolf marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -88,4 +90,6 @@
"ClusteringModel",
"SemanticClusterLevelDedup",
"SemDedup",
"Module",
"ToBackend",
]
6 changes: 4 additions & 2 deletions nemo_curator/modules/add_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,20 @@
from dask import delayed

from nemo_curator.datasets import DocumentDataset
from nemo_curator.modules.base import Module
from nemo_curator.utils.module_utils import count_digits


class AddId:
class AddId(Module):
def __init__(
self, id_field, id_prefix: str = "doc_id", start_index: Optional[int] = None
) -> None:
super().__init__(input_backend="pandas")
self.id_field = id_field
self.id_prefix = id_prefix
self.start_index = start_index

def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
def call(self, dataset: DocumentDataset) -> DocumentDataset:
if self.start_index is None:
return self._add_id_fast(dataset)
else:
Expand Down
50 changes: 50 additions & 0 deletions nemo_curator/modules/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod

from nemo_curator.datasets import DocumentDataset


class Module(ABC):
ryantwolf marked this conversation as resolved.
Show resolved Hide resolved
SUPPORTED_BACKENDS = ["pandas", "cudf", "any"]

def __init__(self, input_backend: str, name=None) -> None:
ryantwolf marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()
self.name = name or self.__class__.__name__

if input_backend not in self.SUPPORTED_BACKENDS:
raise ValueError(
f"{input_backend} not one of the supported backends {self.SUPPORTED_BACKENDS}"
)
self.input_backend = input_backend

@abstractmethod
def call(self, dataset: DocumentDataset) -> DocumentDataset:
ryantwolf marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError("call method must be implemented by subclasses")

def _check_backend(self, ddf):
praateekmahajan marked this conversation as resolved.
Show resolved Hide resolved
if self.input_backend == "any":
return

backend = type(ddf._meta).__module__.split(".")[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test case for this? And also type hints for ddf which would be dd.DataFrame

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What part of this do you want to test? I'm not quite seeing which part of this you want me to isolate.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As someone reading this, it's not obviously clear what will be the output of .__module__.split(".")[0] of the type of object's _meta attribute. So I would've proposed that you move this to a function and have tests for that and use that function here.
However I see you've changed this to is_cudf_type(..) which is much more readable, and now we can unittest that directly that.

ryantwolf marked this conversation as resolved.
Show resolved Hide resolved
if backend != self.input_backend:
raise ValueError(
f"Module {self.name} requires dataset to have backend {self.input_backend} but got backend {backend}"
ryantwolf marked this conversation as resolved.
Show resolved Hide resolved
)

ryantwolf marked this conversation as resolved.
Show resolved Hide resolved
def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
ryantwolf marked this conversation as resolved.
Show resolved Hide resolved
self._check_backend(dataset.df)

return self.call(dataset)
6 changes: 4 additions & 2 deletions nemo_curator/modules/dataset_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import numpy as np

from nemo_curator.datasets.doc_dataset import DocumentDataset
from nemo_curator.modules.base import Module


def default_filename(partition_num: int) -> str:
return f"file_{partition_num:010d}.jsonl"


class Shuffle:
class Shuffle(Module):
def __init__(
self,
seed: Optional[int] = None,
Expand All @@ -32,13 +33,14 @@ def __init__(
will look like given the partition number. The default method names the partition
f'file_{partition_num:010d}.jsonl' and should be changed if the user is not using a .jsonl format.
"""
super().__init__(input_backend="pandas")
self.seed = seed
self.npartitions = npartitions
self.partition_to_filename = partition_to_filename
self.rand_col = "_shuffle_rand"
self.filename_col = filename_col

def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
def call(self, dataset: DocumentDataset) -> DocumentDataset:
if self.seed is None:
return self.shuffle_nondeterministic(dataset)
else:
Expand Down
6 changes: 4 additions & 2 deletions nemo_curator/modules/exact_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@
from nemo_curator._compat import DASK_P2P_ERROR
from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.modules.base import Module
from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix
from nemo_curator.utils.gpu_utils import is_cudf_type


class ExactDuplicates:
class ExactDuplicates(Module):
"""Find exact duplicates in a document corpus"""

SUPPORTED_HASHES = {"md5"}
Expand All @@ -59,6 +60,7 @@ def __init__(
cache_dir: str, Default None
If specified, will compute & write duplicate id's to cache directory.
"""
super().__init__(input_backend="any")

if hash_method not in self.SUPPORTED_HASHES:
raise ValueError(
Expand Down Expand Up @@ -135,7 +137,7 @@ def hash_documents(
# TODO: Generalize ty using self.hash_method
return df.apply(lambda x: md5(x.encode()).hexdigest())

def __call__(self, dataset: DocumentDataset) -> Union[DocumentDataset, str]:
def call(self, dataset: DocumentDataset) -> Union[DocumentDataset, str]:
"""
Find document ID's for exact duplicates in a given DocumentDataset
Parameters
Expand Down
22 changes: 13 additions & 9 deletions nemo_curator/modules/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from nemo_curator.datasets import DocumentDataset
from nemo_curator.datasets.parallel_dataset import ParallelDataset
from nemo_curator.filters import DocumentFilter
from nemo_curator.modules.base import Module
from nemo_curator.utils.module_utils import is_batched

# Override so that pd.NA is not passed during the metadata inference
Expand All @@ -31,7 +32,7 @@
)


class Score:
class Score(Module):
"""
The module responsible for adding metadata to records based on statistics about the text.
It accepts an arbitrary scoring function that accepts a text field and returns a score.
Expand All @@ -56,12 +57,13 @@ def __init__(
text_field (str): The field the documents will be read from.
score_type (Union[type, str]): The datatype of the score that will be made for each document.
"""
super().__init__(input_backend="pandas")
self.score_fn = score_fn
self.score_field = score_field
self.text_field = text_field
self.score_type = score_type

def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
def call(self, dataset: DocumentDataset) -> DocumentDataset:
"""
Applies the scoring to a dataset

Expand Down Expand Up @@ -89,7 +91,7 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
return dataset


class Filter:
class Filter(Module):
"""
The module responsible for filtering records based on a metadata field.
It accepts an arbitrary filter function that accepts a metadata field and returns True if the field should be kept.
Expand All @@ -107,6 +109,7 @@ def __init__(self, filter_fn: Callable, filter_field: str, invert: bool = False)
filter_field (str): The field(s) to be passed into the filter function.
invert (bool): Whether to invert the filter condition.
"""
super().__init__(input_backend="pandas")
self.filter_fn = filter_fn
self.filter_field = filter_field
self.invert = invert
Expand Down Expand Up @@ -134,7 +137,7 @@ def compute_filter_mask(self, dataset: DocumentDataset):

return bool_mask

def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
def call(self, dataset: DocumentDataset) -> DocumentDataset:
"""
Applies the filtering to a dataset

Expand All @@ -148,7 +151,7 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
return DocumentDataset(dataset.df[bool_mask])


class ScoreFilter:
class ScoreFilter(Module):
"""
The module responsible for applying a filter to all documents in a DocumentDataset.
It accepts an arbitrary DocumentFilter and first computes the score for a document.
Expand Down Expand Up @@ -176,6 +179,7 @@ def __init__(
score_type (Union[type, str]): The datatype of the score that will be made for each document.
invert (bool): If True, will keep all documents that are normally discarded.
"""
super().__init__(input_backend=filter_obj.get_backend())
self.filter_obj = filter_obj
self.text_field = text_field
self.score_field = score_field
Expand Down Expand Up @@ -219,7 +223,7 @@ def compute_filter_mask(self, dataset: DocumentDataset):

return bool_mask

def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
def call(self, dataset: DocumentDataset) -> DocumentDataset:
"""
Scores and filters all records in the dataset

Expand All @@ -233,7 +237,7 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
return DocumentDataset(dataset.df[bool_mask])


class ParallelScoreFilter:
class ParallelScoreFilter(Module):
def __init__(
self,
src_filter_obj,
Expand Down Expand Up @@ -263,15 +267,15 @@ def __init__(
score_type (Optional[str]): The datatype of the score that will be made for each document. Defaults to None.
invert (bool, optional): If True, will keep all documents that are normally discarded. Defaults to False.
"""

super().__init__(input_backend=src_filter_obj.get_backend())
self.source_score_filter = ScoreFilter(
src_filter_obj, src_field, src_score, score_type, invert
)
self.target_score_filter = ScoreFilter(
tgt_filter_obj, tgt_field, tgt_score, score_type, invert
)

def __call__(self, dataset: ParallelDataset):
def call(self, dataset: ParallelDataset):
src_bool_mask = self.source_score_filter.compute_filter_mask(dataset)
tgt_bool_mask = self.target_score_filter.compute_filter_mask(dataset)

Expand Down
6 changes: 4 additions & 2 deletions nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.modules.base import Module
from nemo_curator.modules.config import FuzzyDuplicatesConfig
from nemo_curator.modules.fuzzy_dedup._mapbuckets import _MapBuckets
from nemo_curator.modules.fuzzy_dedup._shuffle import _Shuffle
Expand All @@ -35,7 +36,7 @@
from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix


class FuzzyDuplicates:
class FuzzyDuplicates(Module):
def __init__(
self,
config: FuzzyDuplicatesConfig,
Expand All @@ -53,6 +54,7 @@ def __init__(
DocumentDataset containing IDs of all documents and the corresponding duplicate group
they belong to. Documents in the same group are near duplicates.
"""
super().__init__(input_backend="cudf")
if isinstance(logger, str):
self._logger = create_logger(
rank=0,
Expand Down Expand Up @@ -129,7 +131,7 @@ def __init__(
profile_dir=self.config.profile_dir,
)

def __call__(self, dataset: DocumentDataset):
def call(self, dataset: DocumentDataset):
"""
Parameters
----------
Expand Down
7 changes: 4 additions & 3 deletions nemo_curator/modules/modify.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo_curator.datasets import DocumentDataset
from nemo_curator.modifiers import DocumentModifier
from nemo_curator.modules.base import Module
from nemo_curator.utils.module_utils import is_batched


class Modify:
class Modify(Module):
def __init__(self, modifier: DocumentModifier, text_field="text"):
super().__init__(input_backend=modifier.get_backend())
self.modifier = modifier
self.text_field = text_field

def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
def call(self, dataset: DocumentDataset) -> DocumentDataset:
if is_batched(self.modifier.modify_document):
dataset.df[self.text_field] = dataset.df[self.text_field].map_partitions(
self.modifier.modify_document, meta=(None, str)
Expand Down
Loading