diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index e44e573a5..79deef989 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -114,6 +114,7 @@ def encode(self, sentences: Union[str, List[str]], output_value: str = 'sentence_embedding', convert_to_numpy: bool = True, convert_to_tensor: bool = False, + move_to_cpu: bool = False, device: str = None, normalize_embeddings: bool = False) -> Union[List[Tensor], ndarray, Tensor]: """ @@ -125,6 +126,7 @@ def encode(self, sentences: Union[str, List[str]], :param output_value: Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values :param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors. :param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy + :param move_to_cpu: If true, the obtained embedding tensors are sequentially moved to the CPU. :param device: Which torch.device to use for the computation :param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used. @@ -142,6 +144,9 @@ def encode(self, sentences: Union[str, List[str]], convert_to_tensor = False convert_to_numpy = False + if convert_to_numpy: + move_to_cpu = True + input_was_string = False if isinstance(sentences, str) or not hasattr(sentences, '__len__'): #Cast an individual sentence to a list with length 1 sentences = [sentences] @@ -171,7 +176,10 @@ def encode(self, sentences: Union[str, List[str]], while last_mask_id > 0 and attention[last_mask_id].item() == 0: last_mask_id -= 1 - embeddings.append(token_emb[0:last_mask_id+1]) + token_embeddings = token_emb[0:last_mask_id+1] + if move_to_cpu: + token_embeddings = token_embeddings.cpu() + embeddings.append(token_embeddings) elif output_value is None: #Return all outputs embeddings = [] for sent_idx in range(len(out_features['sentence_embedding'])): @@ -184,7 +192,7 @@ def encode(self, sentences: Union[str, List[str]], embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) # fixes for #522 and #487 to avoid oom problems on gpu with large datasets - if convert_to_numpy: + if move_to_cpu: embeddings = embeddings.cpu() all_embeddings.extend(embeddings)