Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce Dataframe Backend Checks #514

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
5 changes: 4 additions & 1 deletion docs/user-guide/api/dask.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@ Dask Cluster Functions

.. autofunction:: nemo_curator.get_client

.. autofunction:: nemo_curator.get_network_interfaces
.. autofunction:: nemo_curator.get_network_interfaces

.. autoclass:: nemo_curator.ToBackend
:members:
40 changes: 40 additions & 0 deletions docs/user-guide/cpuvsgpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,46 @@ To read a dataset into GPU memory, one could use the following function call.
Even if you start a GPU dask cluster, you can't operate on datasets that use a ``pandas`` backend.
The ``DocuemntDataset`` must either have been originally read in with a ``cudf`` backend, or it must be transferred during the script.

-----------------------------------------
Moving data between CPU and GPU
-----------------------------------------

The ``ToBackend`` module provides a way to move data between CPU memory and GPU memory by swapping between pandas and cuDF backends for your dataset.
To see how it works, take a look at this example.

.. code-block:: python

from nemo_curator import Sequential, ToBackend, ScoreFilter, get_client
from nemo_curator.datasets import DocumentDataset
from nemo_curator.classifiers import DomainClassifier
from nemo_curator.filters import RepeatingTopNGramsFilter, NonAlphaNumericFilter

def main():
client = get_client(cluster_type="gpu")

dataset = DocumentDataset.read_json("books.jsonl")
curation_pipeline = Sequential([
ScoreFilter(RepeatingTopNGramsFilter(n=5)),
ToBackend("cudf"),
DomainClassifier(),
ryantwolf marked this conversation as resolved.
Show resolved Hide resolved
ToBackend("pandas"),
ScoreFilter(NonAlphaNumericFilter()),
])

curated_dataset = curation_pipeline(dataset)

curated_dataset.to_json("curated_books.jsonl")

if __name__ == "__main__":
main()

Let's highlight some of the important parts of this example.

* ``client = get_client(cluster_type="gpu")``: Creates a local Dask cluster with access to the GPUs. In order to use/swap to a cuDF dataframe backend, you need to make sure you are running on a GPU Dask cluster.
* ``dataset = DocumentDataset.read_json("books.jsonl")``: Reads in the dataset to a pandas (CPU) backend by default.
* ``curation_pipeline = ...``: Defines a curation pipeline consisting of a CPU filtering step, a GPU classifier step, and another CPU filtering step. The ``ToBackend("cudf")`` moves the dataset from CPU to GPU for the classifier, and the ``ToBackend("pandas")`` moves the dataset back to the CPU from the GPU for the last filter.
* ``curated_dataset.to_json("curated_books.jsonl")``: Writes the dataset directly to disk from the GPU. There is no need to transfer back to the CPU before writing to disk.

-----------------------------------------
Dask with Slurm
-----------------------------------------
Expand Down
8 changes: 5 additions & 3 deletions nemo_curator/classifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from dataclasses import dataclass

os.environ["RAPIDS_NO_INITIALIZE"] = "1"
from abc import ABC, abstractmethod
from abc import abstractmethod
from typing import List, Optional, Union

import torch
Expand All @@ -25,10 +25,11 @@
from transformers import AutoModel

from nemo_curator.datasets import DocumentDataset
from nemo_curator.modules.base import BaseModule
from nemo_curator.utils.distributed_utils import get_gpu_memory_info


class DistributedDataClassifier(ABC):
class DistributedDataClassifier(BaseModule):
"""Abstract class for running multi-node multi-GPU data classification"""

def __init__(
Expand All @@ -43,6 +44,7 @@ def __init__(
device_type: str,
autocast: bool,
):
super().__init__(input_backend="cudf")
self.model = model
self.labels = labels
self.filter_by = filter_by
Expand All @@ -53,7 +55,7 @@ def __init__(
self.device_type = device_type
self.autocast = autocast

def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
def call(self, dataset: DocumentDataset) -> DocumentDataset:
result_doc_dataset = self._run_classifier(dataset)
if self.filter_by is not None:
return self._filter_documents(result_doc_dataset)
Expand Down
12 changes: 11 additions & 1 deletion nemo_curator/filters/doc_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import importlib
from abc import ABC, abstractmethod
from typing import Any, Union
from typing import Any, Literal, Union

from nemo_curator.filters.bitext_filter import BitextFilter

Expand Down Expand Up @@ -81,6 +81,16 @@ def keep_document(self, scores: Any) -> bool:
"keep_document method must be implemented by subclasses"
)

@property
def backend(self) -> Literal["pandas", "cudf", "any"]:
"""
The dataframe backend the filter operates on.
Can be 'pandas', 'cudf', or 'any'. Defaults to 'pandas'.
Returns:
str: A string representing the dataframe backend the filter needs as input
"""
return "pandas"

@property
def name(self):
return self._name
Expand Down
11 changes: 11 additions & 0 deletions nemo_curator/modifiers/doc_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Literal


class DocumentModifier(ABC):
Expand All @@ -26,3 +27,13 @@ def __init__(self):
@abstractmethod
def modify_document(self, text):
pass

@property
def backend(self) -> Literal["pandas", "cudf", "any"]:
"""
The dataframe backend the modifier operates on.
Can be 'pandas', 'cudf', or 'any'. Defaults to 'pandas'.
Returns:
str: A string representing the dataframe backend the modifier needs as input
"""
return "pandas"
4 changes: 4 additions & 0 deletions nemo_curator/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
from nemo_curator.utils.import_utils import gpu_only_import_from

from .add_id import AddId
from .base import BaseModule
from .config import FuzzyDuplicatesConfig, SemDedupConfig
from .dataset_ops import blend_datasets, Shuffle
from .exact_dedup import ExactDuplicates
from .meta import Sequential
from .modify import Modify
from .task import TaskDecontamination
from .to_backend import ToBackend

# GPU packages
MinHash = gpu_only_import_from("nemo_curator.modules.fuzzy_dedup.minhash", "MinHash")
ryantwolf marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -88,4 +90,6 @@
"ClusteringModel",
"SemanticClusterLevelDedup",
"SemDedup",
"BaseModule",
"ToBackend",
]
6 changes: 4 additions & 2 deletions nemo_curator/modules/add_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,20 @@
from dask import delayed

