diff --git a/nemo_curator/modules/config.py b/nemo_curator/modules/config.py index d29f02f49..50c71017b 100644 --- a/nemo_curator/modules/config.py +++ b/nemo_curator/modules/config.py @@ -145,6 +145,7 @@ class SemDedupConfig(BaseConfig): embeddings_save_loc (str): Location to save embeddings. embedding_model_name_or_path (str): Model name or path for embeddings. embedding_batch_size (int): Inital Batch size for processing embeddings. + embedding_pooling_strategy (str): Strategy for pooling embeddings, either "mean_pooling" or "last_token". Defaults to "mean_pooling". write_embeddings_to_disk (bool): If True, saves the embeddings to disk, defaults to True. We recommend setting this to False when you have a delayed pipeline. Setting it to False can lead to more memory overhead. @@ -168,6 +169,8 @@ class SemDedupConfig(BaseConfig): embeddings_save_loc: str = "embeddings" embedding_model_name_or_path: str = "sentence-transformers/all-MiniLM-L6-v2" embedding_batch_size: int = 128 + # Options: "mean_pooling", "last_token" + embedding_pooling_strategy: str = "mean_pooling" write_embeddings_to_disk: bool = True # Clustering config diff --git a/nemo_curator/modules/semantic_dedup/embeddings.py b/nemo_curator/modules/semantic_dedup/embeddings.py index 7c607b63e..7f6315e52 100644 --- a/nemo_curator/modules/semantic_dedup/embeddings.py +++ b/nemo_curator/modules/semantic_dedup/embeddings.py @@ -41,6 +41,7 @@ class EmbeddingConfig: model_name_or_path: str max_seq_length: int = None + pooling_strategy: str = "mean_pooling" # Options: "mean_pooling" or "last_token" def __post_init__(self): self.max_seq_length = AutoTokenizer.from_pretrained( @@ -52,6 +53,10 @@ def __post_init__(self): self.max_seq_length = AutoConfig.from_pretrained( self.model_name_or_path ).max_position_embeddings + if self.pooling_strategy not in ["mean_pooling", "last_token"]: + raise ValueError( + "pooling_strategy must be either 'mean_pooling' or 'last_token'" + ) class EmbeddingPytorchModel(nn.Module): @@ -70,7 +75,10 @@ def feature(self, input_ids, attention_mask): @torch.no_grad() def forward(self, batch): feature = self.feature(batch["input_ids"], batch["attention_mask"]) - return self._mean_pooling(feature, batch["attention_mask"]) + if self.config.pooling_strategy == "mean_pooling": + return self._mean_pooling(feature, batch["attention_mask"]) + else: + return self._get_last_token(feature, batch["attention_mask"]) def _mean_pooling(self, model_output, attention_mask): token_embeddings = model_output[0] @@ -81,6 +89,19 @@ def _mean_pooling(self, model_output, attention_mask): sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9) return F.normalize(sum_embeddings / sum_mask, dim=1) + def _get_last_token(self, model_output, attention_mask): + token_embeddings = model_output[0] + # Get indices of last non-padded tokens for each sequence in batch + last_token_indices = attention_mask.sum(dim=1) - 1 # -1 for 0-based indexing + last_token_indices = last_token_indices.to( + torch.long + ) # Ensure indices are of type long + batch_size = attention_mask.size(0) + batch_indices = torch.arange(batch_size, device=attention_mask.device) + # Get embeddings of last non-padded tokens + last_token_embeddings = token_embeddings[batch_indices, last_token_indices] + return F.normalize(last_token_embeddings, dim=1) + class EmbeddingCrossFitModel(HFModel): def __init__( @@ -116,6 +137,7 @@ def __init__( embedding_batch_size: int, embedding_output_dir: str, embedding_max_mem_gb: Optional[int] = None, + embedding_pooling_strategy: str = "mean_pooling", input_column: str = "text", embedding_column: str = "embeddings", write_embeddings_to_disk: bool = True, @@ -132,6 +154,7 @@ def __init__( embedding_output_dir (str): Directory path where embeddings will be saved. embedding_max_mem_gb (int): Maximum memory usage in GB for the embedding process. If None, it defaults to the available GPU memory minus 4 GB. + embedding_pooling_strategy (str): Strategy for pooling embeddings, either "mean_pooling" or "last_token". Defaults to "mean_pooling". input_column (str): Column name from the data to be used for embedding generation, defaults to "text". write_embeddings_to_disk (bool, optional): If True, saves the embeddings to disk, defaults to True. We recommend setting this to False when you have a delayed pipeline. @@ -152,6 +175,7 @@ def __init__( self.embeddings_config = EmbeddingConfig( model_name_or_path=embedding_model_name_or_path, + pooling_strategy=embedding_pooling_strategy, ) self.batch_size = embedding_batch_size self.logger = self._setup_logger(logger) @@ -184,7 +208,7 @@ def create_embeddings( op.Tokenizer( self.model, cols=[input_column], - tokenizer_type="sentencepiece", + tokenizer_type="default", max_length=self.embeddings_config.max_seq_length, ), op.Predictor( diff --git a/nemo_curator/modules/semantic_dedup/semdedup.py b/nemo_curator/modules/semantic_dedup/semdedup.py index eff5e2ec6..d145a4cdd 100644 --- a/nemo_curator/modules/semantic_dedup/semdedup.py +++ b/nemo_curator/modules/semantic_dedup/semdedup.py @@ -48,6 +48,7 @@ def __init__( self.embedding_creator = EmbeddingCreator( embedding_model_name_or_path=config.embedding_model_name_or_path, embedding_batch_size=config.embedding_batch_size, + embedding_pooling_strategy=config.embedding_pooling_strategy, input_column=input_column, embedding_output_dir=os.path.join(cache_dir, config.embeddings_save_loc), write_embeddings_to_disk=config.write_embeddings_to_disk, diff --git a/tests/test_semdedup.py b/tests/test_semdedup.py index 4cc66901d..8ccf850a7 100644 --- a/tests/test_semdedup.py +++ b/tests/test_semdedup.py @@ -13,9 +13,13 @@ # limitations under the License. import os +import numpy as np import pytest +import torch +import torch.nn.functional as F from dask.dataframe.utils import assert_eq from distributed import Client +from transformers import AutoConfig, AutoModel, AutoTokenizer from nemo_curator import SemDedup, SemDedupConfig from nemo_curator.datasets import DocumentDataset @@ -24,6 +28,9 @@ cudf = gpu_only_import("cudf") dask_cudf = gpu_only_import("dask_cudf") LocalCUDACluster = gpu_only_import_from("dask_cuda", "LocalCUDACluster") +EmbeddingCreator = gpu_only_import_from( + "nemo_curator.modules.semantic_dedup.embeddings", "EmbeddingCreator" +) @pytest.fixture @@ -80,3 +87,89 @@ def test_sem_dedup( duplicate_docs = [2, 3, 4, 200, 300] expected_df = cudf.Series(duplicate_docs, name="id") assert_eq(result_df["id"].sort_values(), expected_df, check_index=False) + + @pytest.mark.parametrize("pooling_strategy", ["last_token", "mean_pooling"]) + def test_embedding_creator_pooling_strategies(self, tmpdir, pooling_strategy): + test_text_1 = "The quick brown fox jumps over the lazy dog" + test_text_2 = "The brown fox jumps over the dog" + test_texts = [test_text_1, test_text_2] * 32 + df = cudf.DataFrame({"text": test_texts}) + ddf = dask_cudf.from_cudf(df, 1) + cache_dir = os.path.join(tmpdir, "test_embeddings_cache") + + embedding_creator = EmbeddingCreator( + embedding_model_name_or_path="sentence-transformers/all-MiniLM-L6-v2", + embedding_batch_size=32, + embedding_pooling_strategy=pooling_strategy, + input_column="text", + embedding_output_dir=os.path.join(cache_dir, "mean_embeddings"), + ) + embeddings = embedding_creator.create_embeddings(ddf).compute() + embeddings = embeddings["embeddings"].to_arrow().to_pylist() + embeddings = np.array(embeddings) + reference_embeddings = get_reference_embeddings( + test_texts, pooling_strategy=pooling_strategy + ) + assert np.allclose( + embeddings, reference_embeddings, atol=1e-3 + ), "Embeddings should match reference embeddings" + + +def get_reference_embeddings( + texts, + model_name="sentence-transformers/all-MiniLM-L6-v2", + pooling_strategy="last_token", +): + """ + Get embeddings using either last token or mean pooling strategy. + + Args: + texts: List of input texts + model_name: Name or path of the model to use + pooling_strategy: Either "last_token" for last token or "mean" for mean pooling + """ + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModel.from_pretrained(model_name) + model = model.to("cuda") + model.eval() + max_len_to_use = tokenizer.model_max_length + if max_len_to_use > 1e5: + max_len_to_use = AutoConfig.from_pretrained(model_name).max_position_embeddings + max_seq_length: int = max_len_to_use + + embs = [] + for text in texts: + inputs = tokenizer( + text, + return_tensors="pt", + padding=True, + truncation=True, + max_length=max_seq_length, + ) + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + with torch.no_grad(): + with torch.autocast(device_type="cuda"): + outputs = model(**inputs) + + if pooling_strategy == "last_token": + embeddings = outputs.last_hidden_state[:, -1, :] + elif pooling_strategy == "mean_pooling": + token_embeddings = outputs.last_hidden_state + attention_mask = inputs["attention_mask"] + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1) + sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9) + embeddings = sum_embeddings / sum_mask + else: + raise ValueError( + "pooling_strategy must be either 'last_token' or 'mean_pooling'" + ) + + normed_emb = F.normalize(embeddings, dim=1).cpu() + normed_emb = normed_emb.squeeze(0) + embs.append(normed_emb) + + return np.array(embs)