Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can I select causal attention for retrieval embeddings when using GritLM #45

Open
Yangseung opened this issue Jul 2, 2024 · 17 comments

Comments

@Yangseung
Copy link

In the paper, the ablation study about attention emb and gen is interesting.

Are these models all different models using each attention?

Can I select causal attention for both cases when using GritLM-7B?
If not, could you share the model using causal attention for both?

I selected 'cccc' and 'mean' for retrieval embeddings, but the performance was significant degraded :(

@Muennighoff
Copy link
Collaborator

The models are all uploaded here: https://huggingface.co/collections/GritLM/gritlm-65cc1403df78d51bb89f1651
and linked to from the Appendix table of the paper: https://arxiv.org/pdf/2402.09906
It includes cccc models, for best performance, you probably want the one with cccc & weightedmean. Lmk if sth is unclear.

@Yangseung
Copy link
Author

Thanks for your help.

I found the models for CCCC WM and CCCC LT.
CCCC WM https://hf.co/GritLM/gritlm_m7_sq2048_medi
CCCC LT https://hf.co/GritLM/gritlm_m7_sq2048_medi_lasttoken

I wonder if there is a CCCC & Mean model available and if there is severe performance degradation using the WM. I'm curious if using the CCCC & WM model is permissible for CCCC & Mean.

@Muennighoff
Copy link
Collaborator

There's no trained CCCC + M model because WM is better. Why do you not want to use WM?

@Yangseung
Copy link
Author

Yangseung commented Jul 2, 2024

I am just wondering why CCCC & M models have lower performance than CCCC & WM models.

Also, I have one question.
Excluding the prompt section, can cache reuse still be possible when using this model? If so, could it potentially improve the performance of Table 7 on the main paper?

@Muennighoff
Copy link
Collaborator

I am just wondering why CCCC & M models have lower performance than CCCC & WM models.

The intuition for why WM is better than M for decoders and more comparisons are in this paper: https://arxiv.org/pdf/2202.08904 ; It's not super well-written though, sorry!

Also, I have one question. Excluding the prompt section, can cache reuse still be possible when using this model? If so, could it potentially improve the performance of Table 7 on the main paper?

Yes if you use CCCC then the performance in Table 7 with caching will be much better (at the cost of slightly worse retrieval as BB is better).

I selected 'cccc' and 'mean' for retrieval embeddings, but the performance was significant degraded :(

I'm surprised GritLM-7B gets that much worse if you do cccc & mean. Did you benchmark on MTEB how much of a difference it makes? It is probably worth checking that & comparing with "CCCC WM https://hf.co/GritLM/gritlm_m7_sq2048_medi".

@Yangseung
Copy link
Author

Thank you for sharing your good paper :)

We are experimenting with GritLM in our application, and the top-2 hit rate is around 98% using bidirectional, but when using causal, the performance is below 50%.

@Muennighoff
Copy link
Collaborator

We are experimenting with GritLM in our application, and the top-2 hit rate is around 98% using bidirectional, but when using causal, the performance is below 50%.

Oh that's poor ; https://hf.co/GritLM/gritlm_m7_sq2048_medi2 should be even better than https://hf.co/GritLM/gritlm_m7_sq2048_medi (Table 15). Unfortunately, I didn't train a CCCC E5 model :/ I think the caching issue can be easily solved by further finetuning GritLM to make it get used to that format but I haven't had the time to try.

@Yangseung
Copy link
Author

Thanks for your suggestion.
I think that the suggested model is the embedding-only model.
Our goal is to share both the retriever and the generator, and to reduce inference cost by also sharing the cache.
I will try to utilize cccc & wm. Thank you.

@Muennighoff
Copy link
Collaborator

Oh https://huggingface.co/GritLM/gritlm_m7_sq2048_medi2 should be both embedding & generation.
Did GritLM-7B not work well for your use case? Do you want to cache docs or queries or both?

@Yangseung
Copy link
Author

In our case, we want to cache queries.
In the GritLM-7B case, query caches are calculated using bidirectional attention, so the generation performance is degraded a lot when using our instruction (prompt).
Therefore, we want to use [https://huggingface.co/GritLM/gritlm_m7_sq2048_medi2] model.

@Muennighoff
Copy link
Collaborator

I see, let me know how it goes!

@Yangseung
Copy link
Author

We are experiencing the following issue and are hoping you can provide some comments or solutions:

We are using the aforementioned causal/causal model and are proceeding with the following process:
Step1: We perform query forwarding for retrieval embedding, store that cache, and extract the Top-2 documents.
Step2: We put the query cache into the past key values, forward the documents, and generate the response.

When performing the process as described above, the results differ between when the cache is stored and when it is not. When we input the query and documents without storing the cache, the response is accurate. However, when we input the cache and documents, the response becomes completely strange (with repetitive phrases or nonsensical output).

We found that it seems that the causal/causal model uses a system prompt at the beginning unlike the released GritLM, and even after removing it for evaluation, we still get similarly strange and incorrect results.

@Muennighoff
Copy link
Collaborator

causal/causal model uses a system prompt

What do you mean by this? There should be no system prompt besides the formatting i.e. <|embed|> etc. If you include the formatting in both with and without the cache, the results should be exactly the same.

@Yangseung
Copy link
Author

It seems these two outcomes should be exactly the same, but since different results are coming out, I am asking this question.

When using both the query and text, the following format is utilized: FULL_FORMAT = "<|embed|>\n{query}\n<|user|>\n {text}\n\nOptionally using the prior context answer the query prior to it\n<|assistant|>\n"

When forwarding the query, "<|embed|>\n" + query is included, and when forwarding the text, the format "\n<|user|>\n {text}\n\nOptionally using the prior context answer the query prior to it\n<|assistant|>\n" is utilized.

It seems that both should be exactly the same, but the result of 1) comes out much better.

@Muennighoff
Copy link
Collaborator

Yes, that looks good to me and it should be the same. A few checks:

  • Where is the boss token <s> being added? Make sure this is also the same. The tokenizer adds it automatically unless you pass add_special_tokens=False
  • Try forwarding "<|embed|>\n{query}\n<|user|>\n {text}\n\nOptionally using the prior context answer the query prior to it\n<|assistant|>\n" and check that the key value states are the same for the first X tokens as with "<|embed|>\n" + query
  • Ensure you use causal attention, the correct pooling method, and are passing the key value states

@Yangseung
Copy link
Author

Thank you for answering.

  1. I used add_special_tokens=True for query caching and add_special_tokens=False for generation. I checked the decoded result of the input, and confirmed that two inputs are same.
  2. Let me check :). Thanks for valuable comments
  3. In the generation stage, I think that causal attention is automatically used and pooling method is not used. Is it correct?

@Muennighoff
Copy link
Collaborator

  1. yes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants