Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Optimizing cross-attention QKVParallelLinear computation #12325

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,3 +1165,98 @@ def extra_repr(self) -> str:
s += f", tp_size={self.tp_size}"
s += f", reduce_results={self.reduce_results}"
return s


class QKVCrossParallelLinear(torch.nn.Module):

def __init__(self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: Optional[int] = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
# Empty placeholders for loading as a single module.
self.weight = torch.nn.Parameter()
set_weight_attrs(self.weight, {
"weight_loader": self.weight_loader_weight,
})
# Use a dictionary to avoid submodules parameters auto-registration:
# drop-in replacement for a `QKVParallelLinear` module.
self.proj = dict()
self.proj["q_proj_decoder"] = ColumnParallelLinear(
input_size=hidden_size,
output_size=total_num_heads * head_size,
bias=bias,
quant_config=quant_config,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
prefix=f"{prefix}.q_proj_decoder")

self.proj["kv_proj_encoder"] = QKVParallelLinear(
hidden_size=hidden_size,
head_size=head_size,
total_num_heads=0,
total_num_kv_heads=total_num_kv_heads,
bias=bias,
quant_config=quant_config,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
prefix=f"{prefix}.kv_proj_encoder")

# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size

if bias:
self.bias = torch.nn.Parameter()
set_weight_attrs(self.bias, {
"weight_loader": self.weight_loader_bias,
})

@property
def q_proj_decoder(self):
return self.proj["q_proj_decoder"]

@property
def kv_proj_encoder(self):
return self.proj["kv_proj_encoder"]

def forward(self, decoder_hidden_states, encoder_hidden_states):
q, _ = self.q_proj_decoder(decoder_hidden_states)
if encoder_hidden_states is None:
# Encoder KV already cached.
k = None
v = None
else:
# Prefill phase, encoder KV cached here.
kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
# Split kv in half
k, v = kv_enc.split(self.kv_size, dim=-1)
return q, k, v

def weight_loader_weight(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
# NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param.
param = self.q_proj_decoder.weight if loaded_shard_id == "q" \
else self.kv_proj_encoder.weight
param.weight_loader(
param,
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
param, loaded_weight, loaded_shard_id)

def weight_loader_bias(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
param = self.q_proj_decoder.bias if loaded_shard_id == "q" \
else self.kv_proj_encoder.bias
param.weight_loader(
param,
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
param, loaded_weight, loaded_shard_id)
39 changes: 13 additions & 26 deletions vllm/model_executor/models/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVCrossParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Expand Down Expand Up @@ -168,7 +169,7 @@ def __init__(
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
self.num_kv_heads = self.num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim

Expand Down Expand Up @@ -248,7 +249,7 @@ def __init__(
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
self.num_kv_heads = self.num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim

Expand Down Expand Up @@ -300,14 +301,14 @@ def __init__(
f" and `num_heads`: {num_heads}).")
self.scaling = self.head_dim**-0.5

self.qkv_proj = QKVParallelLinear(
self.d_model,
self.d_model // self.total_num_heads,
self.total_num_heads,
self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
)
# TP sharding sizes is accounted for within "*Parallel" layers.
self.qkv_proj = QKVCrossParallelLinear(self.d_model,
self.d_model //
self.total_num_heads,
self.total_num_heads,
self.total_num_kv_heads,
bias,
quant_config=quant_config)

self.out_proj = RowParallelLinear(
embed_dim,
Expand All @@ -328,10 +329,7 @@ def __init__(
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim

self.num_kv_heads = self.num_heads # No GQA in bart
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
Expand All @@ -350,18 +348,7 @@ def forward(
) -> torch.Tensor:
"""Input shape: Batch x Time x Channel"""

# (afeldman-nm 2024/07/22) TODO:
# Need a more efficient solution for q/k/v
qkv_dec, _ = self.qkv_proj(decoder_hidden_states)
q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
if encoder_hidden_states is None:
k = None
v = None
else:
qkv_enc, _ = self.qkv_proj(encoder_hidden_states)
_, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
q, k, v = self.qkv_proj(decoder_hidden_states, encoder_hidden_states)

attn_output = self.attn(q, k, v, kv_cache, attn_metadata)

Expand Down
41 changes: 17 additions & 24 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVCrossParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Expand Down Expand Up @@ -772,21 +773,12 @@ def __init__(
super().__init__()
self.config = config
self.model_parallel_size = get_tensor_model_parallel_world_size()
self.num_heads = self.config.num_attention_heads
self.num_local_heads = self.num_heads // self.model_parallel_size
self.num_key_value_heads = self.config.num_key_value_heads
self.num_local_key_value_heads = \
self.num_key_value_heads // self.model_parallel_size
self.dropout = config.dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // self.num_heads
self.layer_idx = layer_idx
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.q_local_size = self.num_local_heads * self.head_dim
self.kv_local_size = self.num_local_key_value_heads * self.head_dim
self.num_key_value_heads = config.num_key_value_heads

# TODO: change to Q/KV separate linear after #7448 is merged
self.qkv_proj = QKVParallelLinear(
self.qkv_proj = QKVCrossParallelLinear(
self.hidden_size,
self.head_dim,
self.num_heads,
Expand All @@ -795,6 +787,15 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)

self.num_local_heads = self.num_heads // self.model_parallel_size
self.num_local_key_value_heads = \
self.num_key_value_heads // self.model_parallel_size
self.layer_idx = layer_idx
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.q_local_size = self.num_local_heads * self.head_dim
self.kv_local_size = self.num_local_key_value_heads * self.head_dim

self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.hidden_size,
Expand Down Expand Up @@ -827,21 +828,12 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv_dec, _ = self.qkv_proj(hidden_states)
q, _, _ = qkv_dec.split(
[self.q_local_size, self.kv_local_size, self.kv_local_size],
dim=-1)
if cross_attention_states is None:
k = None
v = None
else:
qkv_enc, _ = self.qkv_proj(cross_attention_states)
_, k, v = qkv_enc.split(
[self.q_local_size, self.kv_local_size, self.kv_local_size],
dim=-1)
q, k, v = self.qkv_proj(hidden_states, cross_attention_states)
if cross_attention_states is not None:
k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
k = self.k_norm(k)

q = q.view(-1, self.num_local_heads, self.head_dim)
q = self.q_norm(q)

Expand All @@ -868,6 +860,7 @@ def _attention_with_mask(
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Skip writing kv-cache for the initial profiling run.
# TODO (NickLucche) replace with custom attn bias and use standard attn
if len(kv_cache.shape) > 1:
i = torch.ones(1, dtype=torch.float32)
if self.attn.backend in (_Backend.FLASH_ATTN,
Expand Down