You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I see there is a Qwen2 7B model. I am trying to use SHALlamaAttention with the Qwen2 0.5B model, but I am getting gibberish. How do I adapt it here?
To Reproduce
Based on the llama3 code, modify the code to adapt Qwen:
models/_shared/llama3/model.py:
from transformers.models.qwen2 import modeling_qwen2 as modeling_llama
modeling_llama.QWEN2_ATTENTION_CLASSES["eager"] = SHALlamaAttention
modeling_llama.Qwen2RMSNorm.forward = LlamaRMSNorm_forward
modeling_llama.Qwen2RotaryEmbedding.forward = bypass_RotaryEmbedding
model = modeling_llama.Qwen2ForCausalLM.from_pretrained(
self.huggingface_model_name,
config=self.llm_config,
ignore_mismatched_sizes=_make_small_for_debugging,
)
I reviewed the code for qwen2_7b_instruct_quantized, but this code uses a model that has already been converted by QNN. I want to use the native Qwen2 model with SHALlamaAttention.
Describe the issue
I see there is a Qwen2 7B model. I am trying to use SHALlamaAttention with the Qwen2 0.5B model, but I am getting gibberish. How do I adapt it here?
To Reproduce
Based on the llama3 code, modify the code to adapt Qwen:
models/_shared/llama3/model.py:
from transformers.models.qwen2 import modeling_qwen2 as modeling_llama
modeling_llama.QWEN2_ATTENTION_CLASSES["eager"] = SHALlamaAttention
modeling_llama.Qwen2RMSNorm.forward = LlamaRMSNorm_forward
modeling_llama.Qwen2RotaryEmbedding.forward = bypass_RotaryEmbedding
model = modeling_llama.Qwen2ForCausalLM.from_pretrained(
self.huggingface_model_name,
config=self.llm_config,
ignore_mismatched_sizes=_make_small_for_debugging,
)
models/_shared/llama/model.py:
def init(self, head_dim: int = 64, max_length: int = 1024)
def precompute_freqs_cis(self, dim: int, end: int, theta: float = 1000000.0)
Expected behavior
Get valid results.
how can i use SHALlamaAttention with the Qwen2 0.5B model,
Stack trace
If applicable, add screenshots to help explain your problem.
Host configuration:
python 3.10.0
QNN 2.26.0.240828
transformers 4.46.3
Additional context
Add any other context about the problem here.
The text was updated successfully, but these errors were encountered: