Skip to content

Commit ef4e12a

Browse files
committed
fix: current status before merging and trying to upgrade for eurobert
1 parent c2ccb72 commit ef4e12a

File tree

6 files changed

+76
-32
lines changed

6 files changed

+76
-32
lines changed

FlagEmbedding/inference/auto_embedder.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def from_finetuned(
5858
AbsEmbedder: The model class to load model, which is child class of :class:`AbsEmbedder`.
5959
"""
6060
model_name = os.path.basename(model_name_or_path)
61-
if "nomic" in model_name_or_path:
62-
model_name = "nomic"
61+
# if "nomic" in model_name_or_path:
62+
# model_name = "nomic"
6363
if model_name.startswith("checkpoint-"):
6464
model_name = os.path.basename(os.path.dirname(model_name_or_path))
6565

FlagEmbedding/inference/embedder/encoder_only/nomic.py

+42-14
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def _transform_func(tokenizer,
2828

2929
# Triton is not thread safe AFAICT so using naive DataParallel fails
3030
class EncoderWorker(mp.Process):
31-
def __init__(self, rank, world_size, input_queue, output_queue, model_name, tokenizer_name, batch_size, master_port=12345):
31+
def __init__(self, rank, world_size, input_queue, output_queue, model_name, tokenizer_name, batch_size, master_port=12344):
3232
super().__init__()
3333
self.rank = rank
3434
self.world_size = world_size
@@ -99,7 +99,7 @@ def run(self):
9999

100100
local_embeds = []
101101
with torch.no_grad():
102-
for batch_dict in tqdm(loader, desc=f"Rank {self.rank}"):
102+
for batch_dict in tqdm(loader, desc=f"Rank {self.rank}", disable=True):
103103
batch_dict = {k: v.cuda(self.rank) for k, v in batch_dict.items()}
104104
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
105105
outputs = encoder(**batch_dict)
@@ -215,11 +215,15 @@ def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs) -> np.ndarray:
215215
def encode_single_device(
216216
self,
217217
sentences: Union[List[str], str],
218-
batch_size: int = 256,
218+
batch_size: int = 512,
219219
max_length: int = 512,
220220
convert_to_numpy: bool = True,
221221
device: Optional[str] = None,
222222
):
223+
if isinstance(sentences, str):
224+
sentences = [sentences]
225+
226+
# Initialize workers if not already initialized
223227
if len(self.workers) == 0:
224228
for rank in range(self.world_size):
225229
worker = EncoderWorker(
@@ -234,17 +238,41 @@ def encode_single_device(
234238
worker.start()
235239
self.workers.append(worker)
236240

237-
if isinstance(sentences, str):
238-
sentences = [sentences]
239-
240-
for _ in range(self.world_size):
241-
self.input_queue.put(sentences)
242-
result = self.output_queue.get()
243-
244-
if isinstance(result, Exception):
245-
raise result
246-
247-
return result
241+
# Calculate number of batches
242+
total_samples = len(sentences)
243+
batch_size = 65536
244+
num_batches = (total_samples + batch_size - 1) // batch_size
245+
246+
all_results = []
247+
248+
# Process sentences in batches
249+
for batch_idx in tqdm(range(num_batches)):
250+
start_idx = batch_idx * batch_size
251+
end_idx = min((batch_idx + 1) * batch_size, total_samples)
252+
batch_sentences = sentences[start_idx:end_idx]
253+
254+
# Distribute batch to workers
255+
for _ in range(self.world_size):
256+
self.input_queue.put(batch_sentences)
257+
258+
# Get results for this batch
259+
batch_result = self.output_queue.get()
260+
261+
if isinstance(batch_result, Exception):
262+
raise batch_result
263+
264+
all_results.append(batch_result)
265+
266+
# Concatenate results from all batches
267+
if len(all_results) > 1:
268+
if isinstance(all_results[0], np.ndarray):
269+
final_result = np.concatenate(all_results, axis=0)
270+
else: # Assuming torch.Tensor
271+
final_result = torch.cat(all_results, dim=0)
272+
else:
273+
final_result = all_results[0]
274+
275+
return final_result
248276

249277
def __del__(self):
250278
# Send poison pills to workers

FlagEmbedding/inference/embedder/model_mapping.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -234,9 +234,14 @@ class EmbedderConfig:
234234
EmbedderConfig(FlagModel, PoolingMethod.CLS, trust_remote_code=True)
235235
),
236236
(
237-
'nomic',
238-
EmbedderConfig(NomicModel, None)
239-
)
237+
'eurobert-210m-2e4-128sl-subset',
238+
EmbedderConfig(FlagModel, PoolingMethod.MEAN, trust_remote_code=True)
239+
),
240+
# (
241+
# 'nomic',
242+
# EmbedderConfig(NomicModel, None)
243+
# )
244+
# TODO: Add more models, such as Jina, Stella_v5, NV-Embed, etc.
240245
])
241246

242247
# Combine all mappings

examples/evaluation/miracl/e5.md

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
| Model | Reranker | average | ar-dev | bn-dev | de-dev | en-dev | es-dev | fa-dev | fi-dev | fr-dev | hi-dev | id-dev | ja-dev | ko-dev | ru-dev | sw-dev | te-dev | th-dev | yo-dev | zh-dev |
2+
| :---- | :---- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
3+
| multilingual-e5-base | NoReranker | 39.219 | 51.822 | 52.237 | 23.059 | 31.310 | 26.891 | 32.108 | 56.341 | 20.316 | 29.566 | 36.206 | 43.846 | 46.233 | 34.874 | 54.593 | 61.314 | 56.953 | 20.704 | 27.565 |

examples/evaluation/miracl/eval_miracl.sh

+11-7
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@ if [ -z "$HF_HUB_CACHE" ]; then
22
export HF_HUB_CACHE="$HOME/.cache/huggingface/hub"
33
fi
44

5-
dataset_names="ar bn de en es fa fi fr hi id ja ko ru sw te th yo zh"
5+
# pass in language via cli, default is all languages
6+
#"ar bn de en es fa fi fr hi id ja ko ru sw te th yo zh"
7+
# 0 0 1 1 2 2 3 3 3 4 4 5 5 6 6 7 7
8+
dataset_names=(${1:-"ar bn de en es fa fi fr hi id ja ko ru sw te th yo zh"})
9+
device=${2:-"cuda:0 cuda:1 cuda:2 cuda:3 cuda:4 cuda:5 cuda:6 cuda:7"}
610

7-
VENV="/home/ubuntu/contrastors-dev/env/"
8-
source $VENV/bin/activate
911

1012
eval_args="\
1113
--eval_name miracl \
@@ -24,16 +26,18 @@ eval_args="\
2426
"
2527

2628
model_args="\
27-
--embedder_name_or_path /home/ubuntu/contrastors-dev/src/contrastors/ckpts/nomic-multi-finetune-bge-bge-m3-filtered-data-512tokens/epoch_0_model \
28-
--devices cuda:1 \
29+
--embedder_name_or_path nomic-ai/eurobert-210m-2e4-128sl-subset \
30+
--devices $device \
2931
--trust_remote_code \
3032
--query_instruction_for_retrieval 'search_query: ' \
3133
--passage_instruction_for_retrieval 'search_document: ' \
32-
--embedder_batch_size 32 \
34+
--embedder_batch_size 512 \
35+
--embedder_query_max_length 128 \
36+
--embedder_passage_max_length 128 \
3337
--cache_dir $HF_HUB_CACHE
3438
"
3539

36-
cmd="/home/ubuntu/contrastors-dev/env/bin/python -m FlagEmbedding.evaluation.miracl \
40+
cmd="uv run python -W ignore -m FlagEmbedding.evaluation.miracl \
3741
$eval_args \
3842
$model_args \
3943
"

examples/evaluation/miracl/miracl/miracl_eval_results.md

+10-6
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,25 @@
22

33
| Model | Reranker | average | ar-dev | bn-dev | de-dev | en-dev | es-dev | fa-dev | fi-dev | fr-dev | hi-dev | id-dev | ja-dev | ko-dev | ru-dev | sw-dev | te-dev | th-dev | yo-dev | zh-dev |
44
| :---- | :---- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
5+
| finetune_256_tokens | NoReranker | 65.756 | 76.435 | 72.657 | 57.761 | 56.709 | **57.068** | 58.854 | 77.334 | 55.999 | 60.278 | 54.011 | 67.139 | 66.395 | 64.520 | 64.530 | 82.571 | 77.493 | 77.345 | 56.514 |
6+
| finetune_512_tokens | NoReranker | 65.644 | 76.229 | 72.891 | 57.554 | 56.727 | 56.677 | 58.892 | 77.359 | 55.832 | 60.038 | 54.137 | 66.875 | 65.942 | 64.612 | 64.461 | 82.553 | 77.309 | 76.302 | 57.209 |
57
| snowflake-arctic-embed-l-v2.0 | NoReranker | 66.263 | 76.046 | 74.416 | **58.565** | 53.688 | 55.598 | 60.288 | 77.079 | 56.658 | 58.368 | 52.254 | 66.452 | 66.248 | 67.071 | 70.756 | 83.489 | 77.520 | 78.317 | 59.917 |
6-
| gte-multilingual-base | NoReranker | 63.560 | 71.407 | 72.908 | 49.722 | 54.030 | 51.779 | 54.007 | 73.497 | 54.490 | 51.888 | 50.315 | 65.798 | 62.862 | 63.244 | 69.925 | 83.076 | 74.037 | 79.332 | 61.765 |
7-
| multilingual-e5-base | NoReranker | 39.219 | 51.822 | 52.237 | 23.059 | 31.310 | 26.891 | 32.108 | 56.341 | 20.316 | 29.566 | 36.206 | 43.846 | 46.233 | 34.874 | 54.593 | 61.314 | 56.953 | 20.704 | 27.565 |
8+
| epoch_0_model | NoReranker | 65.996 | 76.675 | 73.627 | 56.597 | 54.657 | 56.303 | 59.219 | 77.095 | 55.831 | **60.485** | 54.281 | 67.037 | 65.898 | 65.148 | 66.303 | 82.615 | 78.366 | 78.260 | 59.529 |
89
| snowflake-arctic-embed-m-v2.0 | NoReranker | 60.604 | 69.689 | 67.648 | 56.645 | 55.739 | 55.416 | 52.611 | 68.359 | 54.035 | 53.662 | 48.267 | 58.268 | 59.696 | 58.766 | 52.289 | 81.711 | 74.249 | 75.559 | 48.270 |
910
| bge-m3 | NoReranker | **69.202** | **78.445** | **79.941** | 56.764 | **56.888** | 56.080 | **60.866** | **78.619** | **58.228** | 59.458 | **56.020** | **72.802** | **69.624** | **70.109** | **78.607** | **86.156** | **82.619** | **81.794** | **62.616** |
10-
| epoch_0_model | NoReranker | 65.756 | 76.435 | 72.657 | 57.761 | 56.709 | **57.068** | 58.854 | 77.334 | 55.999 | **60.278** | 54.011 | 67.139 | 66.395 | 64.520 | 64.530 | 82.571 | 77.493 | 77.345 | 56.514 |
11+
| multilingual-e5-base | NoReranker | - | 57.134 | 52.770 | 27.977 | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
12+
| gte-multilingual-base | NoReranker | 63.560 | 71.407 | 72.908 | 49.722 | 54.030 | 51.779 | 54.007 | 73.497 | 54.490 | 51.888 | 50.315 | 65.798 | 62.862 | 63.244 | 69.925 | 83.076 | 74.037 | 79.332 | 61.765 |
1113

1214
## recall_at_100
1315

1416
| Model | Reranker | average | ar-dev | bn-dev | de-dev | en-dev | es-dev | fa-dev | fi-dev | fr-dev | hi-dev | id-dev | ja-dev | ko-dev | ru-dev | sw-dev | te-dev | th-dev | yo-dev | zh-dev |
1517
| :---- | :---- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
18+
| finetune_256_tokens | NoReranker | 94.843 | **97.891** | 98.174 | **92.072** | 90.887 | **92.019** | 93.590 | 97.832 | 93.921 | 92.519 | 89.337 | 96.291 | 94.549 | 95.254 | 92.732 | 98.027 | 98.722 | **99.160** | 94.201 |
19+
| finetune_512_tokens | NoReranker | 94.824 | 97.891 | 97.991 | 91.409 | **91.150** | 91.534 | 93.627 | 97.367 | 93.970 | 93.044 | 89.109 | 96.638 | 94.653 | 95.141 | 92.383 | 98.209 | 98.636 | 98.739 | 95.349 |
1620
| snowflake-arctic-embed-l-v2.0 | NoReranker | 94.278 | 97.241 | 97.121 | 91.970 | 89.552 | 91.000 | 92.760 | 96.891 | 93.017 | 94.062 | 86.110 | 96.118 | 92.769 | 95.455 | 94.470 | 98.732 | 98.224 | 97.479 | 94.028 |
17-
| gte-multilingual-base | NoReranker | 92.247 | 95.043 | 96.225 | 86.854 | 88.911 | 86.107 | 88.971 | 96.083 | 91.727 | 88.495 | 84.238 | 94.870 | 89.577 | 92.710 | 94.238 | 98.168 | 97.195 | 96.218 | 94.820 |
18-
| multilingual-e5-base | NoReranker | 73.026 | 81.777 | 88.929 | 56.858 | 67.272 | 59.417 | 69.016 | 87.301 | 56.573 | 63.666 | 68.083 | 84.659 | 79.308 | 70.050 | 85.076 | 92.029 | 90.767 | 42.577 | 71.105 |
21+
| epoch_0_model | NoReranker | 94.849 | 97.762 | 97.986 | 91.433 | 90.256 | 91.463 | 93.298 | 97.623 | **93.994** | 92.669 | 89.585 | 96.594 | 95.188 | 95.061 | 93.675 | 98.390 | 98.754 | 98.319 | 95.228 |
1922
| snowflake-arctic-embed-m-v2.0 | NoReranker | 90.959 | 93.378 | 95.647 | 91.043 | 89.698 | 90.084 | 88.826 | 92.640 | 92.619 | 88.746 | 83.466 | 92.040 | 87.793 | 91.541 | 84.051 | 97.967 | 96.123 | 96.078 | 85.514 |
2023
| bge-m3 | NoReranker | **95.539** | 97.645 | **98.702** | 91.021 | 90.685 | 91.130 | **93.836** | **97.914** | 93.800 | **94.434** | **90.463** | **97.444** | **95.456** | **95.870** | **97.206** | **99.396** | **99.095** | 98.739 | **96.862** |
21-
| epoch_0_model | NoReranker | 94.843 | **97.891** | 98.174 | **92.072** | **90.887** | **92.019** | 93.590 | 97.832 | **93.921** | 92.519 | 89.337 | 96.291 | 94.549 | 95.254 | 92.732 | 98.027 | 98.722 | **99.160** | 94.201 |
24+
| multilingual-e5-base | NoReranker | - | 86.559 | 87.519 | 62.369 | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
25+
| gte-multilingual-base | NoReranker | 92.247 | 95.043 | 96.225 | 86.854 | 88.911 | 86.107 | 88.971 | 96.083 | 91.727 | 88.495 | 84.238 | 94.870 | 89.577 | 92.710 | 94.238 | 98.168 | 97.195 | 96.218 | 94.820 |
2226

0 commit comments

Comments
 (0)