Skip to content

Commit 0367822

Browse files
committed
feat: at one point this worked
1 parent 0498828 commit 0367822

File tree

11 files changed

+363
-12
lines changed

11 files changed

+363
-12
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
*.npy
12
*.memmap
3+
miracl/miracl/*
4+
*.json*
25

36
# Byte-compiled / optimized / DLL files
47
__pycache__/

FlagEmbedding/abc/evaluation/arguments.py

+3
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ class AbsEvalModelArgs:
107107
query_instruction_format_for_retrieval: str = field(
108108
default="{}{}", metadata={"help": "Format for query instruction"}
109109
)
110+
passage_instruction_for_retrieval: Optional[str] = field(
111+
default=None, metadata={"help": "Instruction for passage"}
112+
)
110113
examples_for_task: Optional[str] = field(
111114
default=None, metadata={"help": "Examples for task"}
112115
)

FlagEmbedding/abc/evaluation/runner.py

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def get_models(model_args: AbsEvalModelArgs) -> Tuple[FlagAutoModel, Union[FlagA
5252
use_fp16=model_args.use_fp16,
5353
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
5454
query_instruction_format=model_args.query_instruction_format_for_retrieval,
55+
passage_instruction_for_retrieval=model_args.passage_instruction_for_retrieval,
5556
devices=model_args.devices,
5657
examples_for_task=model_args.examples_for_task,
5758
examples_instruction_format=model_args.examples_instruction_format,

FlagEmbedding/inference/auto_embedder.py

+2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def from_finetuned(
5858
AbsEmbedder: The model class to load model, which is child class of :class:`AbsEmbedder`.
5959
"""
6060
model_name = os.path.basename(model_name_or_path)
61+
if "nomic" in model_name_or_path:
62+
model_name = "nomic"
6163
if model_name.startswith("checkpoint-"):
6264
model_name = os.path.basename(os.path.dirname(model_name_or_path))
6365

FlagEmbedding/inference/embedder/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .encoder_only import FlagModel, BGEM3FlagModel
1+
from .encoder_only import FlagModel, BGEM3FlagModel, NomicModel
22
from .decoder_only import FlagICLModel, FlagLLMModel
33
from .model_mapping import EmbedderModelClass
44

@@ -8,4 +8,5 @@
88
"FlagICLModel",
99
"FlagLLMModel",
1010
"EmbedderModelClass",
11+
"NomicModel"
1112
]
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from .base import BaseEmbedder as FlagModel
22
from .m3 import M3Embedder as BGEM3FlagModel
3+
from .nomic import NomicEmbedder as NomicModel
34

45
__all__ = [
56
"FlagModel",
67
"BGEM3FlagModel",
8+
"NomicModel"
79
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
from tqdm import tqdm, trange
2+
from typing import cast, Any, List, Union, Optional
3+
4+
import torch
5+
import numpy as np
6+
from transformers import AutoModel, AutoTokenizer
7+
8+
from FlagEmbedding.abc.inference import AbsEmbedder
9+
from contrastors import BiEncoderConfig, BiEncoder
10+
11+
12+
class NomicEmbedder(AbsEmbedder):
13+
"""
14+
Base embedder for encoder only models.
15+
16+
Args:
17+
model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and
18+
load a model from HuggingFace Hub with the name.
19+
normalize_embeddings (bool, optional): If True, normalize the embedding vector. Defaults to :data:`True`.
20+
use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance
21+
degradation. Defaults to :data:`True`.
22+
query_instruction_for_retrieval (Optional[str], optional): Query instruction for retrieval tasks, which will be used with
23+
with :attr:`query_instruction_format`. Defaults to :data:`None`.
24+
query_instruction_format (str, optional): The template for :attr:`query_instruction_for_retrieval`. Defaults to :data:`"{}{}"`.
25+
devices (Optional[Union[str, int, List[str], List[int]]], optional): Devices to use for model inference. Defaults to :data:`None`.
26+
pooling_method (str, optional): Pooling method to get embedding vector from the last hidden state. Defaults to :data:`"cls"`.
27+
trust_remote_code (bool, optional): trust_remote_code for HF datasets or models. Defaults to :data:`False`.
28+
cache_dir (Optional[str], optional): Cache directory for the model. Defaults to :data:`None`.
29+
batch_size (int, optional): Batch size for inference. Defaults to :data:`256`.
30+
query_max_length (int, optional): Maximum length for query. Defaults to :data:`512`.
31+
passage_max_length (int, optional): Maximum length for passage. Defaults to :data:`512`.
32+
convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor.
33+
Defaults to :data:`True`.
34+
35+
Attributes:
36+
DEFAULT_POOLING_METHOD: The default pooling method when running the model.
37+
"""
38+
39+
DEFAULT_POOLING_METHOD = None
40+
41+
def __init__(
42+
self,
43+
model_name_or_path: str,
44+
normalize_embeddings: bool = True,
45+
use_fp16: bool = True,
46+
query_instruction_for_retrieval: Optional[str] = None,
47+
query_instruction_format: str = "{}{}", # specify the format of query_instruction_for_retrieval
48+
devices: Optional[Union[str, List[str]]] = None, # specify devices, such as "cuda:0" or ["cuda:0", "cuda:1"]
49+
# Additional parameters for BaseEmbedder
50+
pooling_method: str = "cls",
51+
trust_remote_code: bool = False,
52+
cache_dir: Optional[str] = None,
53+
# inference
54+
batch_size: int = 256,
55+
query_max_length: int = 512,
56+
passage_max_length: int = 512,
57+
convert_to_numpy: bool = True,
58+
**kwargs: Any,
59+
):
60+
super().__init__(
61+
model_name_or_path,
62+
normalize_embeddings=normalize_embeddings,
63+
use_fp16=use_fp16,
64+
query_instruction_for_retrieval=query_instruction_for_retrieval,
65+
query_instruction_format=query_instruction_format,
66+
devices=devices,
67+
batch_size=batch_size,
68+
query_max_length=query_max_length,
69+
passage_max_length=passage_max_length,
70+
convert_to_numpy=convert_to_numpy,
71+
**kwargs
72+
)
73+
self.pooling_method = pooling_method
74+
75+
self.tokenizer = AutoTokenizer.from_pretrained(
76+
"FacebookAI/xlm-roberta-base",
77+
trust_remote_code=trust_remote_code,
78+
cache_dir=cache_dir
79+
)
80+
config = BiEncoderConfig.from_pretrained(model_name_or_path)
81+
self.model = BiEncoder.from_pretrained(
82+
model_name_or_path, config=config
83+
).to(torch.bfloat16)
84+
print(self.model)
85+
86+
def encode_queries(
87+
self,
88+
queries: Union[List[str], str],
89+
batch_size: Optional[int] = None,
90+
max_length: Optional[int] = None,
91+
convert_to_numpy: Optional[bool] = None,
92+
**kwargs: Any
93+
) -> Union[np.ndarray, torch.Tensor]:
94+
"""Encode the queries.
95+
96+
Args:
97+
queries (Union[List[str], str]): Input queries to encode.
98+
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`.
99+
max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`.
100+
convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
101+
be a Torch Tensor. Defaults to :data:`None`.
102+
103+
Returns:
104+
Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
105+
"""
106+
return super().encode_queries(
107+
queries,
108+
batch_size=batch_size,
109+
max_length=max_length,
110+
convert_to_numpy=convert_to_numpy,
111+
**kwargs
112+
)
113+
114+
def encode_corpus(
115+
self,
116+
corpus: Union[List[str], str],
117+
batch_size: Optional[int] = None,
118+
max_length: Optional[int] = None,
119+
convert_to_numpy: Optional[bool] = None,
120+
**kwargs: Any
121+
) -> Union[np.ndarray, torch.Tensor]:
122+
"""Encode the corpus using the instruction if provided.
123+
124+
Args:
125+
corpus (Union[List[str], str]): Input corpus to encode.
126+
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`.
127+
max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`.
128+
convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
129+
be a Torch Tensor. Defaults to :data:`None`.
130+
131+
Returns:
132+
Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
133+
"""
134+
return super().encode_corpus(
135+
corpus,
136+
batch_size=batch_size,
137+
max_length=max_length,
138+
convert_to_numpy=convert_to_numpy,
139+
**kwargs
140+
)
141+
142+
def encode(
143+
self,
144+
sentences: Union[List[str], str],
145+
batch_size: Optional[int] = None,
146+
max_length: Optional[int] = None,
147+
convert_to_numpy: Optional[bool] = None,
148+
**kwargs: Any
149+
) -> Union[np.ndarray, torch.Tensor]:
150+
"""Encode the input sentences with the embedding model.
151+
152+
Args:
153+
sentences (Union[List[str], str]): Input sentences to encode.
154+
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`.
155+
max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`.
156+
convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
157+
be a Torch Tensor. Defaults to :data:`None`.
158+
159+
Returns:
160+
Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor.
161+
"""
162+
return super().encode(
163+
sentences,
164+
batch_size=batch_size,
165+
max_length=max_length,
166+
convert_to_numpy=convert_to_numpy,
167+
**kwargs
168+
)
169+
170+
@torch.no_grad()
171+
def encode_single_device(
172+
self,
173+
sentences: Union[List[str], str],
174+
batch_size: int = 256,
175+
max_length: int = 512,
176+
convert_to_numpy: bool = True,
177+
device: Optional[str] = None,
178+
**kwargs: Any
179+
):
180+
"""Encode input sentences by a single device.
181+
182+
Args:
183+
sentences (Union[List[str], str]): Input sentences to encode.
184+
batch_size (int, optional): Number of sentences for each iter. Defaults to :data:`256`.
185+
max_length (int, optional): Maximum length of tokens. Defaults to :data:`512`.
186+
convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will
187+
be a Torch Tensor. Defaults to :data:`True`.
188+
device (Optional[str], optional): Device to use for encoding. Defaults to None.
189+
190+
Returns:
191+
Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor.
192+
"""
193+
if device is None:
194+
device = self.target_devices[0]
195+
196+
if device == "cpu": self.use_fp16 = False
197+
if self.use_fp16: self.model.to(torch.bfloat16)
198+
199+
self.model.to(device)
200+
self.model.eval()
201+
202+
input_was_string = False
203+
if isinstance(sentences, str):
204+
sentences = [sentences]
205+
input_was_string = True
206+
207+
# tokenize without padding to get the correct length
208+
all_inputs = []
209+
for start_index in trange(0, len(sentences), batch_size, desc='pre tokenize',
210+
disable=len(sentences) < 256):
211+
sentences_batch = sentences[start_index:start_index + batch_size]
212+
inputs_batch = self.tokenizer(
213+
sentences_batch,
214+
truncation=True,
215+
max_length=max_length,
216+
**kwargs
217+
)
218+
inputs_batch = [{
219+
k: inputs_batch[k][i] for k in inputs_batch.keys()
220+
} for i in range(len(sentences_batch))]
221+
all_inputs.extend(inputs_batch)
222+
223+
# sort by length for less padding
224+
length_sorted_idx = np.argsort([-len(x['input_ids']) for x in all_inputs])
225+
all_inputs_sorted = [all_inputs[i] for i in length_sorted_idx]
226+
227+
# adjust batch size
228+
flag = False
229+
batch_size = 4
230+
231+
# while flag is False:
232+
# try:
233+
# inputs_batch = self.tokenizer.pad(
234+
# all_inputs_sorted[: batch_size],
235+
# padding=True,
236+
# return_tensors='pt',
237+
# **kwargs
238+
# ).to(device)
239+
# embeddings = self.model(**inputs_batch)["embedding"]
240+
# flag = True
241+
# except RuntimeError as e:
242+
# batch_size = batch_size * 3 // 4
243+
# except torch.OutofMemoryError as e:
244+
# batch_size = batch_size * 3 // 4
245+
246+
# encode
247+
all_embeddings = []
248+
for start_index in tqdm(range(0, len(sentences), batch_size), desc="Inference Embeddings",
249+
disable=len(sentences) < 256):
250+
inputs_batch = all_inputs_sorted[start_index:start_index + batch_size]
251+
inputs_batch = self.tokenizer.pad(
252+
inputs_batch,
253+
padding=True,
254+
return_tensors='pt',
255+
**kwargs
256+
).to(device)
257+
embeddings = self.model(**inputs_batch)["embedding"]
258+
if self.normalize_embeddings:
259+
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
260+
embeddings = cast(torch.Tensor, embeddings)
261+
262+
if convert_to_numpy:
263+
embeddings = embeddings.cpu().float().numpy()
264+
all_embeddings.append(embeddings)
265+
266+
if convert_to_numpy:
267+
all_embeddings = np.concatenate(all_embeddings, axis=0)
268+
else:
269+
all_embeddings = torch.cat(all_embeddings, dim=0)
270+
271+
# adjust the order of embeddings
272+
all_embeddings = all_embeddings[np.argsort(length_sorted_idx)]
273+
274+
# return the embeddings
275+
if input_was_string:
276+
return all_embeddings[0]
277+
return all_embeddings
278+
279+
def pooling(
280+
self,
281+
last_hidden_state: torch.Tensor,
282+
attention_mask: Optional[torch.Tensor] = None
283+
):
284+
"""The pooling function.
285+
286+
Args:
287+
last_hidden_state (torch.Tensor): The last hidden state of the model.
288+
attention_mask (Optional[torch.Tensor], optional): Attention mask. Defaults to :data:`None`.
289+
290+
Raises:
291+
NotImplementedError: pooling method not implemented.
292+
293+
Returns:
294+
torch.Tensor: The embedding vectors after pooling.
295+
"""
296+
# pooling done in contrastors
297+
if self.pooling_method == None:
298+
return last_hidden_state
299+
else:
300+
raise NotImplementedError(f"pooling method {self.pooling_method} not implemented")

FlagEmbedding/inference/embedder/model_mapping.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from collections import OrderedDict
55

66
from FlagEmbedding.abc.inference import AbsEmbedder
7-
from FlagEmbedding.inference.embedder import FlagModel, BGEM3FlagModel, FlagLLMModel, FlagICLModel
8-
7+
from FlagEmbedding.inference.embedder import FlagModel, BGEM3FlagModel, FlagLLMModel, FlagICLModel, NomicModel
98

109
class EmbedderModelClass(Enum):
1110
ENCODER_ONLY_BASE = "encoder-only-base"
@@ -211,5 +210,17 @@ class EmbedderConfig:
211210
'bce-embedding-base_v1',
212211
EmbedderConfig(FlagModel, PoolingMethod.CLS)
213212
),
213+
(
214+
'snowflake-arctic-embed-l-v2.0',
215+
EmbedderConfig(FlagModel, PoolingMethod.CLS, trust_remote_code=True)
216+
),
217+
(
218+
'snowflake-arctic-embed-m-v2.0',
219+
EmbedderConfig(FlagModel, PoolingMethod.CLS, trust_remote_code=True)
220+
),
221+
(
222+
'nomic',
223+
EmbedderConfig(NomicModel, None)
224+
)
214225
# TODO: Add more models, such as Jina, Stella_v5, NV-Embed, etc.
215226
])

0 commit comments

Comments
 (0)