Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cross Encoder support #251

Open
michalwarda opened this issue Sep 21, 2023 · 4 comments
Open

Cross Encoder support #251

michalwarda opened this issue Sep 21, 2023 · 4 comments
Labels
kind:feature New feature or request

Comments

@michalwarda
Copy link

Hi, I'm currently trying to implement a feature called "hybrid search" inside of my application. It's based on returning query results from multiple databases and later scoring results together from multiple sources. To score them I want to use cross-encoder models ie. https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-6-v2.

I'm trying to understand if Bumblebee currently supports models like this and if so how to use it for that.

If it does I'll be very happy to write some documentation for that after getting some hints. If not it would be a very cool feature to handle those types of operations :)

@jonatanklosko
Copy link
Member

Hey @michalwarda! The repository uses the BERT, so Bumblebee supports it, however we don't have any serving that fits cross-encoder. So currently you could do this:

{:ok, model_info} = Bumblebee.load_model({:hf, "cross-encoder/ms-marco-MiniLM-L-6-v2"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "bert-base-uncased"})

inputs =
  Bumblebee.apply_tokenizer(tokenizer, [
    {"How many people live in Berlin?",
     "Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers."},
    {"How many people live in Berlin?",
     "New York City is famous for the Metropolitan Museum of Art."}
  ])

outputs = Axon.predict(model_info.model, model_info.params, inputs)

outputs.logits
#=> #Nx.Tensor<
#=>   f32[2][1]
#=>   EXLA.Backend<fake_gpu:0, 0.1151319922.1832779796.76254>
#=>   [
#=>     [8.845853805541992],
#=>     [-11.245560646057129]
#=>   ]
#=> >

I think we can add a serving like Bumblebee.Text.cross_encoding to optimise for this use case. Ideally we would also open PRs with tokenizer.json in the HF repositories, because in this case it's far from obvious that bert-base-uncased is the place to look for.

@samrat
Copy link

samrat commented Oct 17, 2023

Thank you for the hint about bert-base-uncased :)

I am also interested in this use-case. Is this a feature that will be added?

@jonatanklosko
Copy link
Member

Thank you for the hint about bert-base-uncased :)

We now have a more specific error message when a tokenizer is missing in the repository and suggested steps to get a compatible one, so hopefully it should be more intuitive without guessing repositories :)

I am also interested in this use-case. Is this a feature that will be added?

Yeah, I think Bumblebee.Text.cross_encoding makes sense. It's not the top priority right now, but contributions are also welcome.

@toranb
Copy link

toranb commented Nov 11, 2023

Huge thanks to @jonatanklosko for sharing this solution! I ran into this today and wanted to share a working Nx serving I put together for my use case (RAG with mixed search using Postgres PGVector and full text search together)

toranb/rag-n-drop@4b515ed

I included a working use case for those who might bump into this later on :)

@jonatanklosko jonatanklosko added the kind:feature New feature or request label Dec 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kind:feature New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants