Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pooling Strategy Option for embedding creation #491

Merged
merged 8 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions nemo_curator/modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class SemDedupConfig(BaseConfig):
embeddings_save_loc (str): Location to save embeddings.
embedding_model_name_or_path (str): Model name or path for embeddings.
embedding_batch_size (int): Inital Batch size for processing embeddings.
embedding_pooling_strategy (str): Strategy for pooling embeddings, either "mean_pooling" or "last_token". Defaults to "mean_pooling".
write_embeddings_to_disk (bool): If True, saves the embeddings to disk, defaults to True.
We recommend setting this to False when you have a delayed pipeline.
Setting it to False can lead to more memory overhead.
Expand All @@ -168,6 +169,8 @@ class SemDedupConfig(BaseConfig):
embeddings_save_loc: str = "embeddings"
embedding_model_name_or_path: str = "sentence-transformers/all-MiniLM-L6-v2"
embedding_batch_size: int = 128
# Options: "mean_pooling", "last_token"
embedding_pooling_strategy: str = "mean_pooling"
write_embeddings_to_disk: bool = True

# Clustering config
Expand Down
28 changes: 26 additions & 2 deletions nemo_curator/modules/semantic_dedup/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
class EmbeddingConfig:
model_name_or_path: str
max_seq_length: int = None
pooling_strategy: str = "mean_pooling" # Options: "mean_pooling" or "last_token"

def __post_init__(self):
self.max_seq_length = AutoTokenizer.from_pretrained(
Expand All @@ -52,6 +53,10 @@ def __post_init__(self):
self.max_seq_length = AutoConfig.from_pretrained(
self.model_name_or_path
).max_position_embeddings
if self.pooling_strategy not in ["mean_pooling", "last_token"]:
raise ValueError(
"pooling_strategy must be either 'mean_pooling' or 'last_token'"
)


class EmbeddingPytorchModel(nn.Module):
Expand All @@ -70,7 +75,10 @@ def feature(self, input_ids, attention_mask):
@torch.no_grad()
def forward(self, batch):
feature = self.feature(batch["input_ids"], batch["attention_mask"])
return self._mean_pooling(feature, batch["attention_mask"])
if self.config.pooling_strategy == "mean_pooling":
return self._mean_pooling(feature, batch["attention_mask"])
else:
return self._get_last_token(feature, batch["attention_mask"])

def _mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0]
Expand All @@ -81,6 +89,19 @@ def _mean_pooling(self, model_output, attention_mask):
sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
return F.normalize(sum_embeddings / sum_mask, dim=1)

def _get_last_token(self, model_output, attention_mask):
token_embeddings = model_output[0]
# Get indices of last non-padded tokens for each sequence in batch
last_token_indices = attention_mask.sum(dim=1) - 1 # -1 for 0-based indexing
last_token_indices = last_token_indices.to(
torch.long
) # Ensure indices are of type long
batch_size = attention_mask.size(0)
batch_indices = torch.arange(batch_size, device=attention_mask.device)
# Get embeddings of last non-padded tokens
last_token_embeddings = token_embeddings[batch_indices, last_token_indices]
return F.normalize(last_token_embeddings, dim=1)


