Skip to content

Commit dffc5e1

Browse files
committed
feat: Add NVIDIA embedder and update eval scripts for MIRACL benchmark
1 parent 99d2983 commit dffc5e1

File tree

5 files changed

+184
-10
lines changed

5 files changed

+184
-10
lines changed

FlagEmbedding/abc/evaluation/evaluator.py

+1
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def __call__(
158158

159159
no_reranker_search_results_dict = {}
160160
if flag:
161+
print(f"{retriever} is running..., loading corpus for {dataset_name}")
161162
corpus = self.data_loader.load_corpus(dataset_name=dataset_name)
162163

163164
queries_dict = {

FlagEmbedding/inference/auto_embedder.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
AUTO_EMBEDDER_MAPPING, EMBEDDER_CLASS_MAPPING
88
)
99

10-
from FlagEmbedding.inference.embedder.encoder_only import VoyageModel
10+
from FlagEmbedding.inference.embedder.encoder_only import VoyageModel, NvidiaModel
1111

1212
logger = logging.getLogger(__name__)
1313

@@ -70,6 +70,11 @@ def from_finetuned(
7070
model_class = VoyageModel
7171
_model_class = VoyageModel
7272

73+
elif "nvidia" in model_name_or_path:
74+
model_name = "nvidia"
75+
model_class = NvidiaModel
76+
_model_class = NvidiaModel
77+
7378
elif model_class is not None:
7479
_model_class = EMBEDDER_CLASS_MAPPING[EmbedderModelClass(model_class)]
7580
if pooling_method is None:

FlagEmbedding/inference/embedder/encoder_only/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
from .m3 import M3Embedder as BGEM3FlagModel
33
#from .nomic import NomicEmbedder as NomicModel
44
from .voyage import VoyageEmbedder as VoyageModel
5-
5+
from .nvidia import NvidiaEmbedder as NvidiaModel
66
__all__ = [
77
"FlagModel",
88
"BGEM3FlagModel",
99
# "NomicModel",
10-
"VoyageModel"
10+
"VoyageModel",
11+
"NvidiaModel"
1112
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import os
2+
import torch
3+
import numpy as np
4+
from typing import List, Dict, Optional, Union
5+
from tqdm import tqdm
6+
from openai import OpenAI
7+
from tenacity import retry, stop_after_attempt, wait_exponential
8+
from FlagEmbedding.abc.inference import AbsEmbedder
9+
from transformers.configuration_utils import PretrainedConfig
10+
11+
12+
class NvidiaEmbedderConfig(PretrainedConfig):
13+
def __init__(self, **kwargs):
14+
super().__init__(**kwargs)
15+
self._name_or_path = "nvidia"
16+
17+
18+
class NvidiaMockModel:
19+
def __init__(self):
20+
self.config = NvidiaEmbedderConfig()
21+
22+
23+
class NvidiaEmbedder(AbsEmbedder):
24+
def __init__(
25+
self,
26+
model_name_or_path: str,
27+
normalize_embeddings: bool = True,
28+
use_fp16: bool = True,
29+
query_instruction_for_retrieval: Optional[str] = None,
30+
query_instruction_format: str = "{}{}",
31+
devices: Optional[Union[str, List[str]]] = None,
32+
batch_size: int = 2048,
33+
query_max_length: int = 512,
34+
passage_max_length: int = 512,
35+
convert_to_numpy: bool = True,
36+
**kwargs
37+
):
38+
super().__init__(
39+
model_name_or_path,
40+
normalize_embeddings=normalize_embeddings,
41+
use_fp16=use_fp16,
42+
query_instruction_for_retrieval=query_instruction_for_retrieval,
43+
query_instruction_format=query_instruction_format,
44+
devices=devices,
45+
batch_size=batch_size,
46+
query_max_length=query_max_length,
47+
passage_max_length=passage_max_length,
48+
convert_to_numpy=convert_to_numpy,
49+
**kwargs
50+
)
51+
52+
self.model = NvidiaMockModel()
53+
self.client = OpenAI(
54+
api_key="not-needed", # API key not needed for local server
55+
base_url="http://localhost:8000/v1"
56+
)
57+
self.model_name = "nvidia/llama-3.2-nv-embedqa-1b-v2"
58+
59+
@retry(
60+
stop=stop_after_attempt(5),
61+
wait=wait_exponential(multiplier=1, min=4, max=30),
62+
reraise=True
63+
)
64+
def _get_embeddings(self, texts: List[str], input_type: str = "query") -> np.ndarray:
65+
"""Get embeddings for a batch of texts with automatic retries and truncation on token size errors."""
66+
def try_with_texts(current_texts: List[str], retry_count: int = 0) -> np.ndarray:
67+
try:
68+
response = self.client.embeddings.create(
69+
input=current_texts,
70+
model=self.model_name,
71+
encoding_format="float",
72+
extra_body={"input_type": input_type, "truncate": "END"}
73+
)
74+
return np.array([data.embedding for data in response.data])
75+
except Exception as e:
76+
error_str = str(e)
77+
print(f"Error in _get_embeddings: {error_str}")
78+
79+
# If we hit token size limit and haven't retried too many times, truncate and retry
80+
if "token size" in error_str.lower() and retry_count < 3:
81+
truncated_texts = [t[:len(t)//2] for t in current_texts]
82+
print(f"Retrying with truncated texts (retry {retry_count + 1})")
83+
return try_with_texts(truncated_texts, retry_count + 1)
84+
85+
raise
86+
87+
return try_with_texts(texts)
88+
89+
def encode_queries(self, queries: List[str], batch_size: int = 128, **kwargs) -> np.ndarray:
90+
"""Encode queries with input_type='query'."""
91+
all_embeddings = []
92+
for i in tqdm(range(0, len(queries), batch_size), desc="Encoding queries"):
93+
batch = queries[i:i + batch_size]
94+
embeddings = self._get_embeddings(batch, input_type="query")
95+
all_embeddings.append(embeddings)
96+
return np.vstack(all_embeddings)
97+
98+
@staticmethod
99+
def _process_batch_static(args):
100+
"""Static method to process a batch of passages."""
101+
index, batch, model_name, base_url = args
102+
try:
103+
# Create a new client for each process
104+
client = OpenAI(
105+
api_key="not-needed",
106+
base_url=base_url
107+
)
108+
109+
# Get embeddings
110+
response = client.embeddings.create(
111+
input=batch,
112+
model=model_name,
113+
encoding_format="float",
114+
extra_body={"input_type": "passage", "truncate": "END"}
115+
)
116+
embeddings = np.array([data.embedding for data in response.data])
117+
return index, embeddings
118+
except Exception as e:
119+
print(f"Error processing batch {index}: {str(e)}")
120+
return index, None
121+
122+
def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 128, num_processes: int = 16, **kwargs) -> np.ndarray:
123+
"""Encode corpus passages with input_type='passage' using multiple processes."""
124+
if isinstance(corpus[0], dict):
125+
passages = [f"{doc.get('title', '')} {doc['text']}".strip() for doc in corpus]
126+
else:
127+
passages = corpus
128+
129+
# Prepare batches with their indices and required parameters
130+
batches = []
131+
for i in range(0, len(passages), batch_size):
132+
batch = passages[i:i + batch_size]
133+
# Include model name and base URL for each batch
134+
batches.append((len(batches), batch, self.model_name, self.client.base_url))
135+
136+
# Process batches in parallel
137+
from multiprocessing import Pool
138+
with Pool(processes=num_processes) as pool:
139+
# Use tqdm to show progress
140+
results = list(tqdm(
141+
pool.imap(self._process_batch_static, batches),
142+
total=len(batches),
143+
desc=f"Encoding corpus with {num_processes} processes"
144+
))
145+
146+
# Sort results by index and collect embeddings
147+
sorted_results = sorted(results, key=lambda x: x[0])
148+
all_embeddings = []
149+
for _, embeddings in sorted_results:
150+
if embeddings is not None:
151+
all_embeddings.append(embeddings)
152+
else:
153+
raise Exception("One or more batches failed to process")
154+
155+
return np.vstack(all_embeddings)
156+
157+
@torch.no_grad()
158+
def encode_single_device(
159+
self,
160+
sentences: Union[List[str], str],
161+
batch_size: int = 128,
162+
max_length: int = 512,
163+
convert_to_numpy: bool = True,
164+
device: Optional[str] = None,
165+
):
166+
"""Single device encoding method that defaults to query encoding."""
167+
if isinstance(sentences, str):
168+
sentences = [sentences]
169+
return self.encode_queries(sentences, batch_size=batch_size)

examples/evaluation/miracl/eval_miracl.sh

+5-7
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ fi
44

55
dataset_names="ar bn de en es fa fi fr hi id ja ko ru sw te th yo zh"
66

7-
VENV="/home/ubuntu/bstadt-smol/flagemb/env"
7+
VENV="/home/ubuntu/FlagEmbedding/.venv"
88
source $VENV/bin/activate
99

1010
eval_args="\
@@ -24,16 +24,14 @@ eval_args="\
2424
"
2525

2626
model_args="\
27-
--embedder_name_or_path voyage
28-
--devices cuda:1 \
27+
--embedder_name_or_path nvidia
28+
--devices cuda:0 \
2929
--trust_remote_code \
30-
--query_instruction_for_retrieval 'search_query: ' \
31-
--passage_instruction_for_retrieval 'search_document: ' \
32-
--embedder_batch_size 32 \
30+
--embedder_batch_size 1024 \
3331
--cache_dir $HF_HUB_CACHE
3432
"
3533

36-
cmd="/home/ubuntu/bstadt-smol/flagemb/env/bin/python -m FlagEmbedding.evaluation.miracl \
34+
cmd="/home/ubuntu/FlagEmbedding/.venv/bin/python -m FlagEmbedding.evaluation.miracl \
3735
$eval_args \
3836
$model_args \
3937
"

0 commit comments

Comments
 (0)