@@ -39,6 +39,75 @@ class VoyageMockModel():
39
39
def __init__ (self ):
40
40
self .config = VoyageEmbedderConfig ()
41
41
42
+ MODEL = 'voyage-3-large'
43
+
44
+ def call_api_query (text_chunk , cache_dir = "./voyage_cache" ):
45
+ """Simple function that just calls the API - used in multiprocessing"""
46
+ cache_dir = pathlib .Path (cache_dir )
47
+ cache_dir .mkdir (exist_ok = True )
48
+
49
+ # Create cache key from input
50
+ chunks_str = json .dumps (text_chunk , sort_keys = True )
51
+ cache_key = hashlib .md5 (chunks_str .encode ()).hexdigest ()
52
+ cache_file = cache_dir / f"query_{ cache_key } .npy"
53
+
54
+ # Check cache first
55
+ if cache_file .exists ():
56
+ try :
57
+ return np .load (cache_file ).tolist ()
58
+ except Exception as e :
59
+ print (f"Failed to load cache file { cache_file } : { e } " )
60
+
61
+ # If not in cache, call API
62
+ vo = voyageai .Client (api_key = VOYAGE_API_KEY )
63
+ max_retries = 5
64
+ for attempt in range (max_retries ):
65
+ try :
66
+ result = vo .embed (text_chunk , model = MODEL , input_type = 'query' )
67
+ embeddings = result .embeddings
68
+ # Save to cache
69
+ np .save (cache_file , embeddings )
70
+ return embeddings
71
+ except Exception as e :
72
+ if attempt == max_retries - 1 :
73
+ print (f"Failed after { max_retries } attempts: { str (e )} " )
74
+ raise e
75
+ time .sleep (5 )
76
+
77
+ def call_api_document (text_chunk , cache_dir = "./voyage_cache" ):
78
+ """Simple function that just calls the API - used in multiprocessing"""
79
+ cache_dir = pathlib .Path (cache_dir )
80
+ cache_dir .mkdir (exist_ok = True )
81
+
82
+ # Create cache key from input
83
+ chunks_str = json .dumps (text_chunk , sort_keys = True )
84
+ cache_key = hashlib .md5 (chunks_str .encode ()).hexdigest ()
85
+ cache_file = cache_dir / f"doc_{ cache_key } .npy"
86
+
87
+ # Check cache first
88
+ if cache_file .exists ():
89
+ try :
90
+ return np .load (cache_file ).tolist ()
91
+ except Exception as e :
92
+ print (f"Failed to load cache file { cache_file } : { e } " )
93
+
94
+ # If not in cache, call API
95
+ vo = voyageai .Client (api_key = VOYAGE_API_KEY )
96
+ max_retries = 5
97
+ for attempt in range (max_retries ):
98
+ try :
99
+ result = vo .embed (text_chunk , model = MODEL , input_type = 'document' )
100
+ embeddings = result .embeddings
101
+ # Save to cache
102
+ np .save (cache_file , embeddings )
103
+ return embeddings
104
+ except Exception as e :
105
+ if attempt == max_retries - 1 :
106
+ print (f"Failed after { max_retries } attempts: { str (e )} " )
107
+ raise e
108
+ time .sleep (5 )
109
+
110
+
42
111
class VoyageEmbedder (AbsEmbedder ):
43
112
def __init__ (
44
113
self ,
@@ -106,64 +175,49 @@ def encode_single_device(
106
175
):
107
176
return self .encode_queries (sentences , batch_size = batch_size )
108
177
109
- def encode_queries (self , queries : List [str ], batch_size : int = 256 , ** kwargs ) -> np .ndarray :
178
+
179
+ def encode_queries (self , queries : List [str ], batch_size : int = 128 , num_parallel = 32 , ** kwargs ) -> np .ndarray :
180
+ print ('Encoding queries' )
181
+ # Prepare chunks outside of multiprocessing
182
+ chunks = list (split_list (queries , batch_size ))
183
+
184
+ # Setup the pool and process chunks
185
+ with mp .Pool (num_parallel ) as pool :
186
+ results = list (tqdm (
187
+ pool .imap (call_api_query , chunks ),
188
+ desc = "Encoding queries" ,
189
+ total = len (chunks )
190
+ ))
110
191
111
- #queries = [cutoff_long_text_for_embedding_generation(query, self.encoding, cutoff=4096) for query in queries]
192
+ # Flatten results
112
193
total_encoded_queries = []
113
- #for query_chunks in tqdm(split_list(queries, self.encoder_batch_size), total=len(queries)//self.encoder_batch_size):
114
- for query_chunks in tqdm (split_list (queries , batch_size ), total = len (queries )// batch_size ):
115
- try :
116
- encoded_queries = self .vo .embed (query_chunks , model = self .embedding_model , input_type = 'query' )
117
- encoded_queries = encoded_queries .embeddings
118
- except Exception as e :
119
- raise e
120
- time .sleep (5 )
121
- encoded_queries = self .vo .embed (query_chunks , model = self .embedding_model , input_type = 'query' )
122
- encoded_queries = encoded_queries .embeddings
123
-
124
- #encoded_queries = [query_encoding for query_encoding in encoded_queries]
125
- total_encoded_queries += encoded_queries
194
+ for result in results :
195
+ total_encoded_queries .extend (result )
196
+
126
197
return np .array (total_encoded_queries )
127
198
128
- # Write your own encoding corpus function (Returns: Document embeddings as numpy array)
129
- def encode_corpus (self , corpus : List [Dict [str , str ]], batch_size : int = 256 , ** kwargs ) -> np .ndarray :
199
+ def encode_corpus (self , corpus : List [Dict [str , str ]], batch_size : int = 128 , num_parallel = 32 , ** kwargs ) -> np .ndarray :
200
+ print ('Encoding corpus' )
201
+ # Prepare passages outside of multiprocessing
130
202
if isinstance (corpus [0 ], dict ):
131
203
passages = ['{} {}' .format (doc .get ('title' , '' ), doc ['text' ]).strip () for doc in corpus ]
132
204
else :
133
205
passages = corpus
206
+
207
+ passages = [passage [:8192 * 8 ] for passage in passages ]
208
+ chunks = list (split_list (passages , batch_size ))
209
+
210
+ with mp .Pool (num_parallel ) as pool :
211
+ results = list (tqdm (
212
+ pool .imap (call_api_document , chunks ),
213
+ desc = "Encoding documents" ,
214
+ total = len (chunks )
215
+ ))
216
+
217
+ # Flatten results
218
+ total_encoded_queries = []
219
+ for result in results :
220
+ total_encoded_queries .extend (result )
134
221
135
- passages = [
136
- passage [:8192 * 8 ] #modify for context length
137
- for passage in passages
138
- ]
139
222
140
- total_encoded_passages = []
141
- #for passage_chunks in tqdm(split_list(passages, self.encoder_batch_size), total=len(passages)//self.encoder_batch_size):
142
- for passage_chunks in tqdm (split_list (passages , batch_size ), total = len (passages )// batch_size ):
143
- # Create a hash of the passage chunks for the cache filename
144
- chunks_str = json .dumps (passage_chunks , sort_keys = True )
145
- cache_key = hashlib .md5 (chunks_str .encode ()).hexdigest ()
146
- cache_file = self .cache_dir / f"{ cache_key } .npy"
147
-
148
- if cache_file .exists ():
149
- # Load cached embeddings if they exist
150
- self .logger .info (f"Cache hit for key: { cache_key [:8 ]} ..." )
151
- encoded_passages = np .load (cache_file )
152
- encoded_passages = encoded_passages .tolist ()
153
- else :
154
- attempts = 0
155
- while attempts < 5 and not cache_file .exists ():
156
- try :
157
- encoded_passages = self .vo .embed (passage_chunks , model = self .embedding_model , input_type = 'document' )
158
- encoded_passages = encoded_passages .embeddings
159
- # Cache the results
160
- np .save (cache_file , encoded_passages )
161
- except Exception as e :
162
- attempts += 1
163
- self .logger .warning (f"API call failed for key: { cache_key [:8 ]} ... Retrying after 30s" )
164
- time .sleep (5 )
165
- if not cache_file .exists ():
166
- raise Exception (f"Failed to retrieve embeddings after 5 attempts for key: { cache_key [:8 ]} " )
167
-
168
- total_encoded_passages += encoded_passages
169
223
return np .array (total_encoded_passages )
0 commit comments