Skip to content

Commit

Permalink
address review
Browse files Browse the repository at this point in the history
Signed-off-by: NickLucche <[email protected]>
  • Loading branch information
NickLucche committed Feb 19, 2025
1 parent 8393b98 commit 75ce6ac
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 104 deletions.
95 changes: 95 additions & 0 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,3 +1165,98 @@ def extra_repr(self) -> str:
s += f", tp_size={self.tp_size}"
s += f", reduce_results={self.reduce_results}"
return s


class QKVCrossParallelLinear(torch.nn.Module):

def __init__(self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: Optional[int] = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
# Empty placeholders for loading as a single module.
self.weight = torch.nn.Parameter()
set_weight_attrs(self.weight, {
"weight_loader": self.weight_loader_weight,
})
# Use a dictionary to avoid submodules parameters auto-registration:
# drop-in replacement for a `QKVParallelLinear` module.
self.proj = dict()
self.proj["q_proj_decoder"] = ColumnParallelLinear(
input_size=hidden_size,
output_size=total_num_heads * head_size,
bias=bias,
quant_config=quant_config,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
prefix=f"{prefix}.q_proj_decoder")

self.proj["kv_proj_encoder"] = QKVParallelLinear(
hidden_size=hidden_size,
head_size=head_size,
total_num_heads=0,
total_num_kv_heads=total_num_kv_heads,
bias=bias,
quant_config=quant_config,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
prefix=f"{prefix}.kv_proj_encoder")

# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size

if bias:
self.bias = torch.nn.Parameter()
set_weight_attrs(self.bias, {
"weight_loader": self.weight_loader_bias,
})

@property
def q_proj_decoder(self):
return self.proj["q_proj_decoder"]

@property
def kv_proj_encoder(self):
return self.proj["kv_proj_encoder"]

def forward(self, decoder_hidden_states, encoder_hidden_states):
q, _ = self.q_proj_decoder(decoder_hidden_states)
if encoder_hidden_states is None:
# Encoder KV already cached.
k = None
v = None
else:
# Prefill phase, encoder KV cached here.
kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
# Split kv in half
k, v = kv_enc.split(self.kv_size, dim=-1)
return q, k, v

def weight_loader_weight(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
# NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param.
param = self.q_proj_decoder.weight if loaded_shard_id == "q" \
else self.kv_proj_encoder.weight
param.weight_loader(
param,
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
param, loaded_weight, loaded_shard_id)

def weight_loader_bias(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
param = self.q_proj_decoder.bias if loaded_shard_id == "q" \
else self.kv_proj_encoder.bias
param.weight_loader(
param,
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
param, loaded_weight, loaded_shard_id)
7 changes: 4 additions & 3 deletions vllm/model_executor/models/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVCrossParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Expand All @@ -43,7 +44,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from .utils import QKVCrossParallelLinear, maybe_prefix
from .utils import maybe_prefix

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -168,7 +169,7 @@ def __init__(
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
self.num_kv_heads = self.num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim

Expand Down Expand Up @@ -248,7 +249,7 @@ def __init__(
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
self.num_kv_heads = self.num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim

Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVCrossParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Expand All @@ -64,7 +65,7 @@
from .clip import CLIPMLP
from .interfaces import SupportsMultiModal
from .llama import LlamaDecoderLayer, LlamaMLP
from .utils import QKVCrossParallelLinear, maybe_prefix
from .utils import maybe_prefix

logger = init_logger(__name__)

Expand Down
100 changes: 0 additions & 100 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available
Expand Down Expand Up @@ -646,98 +641,3 @@ def extract_layer_index(layer_name: str) -> int:
assert len(int_vals) == 1, (f"layer name {layer_name} should"
" only contain one integer")
return int_vals[0]


class QKVCrossParallelLinear(torch.nn.Module):

def __init__(self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: Optional[int] = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
# Empty placeholders for loading as a single module.
self.weight = torch.nn.Parameter()
set_weight_attrs(self.weight, {
"weight_loader": self.weight_loader_weight,
})
# Use a dictionary to avoid submodules parameters auto-registration:
# drop-in replacement for a `QKVParallelLinear` module.
self.proj = dict()
self.proj["q_proj_decoder"] = ColumnParallelLinear(
input_size=hidden_size,
output_size=total_num_heads * head_size,
bias=bias,
quant_config=quant_config,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
prefix=f"{prefix}.q_proj_decoder")

self.proj["kv_proj_encoder"] = QKVParallelLinear(
hidden_size=hidden_size,
head_size=head_size,
total_num_heads=0,
total_num_kv_heads=total_num_kv_heads,
bias=bias,
quant_config=quant_config,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
prefix=f"{prefix}.kv_proj_encoder")

# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size

if bias:
self.bias = torch.nn.Parameter()
set_weight_attrs(self.bias, {
"weight_loader": self.weight_loader_bias,
})

@property
def q_proj_decoder(self):
return self.proj["q_proj_decoder"]

@property
def kv_proj_encoder(self):
return self.proj["kv_proj_encoder"]

def forward(self, decoder_hidden_states, encoder_hidden_states):
q, _ = self.q_proj_decoder(decoder_hidden_states)
if encoder_hidden_states is None:
# Encoder KV already cached.
k = None
v = None
else:
# Prefill phase, encoder KV cached here.
kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
# Split kv in half
k, v = kv_enc.split(self.kv_size, dim=-1)
return q, k, v

def weight_loader_weight(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
# NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param.
param = self.q_proj_decoder.weight if loaded_shard_id == "q" \
else self.kv_proj_encoder.weight
param.weight_loader(
param,
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
param, loaded_weight, loaded_shard_id)

def weight_loader_bias(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
param = self.q_proj_decoder.bias if loaded_shard_id == "q" \
else self.kv_proj_encoder.bias
param.weight_loader(
param,
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
param, loaded_weight, loaded_shard_id)

0 comments on commit 75ce6ac

Please sign in to comment.