|
| 1 | +import os |
| 2 | +import torch |
| 3 | +import numpy as np |
| 4 | +from typing import List, Dict, Optional, Union |
| 5 | +from tqdm import tqdm |
| 6 | +from openai import OpenAI |
| 7 | +from tenacity import retry, stop_after_attempt, wait_exponential |
| 8 | +from FlagEmbedding.abc.inference import AbsEmbedder |
| 9 | +from transformers.configuration_utils import PretrainedConfig |
| 10 | + |
| 11 | + |
| 12 | +class NvidiaEmbedderConfig(PretrainedConfig): |
| 13 | + def __init__(self, **kwargs): |
| 14 | + super().__init__(**kwargs) |
| 15 | + self._name_or_path = "nvidia" |
| 16 | + |
| 17 | + |
| 18 | +class NvidiaMockModel: |
| 19 | + def __init__(self): |
| 20 | + self.config = NvidiaEmbedderConfig() |
| 21 | + |
| 22 | + |
| 23 | +class NvidiaEmbedder(AbsEmbedder): |
| 24 | + def __init__( |
| 25 | + self, |
| 26 | + model_name_or_path: str, |
| 27 | + normalize_embeddings: bool = True, |
| 28 | + use_fp16: bool = True, |
| 29 | + query_instruction_for_retrieval: Optional[str] = None, |
| 30 | + query_instruction_format: str = "{}{}", |
| 31 | + devices: Optional[Union[str, List[str]]] = None, |
| 32 | + batch_size: int = 2048, |
| 33 | + query_max_length: int = 512, |
| 34 | + passage_max_length: int = 512, |
| 35 | + convert_to_numpy: bool = True, |
| 36 | + **kwargs |
| 37 | + ): |
| 38 | + super().__init__( |
| 39 | + model_name_or_path, |
| 40 | + normalize_embeddings=normalize_embeddings, |
| 41 | + use_fp16=use_fp16, |
| 42 | + query_instruction_for_retrieval=query_instruction_for_retrieval, |
| 43 | + query_instruction_format=query_instruction_format, |
| 44 | + devices=devices, |
| 45 | + batch_size=batch_size, |
| 46 | + query_max_length=query_max_length, |
| 47 | + passage_max_length=passage_max_length, |
| 48 | + convert_to_numpy=convert_to_numpy, |
| 49 | + **kwargs |
| 50 | + ) |
| 51 | + |
| 52 | + self.model = NvidiaMockModel() |
| 53 | + self.client = OpenAI( |
| 54 | + api_key="not-needed", # API key not needed for local server |
| 55 | + base_url="http://localhost:8000/v1" |
| 56 | + ) |
| 57 | + self.model_name = "nvidia/llama-3.2-nv-embedqa-1b-v2" |
| 58 | + |
| 59 | + @retry( |
| 60 | + stop=stop_after_attempt(5), |
| 61 | + wait=wait_exponential(multiplier=1, min=4, max=30), |
| 62 | + reraise=True |
| 63 | + ) |
| 64 | + def _get_embeddings(self, texts: List[str], input_type: str = "query") -> np.ndarray: |
| 65 | + """Get embeddings for a batch of texts with automatic retries and truncation on token size errors.""" |
| 66 | + def try_with_texts(current_texts: List[str], retry_count: int = 0) -> np.ndarray: |
| 67 | + try: |
| 68 | + response = self.client.embeddings.create( |
| 69 | + input=current_texts, |
| 70 | + model=self.model_name, |
| 71 | + encoding_format="float", |
| 72 | + extra_body={"input_type": input_type, "truncate": "END"} |
| 73 | + ) |
| 74 | + return np.array([data.embedding for data in response.data]) |
| 75 | + except Exception as e: |
| 76 | + error_str = str(e) |
| 77 | + print(f"Error in _get_embeddings: {error_str}") |
| 78 | + |
| 79 | + # If we hit token size limit and haven't retried too many times, truncate and retry |
| 80 | + if "token size" in error_str.lower() and retry_count < 3: |
| 81 | + truncated_texts = [t[:len(t)//2] for t in current_texts] |
| 82 | + print(f"Retrying with truncated texts (retry {retry_count + 1})") |
| 83 | + return try_with_texts(truncated_texts, retry_count + 1) |
| 84 | + |
| 85 | + raise |
| 86 | + |
| 87 | + return try_with_texts(texts) |
| 88 | + |
| 89 | + def encode_queries(self, queries: List[str], batch_size: int = 128, **kwargs) -> np.ndarray: |
| 90 | + """Encode queries with input_type='query'.""" |
| 91 | + all_embeddings = [] |
| 92 | + for i in tqdm(range(0, len(queries), batch_size), desc="Encoding queries"): |
| 93 | + batch = queries[i:i + batch_size] |
| 94 | + embeddings = self._get_embeddings(batch, input_type="query") |
| 95 | + all_embeddings.append(embeddings) |
| 96 | + return np.vstack(all_embeddings) |
| 97 | + |
| 98 | + @staticmethod |
| 99 | + def _process_batch_static(args): |
| 100 | + """Static method to process a batch of passages.""" |
| 101 | + index, batch, model_name, base_url = args |
| 102 | + try: |
| 103 | + # Create a new client for each process |
| 104 | + client = OpenAI( |
| 105 | + api_key="not-needed", |
| 106 | + base_url=base_url |
| 107 | + ) |
| 108 | + |
| 109 | + # Get embeddings |
| 110 | + response = client.embeddings.create( |
| 111 | + input=batch, |
| 112 | + model=model_name, |
| 113 | + encoding_format="float", |
| 114 | + extra_body={"input_type": "passage", "truncate": "END"} |
| 115 | + ) |
| 116 | + embeddings = np.array([data.embedding for data in response.data]) |
| 117 | + return index, embeddings |
| 118 | + except Exception as e: |
| 119 | + print(f"Error processing batch {index}: {str(e)}") |
| 120 | + return index, None |
| 121 | + |
| 122 | + def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 128, num_processes: int = 16, **kwargs) -> np.ndarray: |
| 123 | + """Encode corpus passages with input_type='passage' using multiple processes.""" |
| 124 | + if isinstance(corpus[0], dict): |
| 125 | + passages = [f"{doc.get('title', '')} {doc['text']}".strip() for doc in corpus] |
| 126 | + else: |
| 127 | + passages = corpus |
| 128 | + |
| 129 | + # Prepare batches with their indices and required parameters |
| 130 | + batches = [] |
| 131 | + for i in range(0, len(passages), batch_size): |
| 132 | + batch = passages[i:i + batch_size] |
| 133 | + # Include model name and base URL for each batch |
| 134 | + batches.append((len(batches), batch, self.model_name, self.client.base_url)) |
| 135 | + |
| 136 | + # Process batches in parallel |
| 137 | + from multiprocessing import Pool |
| 138 | + with Pool(processes=num_processes) as pool: |
| 139 | + # Use tqdm to show progress |
| 140 | + results = list(tqdm( |
| 141 | + pool.imap(self._process_batch_static, batches), |
| 142 | + total=len(batches), |
| 143 | + desc=f"Encoding corpus with {num_processes} processes" |
| 144 | + )) |
| 145 | + |
| 146 | + # Sort results by index and collect embeddings |
| 147 | + sorted_results = sorted(results, key=lambda x: x[0]) |
| 148 | + all_embeddings = [] |
| 149 | + for _, embeddings in sorted_results: |
| 150 | + if embeddings is not None: |
| 151 | + all_embeddings.append(embeddings) |
| 152 | + else: |
| 153 | + raise Exception("One or more batches failed to process") |
| 154 | + |
| 155 | + return np.vstack(all_embeddings) |
| 156 | + |
| 157 | + @torch.no_grad() |
| 158 | + def encode_single_device( |
| 159 | + self, |
| 160 | + sentences: Union[List[str], str], |
| 161 | + batch_size: int = 128, |
| 162 | + max_length: int = 512, |
| 163 | + convert_to_numpy: bool = True, |
| 164 | + device: Optional[str] = None, |
| 165 | + ): |
| 166 | + """Single device encoding method that defaults to query encoding.""" |
| 167 | + if isinstance(sentences, str): |
| 168 | + sentences = [sentences] |
| 169 | + return self.encode_queries(sentences, batch_size=batch_size) |
0 commit comments