diff --git a/tests/test_semdedup.py b/tests/test_semdedup.py index 6b33e7652..5411e4aba 100644 --- a/tests/test_semdedup.py +++ b/tests/test_semdedup.py @@ -154,7 +154,7 @@ def get_reference_embeddings( if pooling_strategy == "last_token": embeddings = outputs.last_hidden_state[:, -1, :] - elif pooling_strategy == "mean": + elif pooling_strategy == "mean_pooling": token_embeddings = outputs.last_hidden_state attention_mask = inputs["attention_mask"] input_mask_expanded = (