diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 80d6c9cbd..a73993e33 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -38,6 +38,7 @@ from .models import Normalize, Pooling, Transformer from .quantization import quantize_embeddings from .util import ( + ImageChannelDimension, batch_to_device, get_device_name, import_from_string, @@ -382,6 +383,7 @@ def encode( convert_to_tensor: Literal[False] = ..., device: str = ..., normalize_embeddings: bool = ..., + image_channel_dimension: str = ImageChannelDimension.LAST, **kwargs, ) -> Tensor: ... @@ -399,6 +401,7 @@ def encode( convert_to_tensor: Literal[False] = ..., device: str = ..., normalize_embeddings: bool = ..., + image_channel_dimension: str = ImageChannelDimension.LAST, **kwargs, ) -> np.ndarray: ... @@ -416,6 +419,7 @@ def encode( convert_to_tensor: Literal[True] = ..., device: str = ..., normalize_embeddings: bool = ..., + image_channel_dimension: str = ImageChannelDimension.LAST, **kwargs, ) -> Tensor: ... @@ -433,6 +437,7 @@ def encode( convert_to_tensor: Literal[False] = ..., device: str = ..., normalize_embeddings: bool = ..., + image_channel_dimension: str = ImageChannelDimension.LAST, **kwargs, ) -> list[Tensor]: ... @@ -449,6 +454,7 @@ def encode( convert_to_tensor: bool = False, device: str = None, normalize_embeddings: bool = False, + image_channel_dimension: str = ImageChannelDimension.LAST, **kwargs, ) -> list[Tensor] | np.ndarray | Tensor: """ @@ -480,6 +486,8 @@ def encode( device (str, optional): Which :class:`torch.device` to use for the computation. Defaults to None. normalize_embeddings (bool, optional): Whether to normalize returned vectors to have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used. Defaults to False. + image_channel_dimension (str, optional): Indicate the color channel of an image to be the first or the last element of its shape. + If your image is 'ImageChannelDimension.FIRST', the input is required. Returns: Union[List[Tensor], ndarray, Tensor]: By default, a 2d numpy array with shape [num_inputs, output_dimension] is returned. @@ -511,6 +519,9 @@ def encode( self.is_hpu_graph_enabled = True self.eval() + # Will be used in Image Tokenizer + self.image_channel_dimension = image_channel_dimension + if show_progress_bar is None: show_progress_bar = logger.getEffectiveLevel() in (logging.INFO, logging.DEBUG) @@ -1021,7 +1032,7 @@ def tokenize(self, texts: list[str] | list[dict] | list[tuple[str, str]]) -> dic Dict[str, Tensor]: A dictionary of tensors with the tokenized texts. Common keys are "input_ids", "attention_mask", and "token_type_ids". """ - return self._first_module().tokenize(texts) + return self._first_module().tokenize(texts, image_channel_dimension=self.image_channel_dimension) def get_sentence_features(self, *features) -> dict[Literal["sentence_embedding"], torch.Tensor]: return self._first_module().get_sentence_features(*features) diff --git a/sentence_transformers/models/CLIPModel.py b/sentence_transformers/models/CLIPModel.py index 4eccb3c55..ff3442ef2 100644 --- a/sentence_transformers/models/CLIPModel.py +++ b/sentence_transformers/models/CLIPModel.py @@ -4,6 +4,7 @@ import transformers from PIL import Image from torch import nn +from ..util import ImageChannelDimension class CLIPModel(nn.Module): @@ -51,10 +52,11 @@ def forward(self, features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: return features - def tokenize(self, texts, padding: str | bool = True) -> dict[str, torch.Tensor]: + def tokenize(self, texts, padding: str | bool = True, **kwargs) -> dict[str, torch.Tensor]: images = [] texts_values = [] image_text_info = [] + image_channel_dimension = kwargs.get("image_channel_dimension", ImageChannelDimension.LAST) for idx, data in enumerate(texts): if isinstance(data, Image.Image): # An Image @@ -69,7 +71,7 @@ def tokenize(self, texts, padding: str | bool = True) -> dict[str, torch.Tensor] encoding = self.processor.tokenizer(texts_values, return_tensors="pt", padding=padding) if len(images): - image_features = self.processor.image_processor(images, return_tensors="pt") + image_features = self.processor.image_processor(images, return_tensors="pt", input_data_format=image_channel_dimension) encoding["pixel_values"] = image_features.pixel_values encoding["image_text_info"] = image_text_info diff --git a/sentence_transformers/models/Transformer.py b/sentence_transformers/models/Transformer.py index ea66a13cb..449e00547 100644 --- a/sentence_transformers/models/Transformer.py +++ b/sentence_transformers/models/Transformer.py @@ -134,7 +134,7 @@ def get_word_embedding_dimension(self) -> int: return self.auto_model.config.hidden_size def tokenize( - self, texts: list[str] | list[dict] | list[tuple[str, str]], padding: str | bool = True + self, texts: list[str] | list[dict] | list[tuple[str, str]], padding: str | bool = True, **kwargs ) -> dict[str, torch.Tensor]: """Tokenizes a text and maps tokens to token-ids""" output = {} diff --git a/sentence_transformers/util.py b/sentence_transformers/util.py index 5b8baa5f7..b22fc4cf6 100644 --- a/sentence_transformers/util.py +++ b/sentence_transformers/util.py @@ -29,6 +29,12 @@ from sentence_transformers.cross_encoder.CrossEncoder import CrossEncoder from sentence_transformers.SentenceTransformer import SentenceTransformer +class ImageChannelDimension(): + """ + Defines the color channels' position in an Image's shape + """ + FIRST = "channels_first" + LAST = "channels_last" def _convert_to_tensor(a: list | np.ndarray | Tensor) -> Tensor: """