From b0e331722aa282f8eccfb5e1af816d27b45a3f60 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Sun, 19 Jan 2025 16:11:44 -0800 Subject: [PATCH 1/7] Add pooling stratedgy Signed-off-by: Vibhu Jawa --- nemo_curator/modules/config.py | 2 + .../modules/semantic_dedup/embeddings.py | 18 +++- .../modules/semantic_dedup/semdedup.py | 1 + tests/test_semdedup.py | 88 +++++++++++++++++++ 4 files changed, 107 insertions(+), 2 deletions(-) diff --git a/nemo_curator/modules/config.py b/nemo_curator/modules/config.py index 551f261e1..b501f1754 100644 --- a/nemo_curator/modules/config.py +++ b/nemo_curator/modules/config.py @@ -123,6 +123,7 @@ class SemDedupConfig(BaseConfig): embeddings_save_loc (str): Location to save embeddings. embedding_model_name_or_path (str): Model name or path for embeddings. embedding_batch_size (int): Inital Batch size for processing embeddings. + embedding_pooling_strategy (str): Strategy for pooling embeddings, either "mean" or "last_token". Defaults to "last_token". clustering_save_loc (str): Location to save clustering results. n_clusters (int): Number of clusters. seed (int): Seed for clustering. @@ -143,6 +144,7 @@ class SemDedupConfig(BaseConfig): embeddings_save_loc: str = "embeddings" embedding_model_name_or_path: str = "sentence-transformers/all-MiniLM-L6-v2" embedding_batch_size: int = 128 + embedding_pooling_strategy: str = "last_token" # Clustering config clustering_save_loc: str = "clustering_results" diff --git a/nemo_curator/modules/semantic_dedup/embeddings.py b/nemo_curator/modules/semantic_dedup/embeddings.py index 4a0b638b0..7385f3bd2 100644 --- a/nemo_curator/modules/semantic_dedup/embeddings.py +++ b/nemo_curator/modules/semantic_dedup/embeddings.py @@ -41,6 +41,7 @@ class EmbeddingConfig: model_name_or_path: str max_seq_length: int = None + pooling_strategy: str = "mean" # Options: "mean" or "last_token" def __post_init__(self): self.max_seq_length = AutoTokenizer.from_pretrained( @@ -52,6 +53,8 @@ def __post_init__(self): self.max_seq_length = AutoConfig.from_pretrained( self.model_name_or_path ).max_position_embeddings + if self.pooling_strategy not in ["mean", "last_token"]: + raise ValueError("pooling_strategy must be either 'mean' or 'last_token'") class EmbeddingPytorchModel(nn.Module): @@ -70,7 +73,10 @@ def feature(self, input_ids, attention_mask): @torch.no_grad() def forward(self, batch): feature = self.feature(batch["input_ids"], batch["attention_mask"]) - return self._mean_pooling(feature, batch["attention_mask"]) + if self.config.pooling_strategy == "mean": + return self._mean_pooling(feature, batch["attention_mask"]) + else: + return self._get_last_token(feature) def _mean_pooling(self, model_output, attention_mask): token_embeddings = model_output[0] @@ -81,6 +87,11 @@ def _mean_pooling(self, model_output, attention_mask): sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9) return F.normalize(sum_embeddings / sum_mask, dim=1) + def _get_last_token(self, model_output): + token_embeddings = model_output[0] + last_token_embeddings = token_embeddings[:, -1, :] + return F.normalize(last_token_embeddings, dim=1) + class EmbeddingCrossFitModel(HFModel): def __init__( @@ -116,6 +127,7 @@ def __init__( embedding_batch_size: int, embedding_output_dir: str, embedding_max_mem_gb: Optional[int] = None, + embedding_pooling_strategy: str = "last_token", input_column: str = "text", embedding_column: str = "embeddings", write_embeddings_to_disk: bool = True, @@ -132,6 +144,7 @@ def __init__( embedding_output_dir (str): Directory path where embeddings will be saved. embedding_max_mem_gb (int): Maximum memory usage in GB for the embedding process. If None, it defaults to the available GPU memory minus 4 GB. + embedding_pooling_strategy (str): Strategy for pooling embeddings, either "mean" or "last_token". Defaults to "last_token". input_column (str): Column name from the data to be used for embedding generation, defaults to "text". write_embeddings_to_disk (bool, optional): If True, saves the embeddings to disk, defaults to True. We recommend setting this to False when you have a delayed pipeline. @@ -152,6 +165,7 @@ def __init__( self.embeddings_config = EmbeddingConfig( model_name_or_path=embedding_model_name_or_path, + pooling_strategy=embedding_pooling_strategy, ) self.batch_size = embedding_batch_size self.logger = self._setup_logger(logger) @@ -184,7 +198,7 @@ def create_embeddings( op.Tokenizer( self.model, cols=[input_column], - tokenizer_type="sentencepiece", + tokenizer_type="default", max_length=self.embeddings_config.max_seq_length, ), op.Predictor( diff --git a/nemo_curator/modules/semantic_dedup/semdedup.py b/nemo_curator/modules/semantic_dedup/semdedup.py index a03d152b1..3c389f415 100644 --- a/nemo_curator/modules/semantic_dedup/semdedup.py +++ b/nemo_curator/modules/semantic_dedup/semdedup.py @@ -48,6 +48,7 @@ def __init__( self.embedding_creator = EmbeddingCreator( embedding_model_name_or_path=config.embedding_model_name_or_path, embedding_batch_size=config.embedding_batch_size, + embedding_pooling_strategy=config.embedding_pooling_strategy, input_column=input_column, embedding_output_dir=os.path.join(cache_dir, config.embeddings_save_loc), logger=logger, diff --git a/tests/test_semdedup.py b/tests/test_semdedup.py index 4cc66901d..326b58af6 100644 --- a/tests/test_semdedup.py +++ b/tests/test_semdedup.py @@ -13,12 +13,17 @@ # limitations under the License. import os +import numpy as np import pytest +import torch +import torch.nn.functional as F from dask.dataframe.utils import assert_eq from distributed import Client +from transformers import AutoConfig, AutoModel, AutoTokenizer from nemo_curator import SemDedup, SemDedupConfig from nemo_curator.datasets import DocumentDataset +from nemo_curator.modules.semantic_dedup.embeddings import EmbeddingCreator from nemo_curator.utils.import_utils import gpu_only_import, gpu_only_import_from cudf = gpu_only_import("cudf") @@ -80,3 +85,86 @@ def test_sem_dedup( duplicate_docs = [2, 3, 4, 200, 300] expected_df = cudf.Series(duplicate_docs, name="id") assert_eq(result_df["id"].sort_values(), expected_df, check_index=False) + + @pytest.mark.parametrize("pooling_strategy", ["last_token", "mean"]) + def test_embedding_creator_pooling_strategies(self, tmpdir, pooling_strategy): + test_text = "The quick brown fox jumps over the lazy dog" + test_texts = [test_text] * 32 + df = cudf.DataFrame({"text": test_texts}) + ddf = dask_cudf.from_cudf(df, 1) + cache_dir = os.path.join(tmpdir, "test_embeddings_cache") + + embedding_creator = EmbeddingCreator( + embedding_model_name_or_path="sentence-transformers/all-MiniLM-L6-v2", + embedding_batch_size=32, + embedding_pooling_strategy=pooling_strategy, + input_column="text", + embedding_output_dir=os.path.join(cache_dir, "mean_embeddings"), + ) + embeddings = embedding_creator.create_embeddings(ddf).compute() + embeddings = embeddings["embeddings"].to_arrow().to_pylist() + embeddings = np.array(embeddings) + reference_embeddings = get_reference_embeddings( + test_texts, pooling_strategy=pooling_strategy + ) + assert np.allclose( + embeddings, reference_embeddings, atol=1e-3 + ), "Embeddings should match reference embeddings" + + +def get_reference_embeddings( + texts, + model_name="sentence-transformers/all-MiniLM-L6-v2", + pooling_strategy="last_token", +): + """ + Get embeddings using either last token or mean pooling strategy. + + Args: + texts: List of input texts + model_name: Name or path of the model to use + pooling_strategy: Either "last_token" for last token or "mean" for mean pooling + """ + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModel.from_pretrained(model_name) + model = model.to("cuda") + model.eval() + max_len_to_use = tokenizer.model_max_length + if max_len_to_use > 1e5: + max_len_to_use = AutoConfig.from_pretrained(model_name).max_position_embeddings + max_seq_length: int = max_len_to_use + + embs = [] + for text in texts: + inputs = tokenizer( + text, + return_tensors="pt", + padding=True, + truncation=True, + max_length=max_seq_length, + ) + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + with torch.no_grad(): + with torch.autocast(device_type="cuda"): + outputs = model(**inputs) + + if pooling_strategy == "last_token": + embeddings = outputs.last_hidden_state[:, -1, :] + elif pooling_strategy == "mean": + token_embeddings = outputs.last_hidden_state + attention_mask = inputs["attention_mask"] + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1) + sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9) + embeddings = sum_embeddings / sum_mask + else: + raise ValueError("pooling_strategy must be either 'last_token' or 'mean'") + + normed_emb = F.normalize(embeddings, dim=1).cpu() + normed_emb = normed_emb.squeeze(0) + embs.append(normed_emb) + + return np.array(embs) From bc831818cd083bf5782ffdddcb7bcd173f073cca Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Sun, 19 Jan 2025 16:20:26 -0800 Subject: [PATCH 2/7] Ensure pytest is importable in a CPU only environment Signed-off-by: Vibhu Jawa --- tests/test_semdedup.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_semdedup.py b/tests/test_semdedup.py index 326b58af6..5c696a2b7 100644 --- a/tests/test_semdedup.py +++ b/tests/test_semdedup.py @@ -23,12 +23,14 @@ from nemo_curator import SemDedup, SemDedupConfig from nemo_curator.datasets import DocumentDataset -from nemo_curator.modules.semantic_dedup.embeddings import EmbeddingCreator from nemo_curator.utils.import_utils import gpu_only_import, gpu_only_import_from cudf = gpu_only_import("cudf") dask_cudf = gpu_only_import("dask_cudf") LocalCUDACluster = gpu_only_import_from("dask_cuda", "LocalCUDACluster") +EmbeddingCreator = gpu_only_import( + "nemo_curator.modules.semantic_dedup.embeddings.EmbeddingCreator" +) @pytest.fixture From e5d40ea77e13f5cba6a28e60cce73e507d2a36c9 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Tue, 21 Jan 2025 08:52:40 -0800 Subject: [PATCH 3/7] Fix last token based on Avinash's feedback Signed-off-by: Vibhu Jawa --- nemo_curator/modules/semantic_dedup/embeddings.py | 11 ++++++++--- tests/test_semdedup.py | 9 +++++---- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/nemo_curator/modules/semantic_dedup/embeddings.py b/nemo_curator/modules/semantic_dedup/embeddings.py index 7385f3bd2..a2160b9d7 100644 --- a/nemo_curator/modules/semantic_dedup/embeddings.py +++ b/nemo_curator/modules/semantic_dedup/embeddings.py @@ -76,7 +76,7 @@ def forward(self, batch): if self.config.pooling_strategy == "mean": return self._mean_pooling(feature, batch["attention_mask"]) else: - return self._get_last_token(feature) + return self._get_last_token(feature, batch["attention_mask"]) def _mean_pooling(self, model_output, attention_mask): token_embeddings = model_output[0] @@ -87,9 +87,14 @@ def _mean_pooling(self, model_output, attention_mask): sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9) return F.normalize(sum_embeddings / sum_mask, dim=1) - def _get_last_token(self, model_output): + def _get_last_token(self, model_output, attention_mask): token_embeddings = model_output[0] - last_token_embeddings = token_embeddings[:, -1, :] + # Get indices of last non-padded tokens for each sequence in batch + last_token_indices = attention_mask.sum(dim=1) - 1 # -1 for 0-based indexing + batch_size = attention_mask.size(0) + batch_indices = torch.arange(batch_size, device=attention_mask.device) + # Get embeddings of last non-padded tokens + last_token_embeddings = token_embeddings[batch_indices, last_token_indices] return F.normalize(last_token_embeddings, dim=1) diff --git a/tests/test_semdedup.py b/tests/test_semdedup.py index 5c696a2b7..9558d5826 100644 --- a/tests/test_semdedup.py +++ b/tests/test_semdedup.py @@ -28,8 +28,8 @@ cudf = gpu_only_import("cudf") dask_cudf = gpu_only_import("dask_cudf") LocalCUDACluster = gpu_only_import_from("dask_cuda", "LocalCUDACluster") -EmbeddingCreator = gpu_only_import( - "nemo_curator.modules.semantic_dedup.embeddings.EmbeddingCreator" +EmbeddingCreator = gpu_only_import_from( + "nemo_curator.modules.semantic_dedup.embeddings", "EmbeddingCreator" ) @@ -90,8 +90,9 @@ def test_sem_dedup( @pytest.mark.parametrize("pooling_strategy", ["last_token", "mean"]) def test_embedding_creator_pooling_strategies(self, tmpdir, pooling_strategy): - test_text = "The quick brown fox jumps over the lazy dog" - test_texts = [test_text] * 32 + test_text_1 = "The quick brown fox jumps over the lazy dog" + test_text_2 = "The brown fox jumps over the dog" + test_texts = [test_text_1, test_text_2] * 32 df = cudf.DataFrame({"text": test_texts}) ddf = dask_cudf.from_cudf(df, 1) cache_dir = os.path.join(tmpdir, "test_embeddings_cache") From ac7a7fc9031c63c37b0fda73e3a7b76d322a9f8e Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Wed, 5 Feb 2025 11:22:49 -0800 Subject: [PATCH 4/7] Fix indexing issues Signed-off-by: Vibhu Jawa --- nemo_curator/modules/config.py | 3 ++- nemo_curator/modules/semantic_dedup/embeddings.py | 13 +++++++++---- tests/test_semdedup.py | 2 +- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/nemo_curator/modules/config.py b/nemo_curator/modules/config.py index b501f1754..6b0dc4a61 100644 --- a/nemo_curator/modules/config.py +++ b/nemo_curator/modules/config.py @@ -144,7 +144,8 @@ class SemDedupConfig(BaseConfig): embeddings_save_loc: str = "embeddings" embedding_model_name_or_path: str = "sentence-transformers/all-MiniLM-L6-v2" embedding_batch_size: int = 128 - embedding_pooling_strategy: str = "last_token" + # Options: "mean_pooling", "last_token" + embedding_pooling_strategy: str = "mean_pooling" # Clustering config clustering_save_loc: str = "clustering_results" diff --git a/nemo_curator/modules/semantic_dedup/embeddings.py b/nemo_curator/modules/semantic_dedup/embeddings.py index a2160b9d7..fad0a53ce 100644 --- a/nemo_curator/modules/semantic_dedup/embeddings.py +++ b/nemo_curator/modules/semantic_dedup/embeddings.py @@ -41,7 +41,7 @@ class EmbeddingConfig: model_name_or_path: str max_seq_length: int = None - pooling_strategy: str = "mean" # Options: "mean" or "last_token" + pooling_strategy: str = "mean_pooling" # Options: "mean_pooling" or "last_token" def __post_init__(self): self.max_seq_length = AutoTokenizer.from_pretrained( @@ -53,8 +53,10 @@ def __post_init__(self): self.max_seq_length = AutoConfig.from_pretrained( self.model_name_or_path ).max_position_embeddings - if self.pooling_strategy not in ["mean", "last_token"]: - raise ValueError("pooling_strategy must be either 'mean' or 'last_token'") + if self.pooling_strategy not in ["mean_pooling", "last_token"]: + raise ValueError( + "pooling_strategy must be either 'mean_pooling' or 'last_token'" + ) class EmbeddingPytorchModel(nn.Module): @@ -73,7 +75,7 @@ def feature(self, input_ids, attention_mask): @torch.no_grad() def forward(self, batch): feature = self.feature(batch["input_ids"], batch["attention_mask"]) - if self.config.pooling_strategy == "mean": + if self.config.pooling_strategy == "mean_pooling": return self._mean_pooling(feature, batch["attention_mask"]) else: return self._get_last_token(feature, batch["attention_mask"]) @@ -91,6 +93,9 @@ def _get_last_token(self, model_output, attention_mask): token_embeddings = model_output[0] # Get indices of last non-padded tokens for each sequence in batch last_token_indices = attention_mask.sum(dim=1) - 1 # -1 for 0-based indexing + last_token_indices = last_token_indices.to( + torch.long + ) # Ensure indices are of type long batch_size = attention_mask.size(0) batch_indices = torch.arange(batch_size, device=attention_mask.device) # Get embeddings of last non-padded tokens diff --git a/tests/test_semdedup.py b/tests/test_semdedup.py index 9558d5826..6b33e7652 100644 --- a/tests/test_semdedup.py +++ b/tests/test_semdedup.py @@ -88,7 +88,7 @@ def test_sem_dedup( expected_df = cudf.Series(duplicate_docs, name="id") assert_eq(result_df["id"].sort_values(), expected_df, check_index=False) - @pytest.mark.parametrize("pooling_strategy", ["last_token", "mean"]) + @pytest.mark.parametrize("pooling_strategy", ["last_token", "mean_pooling"]) def test_embedding_creator_pooling_strategies(self, tmpdir, pooling_strategy): test_text_1 = "The quick brown fox jumps over the lazy dog" test_text_2 = "The brown fox jumps over the dog" From d876faac10e4b0c8ef39ccf36d2ce38ec8931abd Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Wed, 5 Feb 2025 11:28:13 -0800 Subject: [PATCH 5/7] Merge in main Signed-off-by: Vibhu Jawa --- tests/test_semdedup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_semdedup.py b/tests/test_semdedup.py index 6b33e7652..5411e4aba 100644 --- a/tests/test_semdedup.py +++ b/tests/test_semdedup.py @@ -154,7 +154,7 @@ def get_reference_embeddings( if pooling_strategy == "last_token": embeddings = outputs.last_hidden_state[:, -1, :] - elif pooling_strategy == "mean": + elif pooling_strategy == "mean_pooling": token_embeddings = outputs.last_hidden_state attention_mask = inputs["attention_mask"] input_mask_expanded = ( From b9ce56e92e5d60e9f0bdc167475b247f4c4d3660 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Wed, 5 Feb 2025 11:30:37 -0800 Subject: [PATCH 6/7] Fix Doc-string Signed-off-by: Vibhu Jawa --- nemo_curator/modules/config.py | 2 +- nemo_curator/modules/semantic_dedup/embeddings.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo_curator/modules/config.py b/nemo_curator/modules/config.py index 77eee06ef..50c71017b 100644 --- a/nemo_curator/modules/config.py +++ b/nemo_curator/modules/config.py @@ -145,7 +145,7 @@ class SemDedupConfig(BaseConfig): embeddings_save_loc (str): Location to save embeddings. embedding_model_name_or_path (str): Model name or path for embeddings. embedding_batch_size (int): Inital Batch size for processing embeddings. - embedding_pooling_strategy (str): Strategy for pooling embeddings, either "mean_pooling" or "last_token". Defaults to "last_token". + embedding_pooling_strategy (str): Strategy for pooling embeddings, either "mean_pooling" or "last_token". Defaults to "mean_pooling". write_embeddings_to_disk (bool): If True, saves the embeddings to disk, defaults to True. We recommend setting this to False when you have a delayed pipeline. Setting it to False can lead to more memory overhead. diff --git a/nemo_curator/modules/semantic_dedup/embeddings.py b/nemo_curator/modules/semantic_dedup/embeddings.py index 03aec45ec..758a17fc9 100644 --- a/nemo_curator/modules/semantic_dedup/embeddings.py +++ b/nemo_curator/modules/semantic_dedup/embeddings.py @@ -154,7 +154,7 @@ def __init__( embedding_output_dir (str): Directory path where embeddings will be saved. embedding_max_mem_gb (int): Maximum memory usage in GB for the embedding process. If None, it defaults to the available GPU memory minus 4 GB. - embedding_pooling_strategy (str): Strategy for pooling embeddings, either "mean" or "last_token". Defaults to "last_token". + embedding_pooling_strategy (str): Strategy for pooling embeddings, either "mean_pooling" or "last_token". Defaults to "mean_pooling". input_column (str): Column name from the data to be used for embedding generation, defaults to "text". write_embeddings_to_disk (bool, optional): If True, saves the embeddings to disk, defaults to True. We recommend setting this to False when you have a delayed pipeline. From 04f8d554e8b1951853d5c46ae78f40e5f0d96ea8 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Thu, 6 Feb 2025 10:57:34 -0800 Subject: [PATCH 7/7] Address Sarah's reviews Signed-off-by: Vibhu Jawa --- nemo_curator/modules/semantic_dedup/embeddings.py | 2 +- tests/test_semdedup.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/nemo_curator/modules/semantic_dedup/embeddings.py b/nemo_curator/modules/semantic_dedup/embeddings.py index 758a17fc9..7f6315e52 100644 --- a/nemo_curator/modules/semantic_dedup/embeddings.py +++ b/nemo_curator/modules/semantic_dedup/embeddings.py @@ -137,7 +137,7 @@ def __init__( embedding_batch_size: int, embedding_output_dir: str, embedding_max_mem_gb: Optional[int] = None, - embedding_pooling_strategy: str = "last_token", + embedding_pooling_strategy: str = "mean_pooling", input_column: str = "text", embedding_column: str = "embeddings", write_embeddings_to_disk: bool = True, diff --git a/tests/test_semdedup.py b/tests/test_semdedup.py index 5411e4aba..8ccf850a7 100644 --- a/tests/test_semdedup.py +++ b/tests/test_semdedup.py @@ -164,7 +164,9 @@ def get_reference_embeddings( sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9) embeddings = sum_embeddings / sum_mask else: - raise ValueError("pooling_strategy must be either 'last_token' or 'mean'") + raise ValueError( + "pooling_strategy must be either 'last_token' or 'mean_pooling'" + ) normed_emb = F.normalize(embeddings, dim=1).cpu() normed_emb = normed_emb.squeeze(0)