diff --git a/docs/sentence_transformer/usage/custom_models.rst b/docs/sentence_transformer/usage/custom_models.rst new file mode 100644 index 000000000..eacabef3f --- /dev/null +++ b/docs/sentence_transformer/usage/custom_models.rst @@ -0,0 +1,318 @@ +Creating Custom Models +======================= + +Structure of Sentence Transformer Models +---------------------------------------- + +A Sentence Transformer model consists of a collection of modules (`docs <../../package_reference/sentence_transformer/models.html>`_) that are executed sequentially. The most common architecture is a combination of a :class:`~sentence_transformers.models.Transformer` module, a :class:`~sentence_transformers.models.Pooling` module, and optionally, a :class:`~sentence_transformers.models.Dense` module and/or a :class:`~sentence_transformers.models.Normalize` module. + +* :class:`~sentence_transformers.models.Transformer`: This module is responsible for processing the input text and generating contextualized embeddings. +* :class:`~sentence_transformers.models.Pooling`: This module reduces the dimensionality of the output from the Transformer module by aggregating the embeddings. Common pooling strategies include mean pooling and CLS pooling. +* :class:`~sentence_transformers.models.Dense`: This module contains a linear layer that post-processes the embedding output from the Pooling module. +* :class:`~sentence_transformers.models.Normalize`: This module normalizes the embedding from the previous layer. + +For example, the popular `all-MiniLM-L6-v2 `_ model can also be loaded by initializing the 3 specific modules that make up that model: + +.. code-block:: python + + from sentence_transformers import models, SentenceTransformer + + transformer = models.Transformer("sentence-transformers/all-MiniLM-L6-v2", max_seq_length=256) + pooling = models.Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean") + normalize = models.Normalize() + + model = SentenceTransformer(modules=[transformer, pooling, normalize]) + +Saving Sentence Transformer Models +++++++++++++++++++++++++++++++++++ + +Whenever a Sentence Transformer model is saved, three types of files are generated: + +* ``modules.json``: This file contains a list of module names, paths, and types that are used to reconstruct the model. +* ``config_sentence_transformers.json``: This file contains some configuration options of the Sentence Transformer model, including saved prompts, the model its similarity function, and the Sentence Transformer package version used by the model author. +* **Module-specific files**: Each module is saved in a separate folder, with the first module saved in the root folder and all subsequent modules saved in subfolders named after the module index and the model name (e.g., ``1_Pooling``, ``2_Normalize``). + Most module folders contain a ``config.json`` (or ``sentence_bert_config.json`` for the :class:`~sentence_transformers.models.Transformer` module) file that stores default values for keyword arguments passed to that Module. So, a ``sentence_bert_config.json`` of:: + + { + "max_seq_length": 4096, + "do_lower_case": false + } + + means that the :class:`~sentence_transformers.models.Transformer` module will be initialized with ``max_seq_length=4096`` and ``do_lower_case=False``. + +As a result, if I call :meth:`SentenceTransformer.save_pretrained("local-all-MiniLM-L6-v2") ` on the ``model`` from the previous snippet, the following files are generated: + +.. code-block:: bash + + local-all-MiniLM-L6-v2/ + ├── 1_Pooling + │ └── config.json + ├── 2_Normalize + ├── README.md + ├── config.json + ├── config_sentence_transformers.json + ├── model.safetensors + ├── modules.json + ├── sentence_bert_config.json + ├── special_tokens_map.json + ├── tokenizer.json + ├── tokenizer_config.json + └── vocab.txt + +This contains a ``modules.json`` with these contents: + +.. code-block:: json + + [ + { + "idx": 0, + "name": "0", + "path": "", + "type": "sentence_transformers.models.Transformer" + }, + { + "idx": 1, + "name": "1", + "path": "1_Pooling", + "type": "sentence_transformers.models.Pooling" + }, + { + "idx": 2, + "name": "2", + "path": "2_Normalize", + "type": "sentence_transformers.models.Normalize" + } + ] + +And a ``config_sentence_transformers.json`` with these contents: + +.. code-block:: json + + { + "__version__": { + "sentence_transformers": "3.0.1", + "transformers": "4.43.4", + "pytorch": "2.5.0" + }, + "prompts": {}, + "default_prompt_name": null, + "similarity_fn_name": null + } + +Additionally, the ``1_Pooling`` directory contains the configuration file for the :class:`~sentence_transformers.models.Pooling` module, while the ``2_Normalize`` directory is empty because the :class:`~sentence_transformers.models.Normalize` module does not require any configuration. The ``sentence_bert_config.json`` file contains the configuration of the :class:`~sentence_transformers.models.Transformer` module, and this module also saved a lot of files related to the tokenizer and the model itself in the root directory. + +Loading Sentence Transformer Models ++++++++++++++++++++++++++++++++++++ + +To load a Sentence Transformer model from a saved model directory, the ``modules.json`` is read to determine the modules that make up the model. Each module is initialized with the configuration stored in the corresponding module directory, after which the SentenceTransformer class is instantiated with the loaded modules. + +Sentence Transformer Model from a Transformers Model +---------------------------------------------------- + +When you initialize a Sentence Transformer model with a pure Transformers model (e.g., BERT, RoBERTa, DistilBERT, T5), Sentence Transformers creates a Transformer module and a Mean Pooling module by default. This provides a simple way to leverage pre-trained language models for sentence embeddings. + +To be specific, these two snippets are identical:: + + from sentence_transformers import SentenceTransformer + + model = SentenceTransformer("bert-base-uncased") + +:: + + from sentence_transformers import models, SentenceTransformer + + transformer = models.Transformer("bert-base-uncased") + pooling = models.Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean") + model = SentenceTransformer(modules=[transformer, pooling]) + +Advanced: Custom Modules +++++++++++++++++++++++++ + +To create custom Sentence Transformer models, you can implement your own modules by subclassing PyTorch's :class:`torch.nn.Module` class and implementing these methods: + +* A :meth:`torch.nn.Module.forward` method that accepts a ``features`` dictionary with keys like ``input_ids``, ``attention_mask``, ``token_type_ids``, ``token_embeddings``, and ``sentence_embedding``, depending on where the module is in the model pipeline. +* A ``save`` method that accepts a ``save_dir`` argument and saves the module's configuration to that directory. +* A ``load`` static method that accepts a ``load_dir`` argument and initializes the Module given the module's configuration from that directory. +* (If 1st module) A ``get_max_seq_length`` method that returns the maximum sequence length the module can process. Only required if the module processes input text. +* (If 1st module) A ``tokenize`` method that accepts a list of inputs and returns a dictionary with keys like ``input_ids``, ``attention_mask``, ``token_type_ids``, ``pixel_values``, etc. This dictionary will be passed along to the module's ``forward`` method. +* (Optional) A ``get_sentence_embedding_dimension`` method that returns the dimensionality of the sentence embeddings produced by the module. Only required if the module generated the embeddings or updates the embeddings' dimensionality. +* (Optional) A ``get_config_dict`` method that returns a dictionary with the module's configuration. This method can be used to save the module's configuration to disk and to save the module config in a model card. + +For example, we can create a custom pooling method by implementing a custom Module. + +.. code-block:: python + + # decay_pooling.py + + import json + import os + import torch + import torch.nn as nn + + class DecayMeanPooling(nn.Module): + def __init__(self, dimension: int, decay: float = 0.95) -> None: + super(DecayMeanPooling, self).__init__() + self.dimension = dimension + self.decay = decay + + def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict [str, torch.Tensor]: + token_embeddings = features["token_embeddings"] + attention_mask = features["attention_mask"].unsqueeze(-1) + + # Apply the attention mask to filter away padding tokens + token_embeddings = token_embeddings * attention_mask + # Calculate mean of token embeddings + sentence_embeddings = token_embeddings.sum(1) / attention_mask.sum(1) + # Apply exponential decay + importance_per_dim = self.decay ** torch.arange(sentence_embeddings. size(1), device=sentence_embeddings.device) + features["sentence_embedding"] = sentence_embeddings * importance_per_dim + return features + + def get_config_dict(self) -> dict[str, float]: + return {"dimension": self.dimension, "decay": self.decay} + + def get_sentence_embedding_dimension(self) -> int: + return self.dimension + + def save(self, save_dir: str, **kwargs) -> None: + with open(os.path.join(save_dir, "config.json"), "w") as fOut: + json.dump(self.get_config_dict(), fOut, indent=4) + + def load(load_dir: str, **kwargs) -> "DecayMeanPooling": + with open(os.path.join(load_dir, "config.json")) as fIn: + config = json.load(fIn) + + return DecayMeanPooling(**config) + +.. note:: + + Adding ``**kwargs`` to the ``__init__``, ``forward``, ``save``, ``load``, and ``tokenize`` methods is recommended to ensure that the methods are compatible with future updates to the Sentence Transformers library. + +This can now be used as a module in a Sentence Transformer model:: + + from sentence_transformers import models, SentenceTransformer + from decay_pooling import DecayMeanPooling + + transformer = models.Transformer("bert-base-uncased", max_seq_length=256) + decay_mean_pooling = DecayMeanPooling(transformer.get_word_embedding_dimension(), decay=0.99) + normalize = models.Normalize() + + model = SentenceTransformer(modules=[transformer, decay_mean_pooling, normalize]) + print(model) + """ + SentenceTransformer( + (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: BertModel + (1): DecayMeanPooling() + (2): Normalize() + ) + """ + + texts = [ + "Hello, World!", + "The quick brown fox jumps over the lazy dog.", + "I am a sentence that is used for testing purposes.", + "This is a test sentence.", + "This is another test sentence.", + ] + embeddings = model.encode(texts) + print(embeddings.shape) + # [5, 384] + +You can save this model with :meth:`SentenceTransformer.save_pretrained `, resulting in a ``modules.json`` of:: + + [ + { + "idx": 0, + "name": "0", + "path": "", + "type": "sentence_transformers.models.Transformer" + }, + { + "idx": 1, + "name": "1", + "path": "1_DecayMeanPooling", + "type": "decay_pooling.DecayMeanPooling" + }, + { + "idx": 2, + "name": "2", + "path": "2_Normalize", + "type": "sentence_transformers.models.Normalize" + } + ] + +To ensure that ``decay_pooling.DecayMeanPooling`` can be imported, you should copy over the ``decay_pooling.py`` file to the directory where you saved the model. If you push the model to the `Hugging Face Hub `_, then you should also upload the ``decay_pooling.py`` file to the model's repository. Then, everyone can use your custom module by calling :meth:`SentenceTransformer("your-username/your-model-id", trust_remote_code=True) `. + +.. note:: + + Using a custom module with remote code stored on the Hugging Face Hub requires that your users specify ``trust_remote_code`` as ``True`` when loading the model. This is a security measure to prevent remote code execution attacks. + +If you have your models and custom modelling code on the Hugging Face Hub, then it might make sense to separate your custom modules into a separate repository. This way, you only have to maintain one implementation of your custom module, and you can reuse it across multiple models. You can do this by updating the ``type`` in ``modules.json`` file to include the path to the repository where the custom module is stored like ``{repository_id}--{dot_path_to_module}``. For example, if the ``decay_pooling.py`` file is stored in a repository called ``my-user/my-model-implementation`` and the module is called ``DecayMeanPooling``, then the ``modules.json`` file may look like this:: + + [ + { + "idx": 0, + "name": "0", + "path": "", + "type": "sentence_transformers.models.Transformer" + }, + { + "idx": 1, + "name": "1", + "path": "1_DecayMeanPooling", + "type": "my-user/my-model-implementation--decay_pooling.DecayMeanPooling" + }, + { + "idx": 2, + "name": "2", + "path": "2_Normalize", + "type": "sentence_transformers.models.Normalize" + } + ] + +Advanced: Keyword argument passthrough in Custom Modules +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +If you want your users to be able to specify custom keyword arguments via the :meth:`SentenceTransformer.encode ` method, then you can add their names to the ``modules.json`` file. For example, if my module should behave differently if your users specify a ``task_type`` keyword argument, then your ``modules.json`` might look like:: + + [ + { + "idx": 0, + "name": "0", + "path": "", + "type": "custom_transformer.CustomTransformer", + "kwargs": ["task_type"] + }, + { + "idx": 1, + "name": "1", + "path": "1_Pooling", + "type": "sentence_transformers.models.Pooling" + }, + { + "idx": 2, + "name": "2", + "path": "2_Normalize", + "type": "sentence_transformers.models.Normalize" + } + ] + +Then, you can access the ``task_type`` keyword argument in the ``forward`` method of your custom module:: + + from sentence_transformers.models import Transformer + + class CustomTransformer(Transformer): + def forward(self, features: dict[str, torch.Tensor], task_type: Optional[str] = None) -> dict[str, torch.Tensor]: + if task_type == "default": + # Do something + else: + # Do something else + return features + +This way, users can specify the ``task_type`` keyword argument when calling :meth:`SentenceTransformer.encode `:: + + from sentence_transformers import SentenceTransformer + + model = SentenceTransformer("your-username/your-model-id", trust_remote_code=True) + texts = [...] + model.encode(texts, task_type="default") diff --git a/docs/sentence_transformer/usage/usage.rst b/docs/sentence_transformer/usage/usage.rst index a542db85c..c7cddc0e6 100644 --- a/docs/sentence_transformer/usage/usage.rst +++ b/docs/sentence_transformer/usage/usage.rst @@ -56,4 +56,5 @@ Once you have `installed <../../installation.html>`_ Sentence Transformers, you ../../../examples/applications/parallel-sentence-mining/README ../../../examples/applications/image-search/README ../../../examples/applications/embedding-quantization/README + custom_models.md diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 38924b2b0..86abbbcd3 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -7,6 +7,8 @@ import math import os import queue +import shutil +import sys import tempfile import traceback import warnings @@ -25,6 +27,7 @@ from torch import Tensor, device, nn from tqdm.autonotebook import trange from transformers import is_torch_npu_available +from transformers.dynamic_module_utils import get_class_from_dynamic_module, get_relative_import_files from sentence_transformers.model_card import SentenceTransformerModelCardData, generate_model_card from sentence_transformers.similarity_functions import SimilarityFunction @@ -170,6 +173,7 @@ def __init__( self.trust_remote_code = trust_remote_code self.truncate_dim = truncate_dim self.model_card_data = model_card_data or SentenceTransformerModelCardData() + self.module_kwargs = None self._model_card_vars = {} self._model_card_text = None self._model_config = {} @@ -287,7 +291,7 @@ def __init__( revision=revision, local_files_only=local_files_only, ): - modules = self._load_sbert_model( + modules, self.module_kwargs = self._load_sbert_model( model_name_or_path, token=token, cache_folder=cache_folder, @@ -368,6 +372,7 @@ def encode( convert_to_tensor: Literal[False] = ..., device: str = ..., normalize_embeddings: bool = ..., + **kwargs, ) -> Tensor: ... @overload @@ -384,6 +389,7 @@ def encode( convert_to_tensor: Literal[False] = ..., device: str = ..., normalize_embeddings: bool = ..., + **kwargs, ) -> np.ndarray: ... @overload @@ -400,6 +406,7 @@ def encode( convert_to_tensor: Literal[True] = ..., device: str = ..., normalize_embeddings: bool = ..., + **kwargs, ) -> Tensor: ... @overload @@ -416,6 +423,7 @@ def encode( convert_to_tensor: Literal[False] = ..., device: str = ..., normalize_embeddings: bool = ..., + **kwargs, ) -> list[Tensor]: ... def encode( @@ -431,6 +439,7 @@ def encode( convert_to_tensor: bool = False, device: str = None, normalize_embeddings: bool = False, + **kwargs, ) -> list[Tensor] | np.ndarray | Tensor: """ Computes sentence embeddings. @@ -579,7 +588,7 @@ def encode( features.update(extra_features) with torch.no_grad(): - out_features = self.forward(features) + out_features = self.forward(features, **kwargs) if self.device.type == "hpu": out_features = copy.deepcopy(out_features) @@ -639,6 +648,16 @@ def encode( return all_embeddings + def forward(self, input: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]: + if self.module_kwargs is None: + return super().forward(input) + + for module_name, module in self.named_children(): + module_kwarg_keys = self.module_kwargs.get(module_name, []) + module_kwargs = {key: value for key, value in kwargs.items() if key in module_kwarg_keys} + input = module(input, **module_kwargs) + return input + @property def similarity_fn_name(self) -> str | None: """Return the name of the similarity function used by :meth:`SentenceTransformer.similarity` and :meth:`SentenceTransformer.similarity_pairwise`. @@ -1099,7 +1118,7 @@ def save( # Save modules for idx, name in enumerate(self._modules): module = self._modules[name] - if idx == 0 and isinstance(module, Transformer): # Save transformer model in the main folder + if idx == 0: # Save first module in the main folder model_path = path + "/" else: model_path = os.path.join(path, str(idx) + "_" + type(module).__name__) @@ -1111,9 +1130,28 @@ def save( except TypeError: module.save(model_path) - modules_config.append( - {"idx": idx, "name": name, "path": os.path.basename(model_path), "type": type(module).__module__} - ) + # "module" only works for Sentence Transformers as the modules have the same names as the classes + class_ref = type(module).__module__ + # For remote modules, we want to remove "transformers_modules.{repo_name}": + if class_ref.startswith("transformers_modules."): + class_file = sys.modules[class_ref].__file__ + + # Save the custom module file + dest_file = Path(model_path) / (Path(class_file).name) + shutil.copy(class_file, dest_file) + + # Save all files importeed in the custom module file + for needed_file in get_relative_import_files(class_file): + dest_file = Path(model_path) / (Path(needed_file).name) + shutil.copy(needed_file, dest_file) + + # For remote modules, we want to ignore the "transformers_modules.{repo_id}" part, + # i.e. we only want the filename + class_ref = f"{class_ref.split('.')[-1]}.{type(module).__name__}" + # For other cases, we want to add the class name: + elif not class_ref.startswith("sentence_transformers."): + class_ref = f"{class_ref}.{type(module).__name__}" + modules_config.append({"idx": idx, "name": name, "path": os.path.basename(model_path), "type": class_ref}) with open(os.path.join(path, "modules.json"), "w") as fOut: json.dump(modules_config, fOut, indent=2) @@ -1414,6 +1452,28 @@ def _load_auto_model( self.model_card_data.set_base_model(model_name_or_path, revision=revision) return [transformer_model, pooling_model] + def _load_module_class_from_ref( + self, class_ref: str, model_name_or_path: str, trust_remote_code: bool, model_kwargs: dict[str, Any] | None + ) -> nn.Module: + # If the class is from sentence_transformers, we can directly import it, + # otherwise, we try to import it dynamically, and if that fails, we fall back to the default import + if class_ref.startswith("sentence_transformers."): + return import_from_string(class_ref) + + if trust_remote_code: + code_revision = model_kwargs.pop("code_revision", None) if model_kwargs else None + try: + return get_class_from_dynamic_module( + class_ref, + model_name_or_path, + code_revision=code_revision, + ) + except OSError: + # Ignore the error if the file does not exist, and fall back to the default import + pass + + return import_from_string(class_ref) + def _load_sbert_model( self, model_name_or_path: str, @@ -1504,11 +1564,16 @@ def _load_sbert_model( modules_config = json.load(fIn) modules = OrderedDict() + module_kwargs = OrderedDict() for module_config in modules_config: - module_class = import_from_string(module_config["type"]) + class_ref = module_config["type"] + module_class = self._load_module_class_from_ref( + class_ref, model_name_or_path, trust_remote_code, model_kwargs + ) + # For Transformer, don't load the full directory, rely on `transformers` instead # But, do load the config file first. - if module_class == Transformer and module_config["path"] == "": + if module_config["path"] == "": kwargs = {} for config_name in [ "sentence_bert_config.json", @@ -1566,7 +1631,12 @@ def _load_sbert_model( if config_kwargs: kwargs["config_args"].update(config_kwargs) - module = Transformer(model_name_or_path, cache_dir=cache_folder, **kwargs) + # Try to initialize the module with a lot of kwargs, but only if the module supports them + # Otherwise we fall back to the load method + try: + module = module_class(model_name_or_path, cache_dir=cache_folder, **kwargs) + except TypeError: + module = module_class.load(model_name_or_path) else: # Normalize does not require any files to be loaded if module_class == Normalize: @@ -1582,6 +1652,7 @@ def _load_sbert_model( ) module = module_class.load(module_path) modules[module_config["name"]] = module + module_kwargs[module_config["name"]] = module_config.get("kwargs", []) if revision is None: path_parts = Path(modules_json_path) @@ -1590,7 +1661,7 @@ def _load_sbert_model( if len(revision_path_part) == 40: revision = revision_path_part self.model_card_data.set_base_model(model_name_or_path, revision=revision) - return modules + return modules, module_kwargs @staticmethod def load(input_path) -> SentenceTransformer: diff --git a/sentence_transformers/models/Pooling.py b/sentence_transformers/models/Pooling.py index d0535045a..85ad3f2a4 100644 --- a/sentence_transformers/models/Pooling.py +++ b/sentence_transformers/models/Pooling.py @@ -132,7 +132,11 @@ def get_pooling_mode_str(self) -> str: def forward(self, features: dict[str, Tensor]) -> dict[str, Tensor]: token_embeddings = features["token_embeddings"] - attention_mask = features["attention_mask"] + attention_mask = ( + features["attention_mask"] + if "attention_mask" in features + else torch.ones(token_embeddings.shape[:-1], device=token_embeddings.device, dtype=torch.int64) + ) if not self.include_prompt and "prompt_length" in features: attention_mask[:, : features["prompt_length"]] = 0 diff --git a/sentence_transformers/models/Transformer.py b/sentence_transformers/models/Transformer.py index 4b19d3815..ea66a13cb 100644 --- a/sentence_transformers/models/Transformer.py +++ b/sentence_transformers/models/Transformer.py @@ -109,13 +109,13 @@ def _load_mt5_model(self, model_name_or_path, config, cache_dir, **model_args) - def __repr__(self) -> str: return f"Transformer({self.get_config_dict()}) with Transformer model: {self.auto_model.__class__.__name__} " - def forward(self, features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]: """Returns token_embeddings, cls_token""" trans_features = {"input_ids": features["input_ids"], "attention_mask": features["attention_mask"]} if "token_type_ids" in features: trans_features["token_type_ids"] = features["token_type_ids"] - output_states = self.auto_model(**trans_features, return_dict=False) + output_states = self.auto_model(**trans_features, **kwargs, return_dict=False) output_tokens = output_states[0] features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})