Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Allow loading custom modules; encode kwargs passthrough to modules #2773

Merged
merged 10 commits into from
Sep 10, 2024

Conversation

tomaarsen
Copy link
Collaborator

Hello!

Pull Request overview

  • Allow loading custom modules

Details

This is an experimental branch to allow models with custom architectures to integrate with Sentence Transformers. Behind the scenes, Sentence Transformer models always rely on a list of modules, with configuration in modules.json. Usually, this contains a Transformer followed by a Pooler and optionally a Normalize and/or Dense module, but we can now fully change this up.

I have two implementation formats that should both work:

  1. tomaarsen/jina-clip-v1-st hosts all custom modeling files, notably custom_st.py. This class is loaded via the modules.json. This class handles tokenization/preprocessing of texts/images, and implements the forward. The forward method accepts a dictionary of features from the tokenization, and then must output a dictionary. Some modules might output sentence_embedding, while others may output token_embeddings. The latter can then still be used with the Pooling module, which will convert the token embeddings into sentence embeddings.
  2. tomaarsen/jina-clip-v1-st-remote and tomaarsen/jina-clip-implementation-st are similar, but here the modeling file is in a separate repository (the latter). The modules.json file now uses tomaarsen/jina-clip-implementation-st--custom_st.Transformer, and this repository hosts custom_st.py. Otherwise, everything is the same as the other implementation.

Note: you don't have to use Transformer or custom_st as names, you can be flexible here.

Usage

You can now use these custom modules in Sentence Transformers directly via trust_remote_code=True:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("tomaarsen/jina-clip-v1-st-remote", trust_remote_code=True)
# or
# model = SentenceTransformer("tomaarsen/jina-clip-v1-st", trust_remote_code=True)

# New meaningful sentences
sentences = ['A blue cat', 'A red cat']

# Public image URLs
image_urls = [
    'https://i.pinimg.com/600x315/21/48/7e/21487e8e0970dd366dafaed6ab25d8d8.jpg',
    'https://i.pinimg.com/736x/c9/f2/3e/c9f23e212529f13f19bad5602d84b78b.jpg'
]

text_embeddings = model.encode(sentences)
image_embeddings = model.encode(image_urls)

# Compute similarities
print(model.similarity(text_embeddings[0], text_embeddings[1])) # text embedding similarity
print(model.similarity(text_embeddings, image_embeddings)) # text-image cross-modal similarity
"""
tensor([[0.5636]])
tensor([[0.2906, 0.0569],
        [0.1277, 0.2916]]
"""

I believe this opens up a lot of opportunities for finetuning custom architectures, too. Even modalities not previously used with Sentence Transformers (e.g. audio?) should be totally feasible.

cc @bwanglzu

  • Tom Aarsen

Comment on lines 1362 to 1377
if class_ref.startswith("sentence_transformers."):
return import_from_string(class_ref)

if trust_remote_code:
code_revision = model_kwargs.pop("code_revision", None) if model_kwargs else None
try:
return get_class_from_dynamic_module(
class_ref,
model_name_or_path,
code_revision=code_revision,
)
except EnvironmentError:
# Ignore the error if the file does not exist, and fall back to the default import
pass

return import_from_string(class_ref)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@muellerzr I'd love to get your opinion on whether this might be problematic in terms of security. The first if imports something from the Sentence Transformers module, which should be fine unless a malicious package is installed that uses sentence_transformers as the import name.

The second if imports the class via the transformers remote code functionality. I assume this is safe.

The third if imports whatever the class reference is, so this can also be a class in a malicious module. But the user would already have to have that malicious package installed. I'm not too concerned about this case, as preventing this case also prevents third parties from creating non-malicious custom ST Modules.

  • Tom Aarsen

Copy link
Contributor

@bwanglzu bwanglzu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this PR looks good to me. I'll do some manual testing on this branch using different models to make sure everything works correctly

@tomaarsen tomaarsen changed the title [feat] Allow loading custom modules [feat] Allow loading custom modules; encode kwargs passthrough to modules Sep 10, 2024
@tomaarsen tomaarsen merged commit 6257cb0 into UKPLab:master Sep 10, 2024
11 checks passed
@tomaarsen tomaarsen deleted the feat/custom_modules branch September 10, 2024 15:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants