@@ -193,6 +193,7 @@ def multilingual_extractive_match_metric(
193193 fallback_mode : Literal ["no_fallback" , "first_match" ] = "first_match" ,
194194 extraction_mode : Literal ["first_match" , "any_match" ] = "any_match" ,
195195 precision : int = 6 ,
196+ timeout_seconds : int = 5 ,
196197) -> SampleLevelMetric :
197198 """Creates a language-aware extractive match metric that extracts answers from the model's output.
198199
@@ -222,6 +223,8 @@ def multilingual_extractive_match_metric(
222223
223224 precision: int
224225 Number of decimal places to use when comparing numerical values. Defaults to 6.
226+ timeout_seconds: int
227+ Timeout for the extraction (each attempt) and comparison. Defaults to 5.
225228
226229 Returns:
227230 A sample level metric that extracts and compares mathematical expressions.
@@ -245,11 +248,12 @@ def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc
245248 pred_extraction_regexes = get_extraction_regexes (formatted_doc , pred_extraction_target , language )
246249
247250 extracted_predictions = [
248- extract_target_from_pred (pred , pred_extraction_regexes , fallback_mode , extraction_mode )
251+ extract_target_from_pred (pred , pred_extraction_regexes , fallback_mode , extraction_mode , timeout_seconds )
249252 for pred in predictions
250253 ]
251254 extracted_golds = [
252- extract_target_from_pred (gold , gold_extraction_regexes , fallback_mode , extraction_mode ) for gold in golds
255+ extract_target_from_pred (gold , gold_extraction_regexes , fallback_mode , extraction_mode , timeout_seconds )
256+ for gold in golds
253257 ]
254258
255259 # Assert on empty gold and warn on empty pred
@@ -265,12 +269,19 @@ def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc
265269 # We have to use timeout because the sypmy to str conversion can be very slow
266270 try :
267271 add_to_specifics_with_timeout (formatted_doc , extracted_predictions , extracted_golds )
268- except : # noqa: E722
272+ except Exception : # noqa: E722
269273 logger .warning ("Timeout when adding extracted predictions and golds to specific" )
270274
271275 return aggregation_function (
272276 [
273- (1.0 if any (compare_gold_target (gold , pred , precision ) for gold in extracted_golds ) else 0.0 )
277+ (
278+ 1.0
279+ if any (
280+ compare_gold_target (gold , pred , precision , timeout_seconds = timeout_seconds )
281+ for gold in extracted_golds
282+ )
283+ else 0.0
284+ )
274285 for pred in extracted_predictions
275286 ]
276287 )
0 commit comments