Skip to content

Does SpinQuant implemented R3 when using quantized kv cache? #9705

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
WeiMa01 opened this issue Mar 27, 2025 · 4 comments
Open

Does SpinQuant implemented R3 when using quantized kv cache? #9705

WeiMa01 opened this issue Mar 27, 2025 · 4 comments
Assignees
Labels
module: llm Issues related to LLM examples and apps, and to the extensions/llm/ code module: quantization Issues related to quantization

Comments

@WeiMa01
Copy link

WeiMa01 commented Mar 27, 2025

When we execute SpinQuant using ExecuTorch, we observe that only R4 supports online rotation, while R3 does not. We would like to confirm whether ExecuTorch does not support R3 for SpinQuant.

  1. convert to pte, already enable quantize_kv_cache
    python -m examples.models.llama.export_llama
    --model "llama3_2"
    --checkpoint "/home/zhuan.zhang/llama_models/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8/consolidated.00.pth"
    --params "/home/zhuan.zhang/llama_models/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8/params.json"
    --use_sdpa_with_kv_cache
    -X
    --xnnpack-extended-ops
    --preq_mode 8da4w_output_8da8w
    --preq_group_size 32
    --max_seq_length 2048
    --max_seq_length 2048
    --output_name "llama3_2.pte"
    -kv
    -d fp32
    --preq_embedding_quantize 8,0
    --quantize_kv_cache
    --output_name 'llama3_2_spinquant_qkv.pte'
    --use_spin_quant native
    --generate_etrecord

  2. Runtime delegate op show "llama_fast_hadamard_transform_default" calling 16 times(1 time / decoder layer), which is R4
    | op_type | occurrences_in_delegated_graphs | occurrences_in_non_delegated_graphs |
    19 | llama_fast_hadamard_transform | 0 | 16 |

  3. Source code show using SpinQuant, which just replace FeedForward with FeedForwardNativeCustom using inject_fast_hadamard_transform_native_for_spin_quant

def _get_source_transforms(  # noqa
    modelname: str, dtype_override: Optional[DType], args
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
    transforms = []
    if args.use_spin_quant:
        if args.use_spin_quant == "cuda":
            from .source_transformation.spin_quant import (
                inject_fast_hadamard_transform_cuda_for_spin_quant,
            )
            transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)
        elif args.use_spin_quant == "native":
            from .source_transformation.spin_quant import (
                inject_fast_hadamard_transform_native_for_spin_quant,
            )
            transforms.append(inject_fast_hadamard_transform_native_for_spin_quant)
def _inject_fast_hadamard_transform_native_for_spin_quant(module: torch.nn.Module):
    """
    SpinQuant needs two Hadmard matrixes: R3 and R4. Here we are only injecting R4 in the feed forward layer.
    R3 needs to be injected as well when KV cache quantization is enabled.
    """
    class FeedForwardNativeCustom(nn.Module):
        def __init__(self, w1, w2, w3):
            super().__init__()
            self.w1 = w1
            self.w2 = w2
            self.w3 = w3
        def forward(self, x):
            return self.w2(
                torch.ops.llama.fast_hadamard_transform(F.silu(self.w1(x)) * self.w3(x))
            )
    for name, child in module.named_children():
        if isinstance(child, FeedForward):
            setattr(module, name, FeedForwardNativeCustom(child.w1, child.w2, child.w3))
        else:
            _inject_fast_hadamard_transform_native_for_spin_quant(child)

cc @kimishpatel @jerryzh168 @larryliu0820 @mergennachin @cccclai @helunwencser @jackzhxng

@JacobSzwejbka
Copy link
Contributor

@kimishpatel @helunwencser I think you guys worked on spinquant/ quantized kv cache stuff can you take a look?

@JacobSzwejbka JacobSzwejbka added module: llm Issues related to LLM examples and apps, and to the extensions/llm/ code module: quantization Issues related to quantization labels Mar 27, 2025
@github-project-automation github-project-automation bot moved this to To triage in ExecuTorch Core Mar 27, 2025
@helunwencser
Copy link
Contributor

At the time when we enabled SpinQuant, we did not have kv cache quantization. So R3 is not injected right now. Given that we have kv cache quantization now, we should be able to enable kv cache quantization and inject R3 in SpinQuant.

@WeiMa01
Copy link
Author

WeiMa01 commented Mar 28, 2025

At the time when we enabled SpinQuant, we did not have kv cache quantization. So R3 is not injected right now. Given that we have kv cache quantization now, we should be able to enable kv cache quantization and inject R3 in SpinQuant.

Yes, I see the ExecuTorch has supported quantize kv cache, but no inject R3. You mean you team will support inject R3 in the future?

@jackzhxng
Copy link
Contributor

No plans to inject R3 atm, would you like to help with this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: llm Issues related to LLM examples and apps, and to the extensions/llm/ code module: quantization Issues related to quantization
Projects
Status: To triage
Development

No branches or pull requests

4 participants