Skip to content

Commit

Permalink
Add Pooling Strategy Option for embedding creation (#491)
Browse files Browse the repository at this point in the history
* Add pooling stratedgy

Signed-off-by: Vibhu Jawa <[email protected]>

* Ensure pytest is importable in a CPU only environment

Signed-off-by: Vibhu Jawa <[email protected]>

* Fix last token based on Avinash's feedback

Signed-off-by: Vibhu Jawa <[email protected]>

* Fix indexing issues

Signed-off-by: Vibhu Jawa <[email protected]>

* Merge in main

Signed-off-by: Vibhu Jawa <[email protected]>

* Fix Doc-string

Signed-off-by: Vibhu Jawa <[email protected]>

* Address Sarah's reviews

Signed-off-by: Vibhu Jawa <[email protected]>

---------

Signed-off-by: Vibhu Jawa <[email protected]>
  • Loading branch information
VibhuJawa authored Feb 6, 2025
1 parent 1dab545 commit 97aa372
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 2 deletions.
3 changes: 3 additions & 0 deletions nemo_curator/modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class SemDedupConfig(BaseConfig):
embeddings_save_loc (str): Location to save embeddings.
embedding_model_name_or_path (str): Model name or path for embeddings.
embedding_batch_size (int): Inital Batch size for processing embeddings.
embedding_pooling_strategy (str): Strategy for pooling embeddings, either "mean_pooling" or "last_token". Defaults to "mean_pooling".
write_embeddings_to_disk (bool): If True, saves the embeddings to disk, defaults to True.
We recommend setting this to False when you have a delayed pipeline.
Setting it to False can lead to more memory overhead.
Expand All @@ -168,6 +169,8 @@ class SemDedupConfig(BaseConfig):
embeddings_save_loc: str = "embeddings"
embedding_model_name_or_path: str = "sentence-transformers/all-MiniLM-L6-v2"
embedding_batch_size: int = 128
# Options: "mean_pooling", "last_token"
embedding_pooling_strategy: str = "mean_pooling"
write_embeddings_to_disk: bool = True

# Clustering config
Expand Down
28 changes: 26 additions & 2 deletions nemo_curator/modules/semantic_dedup/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
class EmbeddingConfig:
model_name_or_path: str
max_seq_length: int = None
pooling_strategy: str = "mean_pooling" # Options: "mean_pooling" or "last_token"

def __post_init__(self):
self.max_seq_length = AutoTokenizer.from_pretrained(
Expand All @@ -52,6 +53,10 @@ def __post_init__(self):
self.max_seq_length = AutoConfig.from_pretrained(
self.model_name_or_path
).max_position_embeddings
if self.pooling_strategy not in ["mean_pooling", "last_token"]:
raise ValueError(
"pooling_strategy must be either 'mean_pooling' or 'last_token'"
)


class EmbeddingPytorchModel(nn.Module):
Expand All @@ -70,7 +75,10 @@ def feature(self, input_ids, attention_mask):
@torch.no_grad()
def forward(self, batch):
feature = self.feature(batch["input_ids"], batch["attention_mask"])
return self._mean_pooling(feature, batch["attention_mask"])
if self.config.pooling_strategy == "mean_pooling":
return self._mean_pooling(feature, batch["attention_mask"])
else:
return self._get_last_token(feature, batch["attention_mask"])

def _mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0]
Expand All @@ -81,6 +89,19 @@ def _mean_pooling(self, model_output, attention_mask):
sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
return F.normalize(sum_embeddings / sum_mask, dim=1)

def _get_last_token(self, model_output, attention_mask):
token_embeddings = model_output[0]
# Get indices of last non-padded tokens for each sequence in batch
last_token_indices = attention_mask.sum(dim=1) - 1 # -1 for 0-based indexing
last_token_indices = last_token_indices.to(
torch.long
) # Ensure indices are of type long
batch_size = attention_mask.size(0)
batch_indices = torch.arange(batch_size, device=attention_mask.device)
# Get embeddings of last non-padded tokens
last_token_embeddings = token_embeddings[batch_indices, last_token_indices]
return F.normalize(last_token_embeddings, dim=1)


class EmbeddingCrossFitModel(HFModel):
def __init__(
Expand Down Expand Up @@ -116,6 +137,7 @@ def __init__(
embedding_batch_size: int,
embedding_output_dir: str,
embedding_max_mem_gb: Optional[int] = None,
embedding_pooling_strategy: str = "mean_pooling",
input_column: str = "text",
embedding_column: str = "embeddings",
write_embeddings_to_disk: bool = True,
Expand All @@ -132,6 +154,7 @@ def __init__(
embedding_output_dir (str): Directory path where embeddings will be saved.
embedding_max_mem_gb (int): Maximum memory usage in GB for the embedding process.
If None, it defaults to the available GPU memory minus 4 GB.
embedding_pooling_strategy (str): Strategy for pooling embeddings, either "mean_pooling" or "last_token". Defaults to "mean_pooling".
input_column (str): Column name from the data to be used for embedding generation, defaults to "text".
write_embeddings_to_disk (bool, optional): If True, saves the embeddings to disk, defaults to True.
We recommend setting this to False when you have a delayed pipeline.
Expand All @@ -152,6 +175,7 @@ def __init__(

self.embeddings_config = EmbeddingConfig(
model_name_or_path=embedding_model_name_or_path,
pooling_strategy=embedding_pooling_strategy,
)
self.batch_size = embedding_batch_size
self.logger = self._setup_logger(logger)
Expand Down Expand Up @@ -184,7 +208,7 @@ def create_embeddings(
op.Tokenizer(
self.model,
cols=[input_column],
tokenizer_type="sentencepiece",
tokenizer_type="default",
max_length=self.embeddings_config.max_seq_length,
),
op.Predictor(
Expand Down
1 change: 1 addition & 0 deletions nemo_curator/modules/semantic_dedup/semdedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
self.embedding_creator = EmbeddingCreator(
embedding_model_name_or_path=config.embedding_model_name_or_path,
embedding_batch_size=config.embedding_batch_size,
embedding_pooling_strategy=config.embedding_pooling_strategy,
input_column=input_column,
embedding_output_dir=os.path.join(cache_dir, config.embeddings_save_loc),
write_embeddings_to_disk=config.write_embeddings_to_disk,
Expand Down
93 changes: 93 additions & 0 deletions tests/test_semdedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
# limitations under the License.
import os

import numpy as np
import pytest
import torch
import torch.nn.functional as F
from dask.dataframe.utils import assert_eq
from distributed import Client
from transformers import AutoConfig, AutoModel, AutoTokenizer

from nemo_curator import SemDedup, SemDedupConfig
from nemo_curator.datasets import DocumentDataset
Expand All @@ -24,6 +28,9 @@
cudf = gpu_only_import("cudf")
dask_cudf = gpu_only_import("dask_cudf")
LocalCUDACluster = gpu_only_import_from("dask_cuda", "LocalCUDACluster")
EmbeddingCreator = gpu_only_import_from(
"nemo_curator.modules.semantic_dedup.embeddings", "EmbeddingCreator"
)


@pytest.fixture
Expand Down Expand Up @@ -80,3 +87,89 @@ def test_sem_dedup(
duplicate_docs = [2, 3, 4, 200, 300]
expected_df = cudf.Series(duplicate_docs, name="id")
assert_eq(result_df["id"].sort_values(), expected_df, check_index=False)

@pytest.mark.parametrize("pooling_strategy", ["last_token", "mean_pooling"])
def test_embedding_creator_pooling_strategies(self, tmpdir, pooling_strategy):
test_text_1 = "The quick brown fox jumps over the lazy dog"
test_text_2 = "The brown fox jumps over the dog"
test_texts = [test_text_1, test_text_2] * 32
df = cudf.DataFrame({"text": test_texts})
ddf = dask_cudf.from_cudf(df, 1)
cache_dir = os.path.join(tmpdir, "test_embeddings_cache")

embedding_creator = EmbeddingCreator(
embedding_model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
embedding_batch_size=32,
embedding_pooling_strategy=pooling_strategy,
input_column="text",
embedding_output_dir=os.path.join(cache_dir, "mean_embeddings"),
)
embeddings = embedding_creator.create_embeddings(ddf).compute()
embeddings = embeddings["embeddings"].to_arrow().to_pylist()
embeddings = np.array(embeddings)
reference_embeddings = get_reference_embeddings(
test_texts, pooling_strategy=pooling_strategy
)
assert np.allclose(
embeddings, reference_embeddings, atol=1e-3
), "Embeddings should match reference embeddings"


def get_reference_embeddings(
texts,
model_name="sentence-transformers/all-MiniLM-L6-v2",
pooling_strategy="last_token",
):
"""
Get embeddings using either last token or mean pooling strategy.
Args:
texts: List of input texts
model_name: Name or path of the model to use
pooling_strategy: Either "last_token" for last token or "mean" for mean pooling
"""
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model = model.to("cuda")
model.eval()
max_len_to_use = tokenizer.model_max_length
if max_len_to_use > 1e5:
max_len_to_use = AutoConfig.from_pretrained(model_name).max_position_embeddings
max_seq_length: int = max_len_to_use

embs = []
for text in texts:
inputs = tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_seq_length,
)
inputs = {k: v.to("cuda") for k, v in inputs.items()}

with torch.no_grad():
with torch.autocast(device_type="cuda"):
outputs = model(**inputs)

if pooling_strategy == "last_token":
embeddings = outputs.last_hidden_state[:, -1, :]
elif pooling_strategy == "mean_pooling":
token_embeddings = outputs.last_hidden_state
attention_mask = inputs["attention_mask"]
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
embeddings = sum_embeddings / sum_mask
else:
raise ValueError(
"pooling_strategy must be either 'last_token' or 'mean_pooling'"
)

normed_emb = F.normalize(embeddings, dim=1).cpu()
normed_emb = normed_emb.squeeze(0)
embs.append(normed_emb)

return np.array(embs)

0 comments on commit 97aa372

Please sign in to comment.