Skip to content

Commit 99d2983

Browse files
committed
ug
1 parent 1d79690 commit 99d2983

File tree

1 file changed

+104
-50
lines changed
  • FlagEmbedding/inference/embedder/encoder_only

1 file changed

+104
-50
lines changed

FlagEmbedding/inference/embedder/encoder_only/voyage.py

+104-50
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,75 @@ class VoyageMockModel():
3939
def __init__(self):
4040
self.config = VoyageEmbedderConfig()
4141

42+
MODEL = 'voyage-3-large'
43+
44+
def call_api_query(text_chunk, cache_dir="./voyage_cache"):
45+
"""Simple function that just calls the API - used in multiprocessing"""
46+
cache_dir = pathlib.Path(cache_dir)
47+
cache_dir.mkdir(exist_ok=True)
48+
49+
# Create cache key from input
50+
chunks_str = json.dumps(text_chunk, sort_keys=True)
51+
cache_key = hashlib.md5(chunks_str.encode()).hexdigest()
52+
cache_file = cache_dir / f"query_{cache_key}.npy"
53+
54+
# Check cache first
55+
if cache_file.exists():
56+
try:
57+
return np.load(cache_file).tolist()
58+
except Exception as e:
59+
print(f"Failed to load cache file {cache_file}: {e}")
60+
61+
# If not in cache, call API
62+
vo = voyageai.Client(api_key=VOYAGE_API_KEY)
63+
max_retries = 5
64+
for attempt in range(max_retries):
65+
try:
66+
result = vo.embed(text_chunk, model=MODEL, input_type='query')
67+
embeddings = result.embeddings
68+
# Save to cache
69+
np.save(cache_file, embeddings)
70+
return embeddings
71+
except Exception as e:
72+
if attempt == max_retries - 1:
73+
print(f"Failed after {max_retries} attempts: {str(e)}")
74+
raise e
75+
time.sleep(5)
76+
77+
def call_api_document(text_chunk, cache_dir="./voyage_cache"):
78+
"""Simple function that just calls the API - used in multiprocessing"""
79+
cache_dir = pathlib.Path(cache_dir)
80+
cache_dir.mkdir(exist_ok=True)
81+
82+
# Create cache key from input
83+
chunks_str = json.dumps(text_chunk, sort_keys=True)
84+
cache_key = hashlib.md5(chunks_str.encode()).hexdigest()
85+
cache_file = cache_dir / f"doc_{cache_key}.npy"
86+
87+
# Check cache first
88+
if cache_file.exists():
89+
try:
90+
return np.load(cache_file).tolist()
91+
except Exception as e:
92+
print(f"Failed to load cache file {cache_file}: {e}")
93+
94+
# If not in cache, call API
95+
vo = voyageai.Client(api_key=VOYAGE_API_KEY)
96+
max_retries = 5
97+
for attempt in range(max_retries):
98+
try:
99+
result = vo.embed(text_chunk, model=MODEL, input_type='document')
100+
embeddings = result.embeddings
101+
# Save to cache
102+
np.save(cache_file, embeddings)
103+
return embeddings
104+
except Exception as e:
105+
if attempt == max_retries - 1:
106+
print(f"Failed after {max_retries} attempts: {str(e)}")
107+
raise e
108+
time.sleep(5)
109+
110+
42111
class VoyageEmbedder(AbsEmbedder):
43112
def __init__(
44113
self,
@@ -106,64 +175,49 @@ def encode_single_device(
106175
):
107176
return self.encode_queries(sentences, batch_size=batch_size)
108177

109-
def encode_queries(self, queries: List[str], batch_size: int = 256, **kwargs) -> np.ndarray:
178+
179+
def encode_queries(self, queries: List[str], batch_size: int = 128, num_parallel=32, **kwargs) -> np.ndarray:
180+
print('Encoding queries')
181+
# Prepare chunks outside of multiprocessing
182+
chunks = list(split_list(queries, batch_size))
183+
184+
# Setup the pool and process chunks
185+
with mp.Pool(num_parallel) as pool:
186+
results = list(tqdm(
187+
pool.imap(call_api_query, chunks),
188+
desc="Encoding queries",
189+
total=len(chunks)
190+
))
110191

111-
#queries = [cutoff_long_text_for_embedding_generation(query, self.encoding, cutoff=4096) for query in queries]
192+
# Flatten results
112193
total_encoded_queries = []
113-
#for query_chunks in tqdm(split_list(queries, self.encoder_batch_size), total=len(queries)//self.encoder_batch_size):
114-
for query_chunks in tqdm(split_list(queries, batch_size), total=len(queries)//batch_size):
115-
try:
116-
encoded_queries = self.vo.embed(query_chunks, model=self.embedding_model, input_type='query')
117-
encoded_queries = encoded_queries.embeddings
118-
except Exception as e:
119-
raise e
120-
time.sleep(5)
121-
encoded_queries = self.vo.embed(query_chunks, model=self.embedding_model, input_type='query')
122-
encoded_queries = encoded_queries.embeddings
123-
124-
#encoded_queries = [query_encoding for query_encoding in encoded_queries]
125-
total_encoded_queries += encoded_queries
194+
for result in results:
195+
total_encoded_queries.extend(result)
196+
126197
return np.array(total_encoded_queries)
127198

128-
# Write your own encoding corpus function (Returns: Document embeddings as numpy array)
129-
def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 256, **kwargs) -> np.ndarray:
199+
def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 128, num_parallel=32, **kwargs) -> np.ndarray:
200+
print('Encoding corpus')
201+
# Prepare passages outside of multiprocessing
130202
if isinstance(corpus[0], dict):
131203
passages = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus]
132204
else:
133205
passages = corpus
206+
207+
passages = [passage[:8192*8] for passage in passages]
208+
chunks = list(split_list(passages, batch_size))
209+
210+
with mp.Pool(num_parallel) as pool:
211+
results = list(tqdm(
212+
pool.imap(call_api_document, chunks),
213+
desc="Encoding documents",
214+
total=len(chunks)
215+
))
216+
217+
# Flatten results
218+
total_encoded_queries = []
219+
for result in results:
220+
total_encoded_queries.extend(result)
134221