from nemo_curator.datasets import DocumentDataset
from nemo_curator.modules.base import BaseModule
from nemo_curator.utils.module_utils import count_digits


class AddId:
class AddId(BaseModule):
def __init__(
self, id_field, id_prefix: str = "doc_id", start_index: Optional[int] = None
) -> None:
super().__init__(input_backend="pandas")
self.id_field = id_field
self.id_prefix = id_prefix
self.start_index = start_index

def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
def call(self, dataset: DocumentDataset) -> DocumentDataset:
if self.start_index is None:
return self._add_id_fast(dataset)
else:
Expand Down
84 changes: 84 additions & 0 deletions nemo_curator/modules/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Literal, Optional

import dask.dataframe as dd

from nemo_curator.datasets import DocumentDataset
from nemo_curator.utils.gpu_utils import is_cudf_type


class BaseModule(ABC):
"""
Base class for all NeMo Curator modules.

Handles validating that data lives on the correct device for each module
"""

SUPPORTED_BACKENDS = ["pandas", "cudf", "any"]

def __init__(
self,
input_backend: Literal["pandas", "cudf", "any"],
name: Optional[str] = None,
) -> None:
"""
Constructs a Module

Args:
input_backend (Literal["pandas", "cudf", "any"]): The backend the input dataframe must be on for the module to work
name (str, Optional): The name of the module. If None, defaults to self.__class__.__name__
"""
super().__init__()
self.name = name or self.__class__.__name__

if input_backend not in self.SUPPORTED_BACKENDS:
raise ValueError(
f"{input_backend} not one of the supported backends {self.SUPPORTED_BACKENDS}"
)
self.input_backend = input_backend

@abstractmethod
def call(self, dataset: DocumentDataset):
"""
Performs an arbitrary operation on a dataset

Args:
dataset (DocumentDataset): The dataset to operate on
"""
raise NotImplementedError("call method must be implemented by subclasses")

def _validate_correct_backend(self, ddf: dd.DataFrame):
if self.input_backend == "any":
return

backend = "cudf" if is_cudf_type(ddf) else "pandas"
if backend != self.input_backend:
raise ValueError(
f"Module {self.name} requires dataset to have backend {self.input_backend} but got backend {backend}."
"Try using nemo_curator.ToBackend to swap dataframe backends before running this module."
)

ryantwolf marked this conversation as resolved.
Show resolved Hide resolved
def __call__(self, dataset: DocumentDataset):
"""
Validates the dataset is on the right backend, and performs an arbitrary operation on it

Args:
dataset (DocumentDataset): The dataset to operate on
"""
self._validate_correct_backend(dataset.df)

return self.call(dataset)
6 changes: 4 additions & 2 deletions nemo_curator/modules/dataset_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import numpy as np

from nemo_curator.datasets.doc_dataset import DocumentDataset
from nemo_curator.modules.base import BaseModule


def default_filename(partition_num: int) -> str:
return f"file_{partition_num:010d}.jsonl"


class Shuffle:
class Shuffle(BaseModule):
def __init__(
self,
seed: Optional[int] = None,
Expand All @@ -32,13 +33,14 @@ def __init__(
will look like given the partition number. The default method names the partition
f'file_{partition_num:010d}.jsonl' and should be changed if the user is not using a .jsonl format.
"""
super().__init__(input_backend="pandas")
self.seed = seed
self.npartitions = npartitions
self.partition_to_filename = partition_to_filename
self.rand_col = "_shuffle_rand"
self.filename_col = filename_col

def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
def call(self, dataset: DocumentDataset) -> DocumentDataset:
if self.seed is None:
return self.shuffle_nondeterministic(dataset)
else:
Expand Down
6 changes: 4 additions & 2 deletions nemo_curator/modules/exact_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@
from nemo_curator._compat import DASK_P2P_ERROR
from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.modules.base import BaseModule
from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix
from nemo_curator.utils.gpu_utils import is_cudf_type


class ExactDuplicates:
class ExactDuplicates(BaseModule):
"""Find exact duplicates in a document corpus"""

SUPPORTED_HASHES = {"md5"}
Expand All @@ -59,6 +60,7 @@ def __init__(
cache_dir: str, Default None
If specified, will compute & write duplicate id's to cache directory.
"""
super().__init__(input_backend="any")

if hash_method not in self.SUPPORTED_HASHES:
raise ValueError(
Expand Down Expand Up @@ -135,7 +137,7 @@ def hash_documents(
# TODO: Generalize ty using self.hash_method
return df.apply(lambda x: md5(x.encode()).hexdigest())

def __call__(self, dataset: DocumentDataset) -> Union[DocumentDataset, str]:
def call(self, dataset: DocumentDataset) -> Union[DocumentDataset, str]:
"""
Find document ID's for exact duplicates in a given DocumentDataset
Parameters
Expand Down
Loading