class EmbeddingCrossFitModel(HFModel):
def __init__(
Expand Down Expand Up @@ -116,6 +137,7 @@ def __init__(
embedding_batch_size: int,
embedding_output_dir: str,
embedding_max_mem_gb: Optional[int] = None,
embedding_pooling_strategy: str = "last_token",
VibhuJawa marked this conversation as resolved.
Show resolved Hide resolved
input_column: str = "text",
embedding_column: str = "embeddings",
write_embeddings_to_disk: bool = True,
Expand All @@ -132,6 +154,7 @@ def __init__(
embedding_output_dir (str): Directory path where embeddings will be saved.
embedding_max_mem_gb (int): Maximum memory usage in GB for the embedding process.
If None, it defaults to the available GPU memory minus 4 GB.
embedding_pooling_strategy (str): Strategy for pooling embeddings, either "mean_pooling" or "last_token". Defaults to "mean_pooling".
input_column (str): Column name from the data to be used for embedding generation, defaults to "text".
write_embeddings_to_disk (bool, optional): If True, saves the embeddings to disk, defaults to True.
We recommend setting this to False when you have a delayed pipeline.
Expand All @@ -152,6 +175,7 @@ def __init__(

self.embeddings_config = EmbeddingConfig(
model_name_or_path=embedding_model_name_or_path,
pooling_strategy=embedding_pooling_strategy,
)
self.batch_size = embedding_batch_size
self.logger = self._setup_logger(logger)
Expand Down Expand Up @@ -184,7 +208,7 @@ def create_embeddings(
op.Tokenizer(
self.model,
cols=[input_column],
tokenizer_type="sentencepiece",
tokenizer_type="default",
max_length=self.embeddings_config.max_seq_length,
),
op.Predictor(
Expand Down
1 change: 1 addition & 0 deletions nemo_curator/modules/semantic_dedup/semdedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
self.embedding_creator = EmbeddingCreator(
embedding_model_name_or_path=config.embedding_model_name_or_path,
embedding_batch_size=config.embedding_batch_size,
embedding_pooling_strategy=config.embedding_pooling_strategy,
input_column=input_column,
embedding_output_dir=os.path.join(cache_dir, config.embeddings_save_loc),
write_embeddings_to_disk=config.write_embeddings_to_disk,
Expand Down
91 changes: 91 additions & 0 deletions tests/test_semdedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
# limitations under the License.
import os

import numpy as np
import pytest
import torch
import torch.nn.functional as F
from dask.dataframe.utils import assert_eq
from distributed import Client
from transformers import AutoConfig, AutoModel, AutoTokenizer

from nemo_curator import SemDedup, SemDedupConfig
from nemo_curator.datasets import DocumentDataset
Expand All @@ -24,6 +28,9 @@
cudf = gpu_only_import("cudf")
dask_cudf = gpu_only_import("dask_cudf")
LocalCUDACluster = gpu_only_import_from("dask_cuda", "LocalCUDACluster")
EmbeddingCreator = gpu_only_import_from(
"nemo_curator.modules.semantic_dedup.embeddings", "EmbeddingCreator"
)


@pytest.fixture
Expand Down Expand Up @@ -80,3 +87,87 @@ def test_sem_dedup(
duplicate_docs = [2, 3, 4, 200, 300]
expected_df = cudf.Series(duplicate_docs, name="id")
assert_eq(result_df["id"].sort_values(), expected_df, check_index=False)

@pytest.mark.parametrize("pooling_strategy", ["last_token", "mean_pooling"])
def test_embedding_creator_pooling_strategies(self, tmpdir, pooling_strategy):
test_text_1 = "The quick brown fox jumps over the lazy dog"
test_text_2 = "The brown fox jumps over the dog"
test_texts = [test_text_1, test_text_2] * 32
df = cudf.DataFrame({"text": test_texts})
ddf = dask_cudf.from_cudf(df, 1)
cache_dir = os.path.join(tmpdir, "test_embeddings_cache")

embedding_creator = EmbeddingCreator(
embedding_model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
embedding_batch_size=32,
embedding_pooling_strategy=pooling_strategy,
input_column="text",
embedding_output_dir=os.path.join(cache_dir, "mean_embeddings"),
)
embeddings = embedding_creator.create_embeddings(ddf).compute()
embeddings = embeddings["embeddings"].to_arrow().to_pylist()
embeddings = np.array(embeddings)
reference_embeddings = get_reference_embeddings(
test_texts, pooling_strategy=pooling_strategy
)
assert np.allclose(
embeddings, reference_embeddings, atol=1e-3
), "Embeddings should match reference embeddings"


def get_reference_embeddings(
texts,
model_name="sentence-transformers/all-MiniLM-L6-v2",
pooling_strategy="last_token",
):
"""
Get embeddings using either last token or mean pooling strategy.

Args:
texts: List of input texts
model_name: Name or path of the model to use
pooling_strategy: Either "last_token" for last token or "mean" for mean pooling
"""
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model = model.to("cuda")
model.eval()
max_len_to_use = tokenizer.model_max_length
if max_len_to_use > 1e5:
max_len_to_use = AutoConfig.from_pretrained(model_name).max_position_embeddings
max_seq_length: int = max_len_to_use

embs = []
for text in texts:
inputs = tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_seq_length,
)
inputs = {k: v.to("cuda") for k, v in inputs.items()}

with torch.no_grad():
with torch.autocast(device_type="cuda"):
outputs = model(**inputs)

if pooling_strategy == "last_token":
embeddings = outputs.last_hidden_state[:, -1, :]
elif pooling_strategy == "mean_pooling":
token_embeddings = outputs.last_hidden_state
attention_mask = inputs["attention_mask"]
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
embeddings = sum_embeddings / sum_mask
else:
raise ValueError("pooling_strategy must be either 'last_token' or 'mean'")
VibhuJawa marked this conversation as resolved.
Show resolved Hide resolved

normed_emb = F.normalize(embeddings, dim=1).cpu()
normed_emb = normed_emb.squeeze(0)
embs.append(normed_emb)

return np.array(embs)