135-
passages = [
136-
passage[:8192*8] #modify for context length
137-
for passage in passages
138-
]
139222

140-
total_encoded_passages = []
141-
#for passage_chunks in tqdm(split_list(passages, self.encoder_batch_size), total=len(passages)//self.encoder_batch_size):
142-
for passage_chunks in tqdm(split_list(passages, batch_size), total=len(passages)//batch_size):
143-
# Create a hash of the passage chunks for the cache filename
144-
chunks_str = json.dumps(passage_chunks, sort_keys=True)
145-
cache_key = hashlib.md5(chunks_str.encode()).hexdigest()
146-
cache_file = self.cache_dir / f"{cache_key}.npy"
147-
148-
if cache_file.exists():
149-
# Load cached embeddings if they exist
150-
self.logger.info(f"Cache hit for key: {cache_key[:8]}...")
151-
encoded_passages = np.load(cache_file)
152-
encoded_passages = encoded_passages.tolist()
153-
else:
154-
attempts = 0
155-
while attempts < 5 and not cache_file.exists():
156-
try:
157-
encoded_passages = self.vo.embed(passage_chunks, model=self.embedding_model, input_type='document')
158-
encoded_passages = encoded_passages.embeddings
159-
# Cache the results
160-
np.save(cache_file, encoded_passages)
161-
except Exception as e:
162-
attempts += 1
163-
self.logger.warning(f"API call failed for key: {cache_key[:8]}... Retrying after 30s")
164-
time.sleep(5)
165-
if not cache_file.exists():
166-
raise Exception(f"Failed to retrieve embeddings after 5 attempts for key: {cache_key[:8]}")
167-
168-
total_encoded_passages += encoded_passages
169223
return np.array(total_encoded_passages)

0 commit comments

Comments
 (0)