Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tensor parallel MOE implementation #2293

Closed
wants to merge 50 commits into from
Closed
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
92c3a3c
expert parallel moe
scv119 Dec 27, 2023
d24d9dd
update
scv119 Dec 27, 2023
408ed6d
update
scv119 Dec 28, 2023
ec913db
update
scv119 Dec 28, 2023
56f3220
update
scv119 Dec 28, 2023
6067ec6
update
scv119 Dec 28, 2023
869e0c5
update
scv119 Dec 28, 2023
ea44a0f
update
scv119 Dec 28, 2023
90e223a
update
scv119 Dec 28, 2023
69b5a55
update
scv119 Dec 28, 2023
357b046
update
scv119 Dec 28, 2023
86f7e1e
update
scv119 Dec 28, 2023
40f0f59
update
scv119 Dec 28, 2023
baa90d2
update
scv119 Dec 29, 2023
4367d6a
update
scv119 Dec 29, 2023
b4df657
update
scv119 Dec 29, 2023
1ac4890
update
scv119 Dec 29, 2023
14f29b3
update
scv119 Dec 30, 2023
b3a1b77
update
scv119 Dec 30, 2023
8045832
update
scv119 Dec 30, 2023
5eb304a
update
scv119 Dec 30, 2023
a172a7c
update
scv119 Dec 30, 2023
a56b2df
update
scv119 Dec 30, 2023
ca7110e
update
scv119 Dec 30, 2023
82de999
update
scv119 Dec 30, 2023
20fcbc0
update
scv119 Dec 30, 2023
92709c1
update
scv119 Dec 30, 2023
22daa9b
update
scv119 Dec 31, 2023
42c659d
update
scv119 Dec 31, 2023
4e84e02
udpate
scv119 Dec 31, 2023
d586474
update
scv119 Dec 31, 2023
15f820f
update
scv119 Dec 31, 2023
e0d9440
update
scv119 Dec 31, 2023
817e0bb
update
scv119 Dec 31, 2023
285e7af
Apply suggestions from code review
scv119 Jan 4, 2024
d850834
update
scv119 Jan 4, 2024
920209f
update
scv119 Jan 4, 2024
fdd5b77
update
scv119 Jan 4, 2024
17d17f1
update
scv119 Jan 4, 2024
6c60c3b
update
scv119 Jan 4, 2024
9749e64
update
scv119 Jan 4, 2024
4bee472
update
scv119 Jan 4, 2024
0fe75f3
update
scv119 Jan 4, 2024
cce13fb
update
scv119 Jan 4, 2024
f0f1d5e
update
scv119 Jan 4, 2024
42b3cc3
update
scv119 Jan 4, 2024
0a8069b
update
scv119 Jan 4, 2024
f955162
update
scv119 Jan 4, 2024
1089dd8
reorder operations
scv119 Jan 17, 2024
43ec685
Merge remote-tracking branch 'origin/main' into moe
scv119 Jan 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def main(args: argparse.Namespace):
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
enforce_eager=args.enforce_eager,
load_format="dummy",
)

sampling_params = SamplingParams(
Expand Down
32 changes: 32 additions & 0 deletions csrc/bincount.cu
scv119 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>

#include "cuda_compat.h"
#include "dispatch_utils.h"

#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS

namespace vllm {
template <typename scalar_t>
__global__ void bincount_kernel(scalar_t *__restrict__ src, int32_t *out,
size_t numel) {
scv119 marked this conversation as resolved.
Show resolved Hide resolved
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = index; i < numel; i += stride) {
atomicAdd(out + (ptrdiff_t)src[i], 1);
}
}
}

void vllm_bincount(torch::Tensor src, torch::Tensor out) {
scv119 marked this conversation as resolved.
Show resolved Hide resolved
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(
scv119 marked this conversation as resolved.
Show resolved Hide resolved
src.scalar_type(), "bincount_kernel", [&] {
vllm::bincount_kernel<scalar_t><<<BLOCKS(src.numel()), THREADS, 0, stream>>>(
src.data<scalar_t>(), out.data<int32_t>(), src.numel());
});
}
scv119 marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 4 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,7 @@ torch::Tensor gptq_gemm(
void gptq_shuffle(
torch::Tensor q_weight,
torch::Tensor q_perm);

void vllm_bincount(
torch::Tensor src,
torch::Tensor out);
4 changes: 4 additions & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def(
"bincount",
&vllm_bincount,
"Gather key and value from the cache into contiguous QKV tensors");

// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def get_torch_arch_list() -> Set[str]:
ext_modules = []

vllm_extension_sources = [
"csrc/bincount.cu",
"csrc/cache_kernels.cu",
"csrc/attention/attention_kernels.cu",
"csrc/pos_encoding_kernels.cu",
Expand Down
62 changes: 62 additions & 0 deletions tests/kernels/test_moe_grouped_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import itertools
import random
import pytest
import torch

from vllm.model_executor.layers.moe import grouped_matmul


def ref_grouped_matmul(
fused_input: torch.Tensor,
cum_group_range: torch.Tensor,
fused_group_b: torch.Tensor,
activation: str = "",
) -> torch.Tensor:
num_groups = cum_group_range.shape[0] - 1
output = torch.zeros(fused_input.shape[0],
fused_group_b.shape[2],
device=fused_input.device,
dtype=fused_input.dtype)
for i in range(num_groups):
group_i = fused_input[cum_group_range[i]:cum_group_range[i + 1]]
group_i_b = fused_group_b[i]
output[cum_group_range[i]:cum_group_range[i + 1]] = group_i @ group_i_b

if activation == "silu":
output = torch.nn.functional.silu(output)
return output


@pytest.mark.parametrize("group_size", [1, 2, 4, 8, 16])
@pytest.mark.parametrize("m", [1, 5, 33, 81])
@pytest.mark.parametrize("n", [128, 1024, 2000])
@pytest.mark.parametrize("k", [128, 1024, 2000])
@pytest.mark.parametrize("activation", ["", "silu"])
@pytest.mark.parametrize("dtype",
[torch.float16, torch.float32, torch.bfloat16])
def test_moe_grouped_matmul(
group_size: int,
m: int,
n: int,
k: int,
activation: str,
dtype: torch.dtype,
):
groups = [random.randint(1, m) for _ in range(group_size)]
batch_size = sum(groups)
fused_input = torch.randn(batch_size, k, dtype=dtype, device="cuda")
cum_group_range = torch.tensor([0] + list(itertools.accumulate(groups)),
dtype=torch.int32,
device="cuda")
fused_group_b = torch.randn(group_size, k, n, dtype=dtype, device="cuda")

ref_output = ref_grouped_matmul(fused_input, cum_group_range,
fused_group_b, activation)
output = grouped_matmul(fused_input, cum_group_range, fused_group_b,
activation)
diff = torch.abs(ref_output - output)
mean = torch.mean(diff)
max = torch.max(diff)
print(f"{mean=}, {max=}")

torch.allclose(output, ref_output, atol=1e-2, rtol=0)
256 changes: 256 additions & 0 deletions vllm/model_executor/layers/moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
from typing import Tuple

import torch

from torch import nn
import torch.nn.functional as F

import triton
import triton.language as tl

from vllm._C import ops
from vllm.model_executor.layers.linear import ReplicatedLinear

from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)


class MoE(nn.Module):

def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
tp_size: int,
scv119 marked this conversation as resolved.
Show resolved Hide resolved
):
super().__init__()
self.num_total_experts = num_experts
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size // tp_size

self.gate = ReplicatedLinear(self.hidden_size,
self.num_total_experts,
bias=False,
linear_method=None)

self.w1s = nn.Parameter(
torch.empty(self.num_total_experts, self.hidden_size,
self.intermediate_size))
self.w2s = nn.Parameter(
torch.empty(self.num_total_experts, self.intermediate_size,
self.hidden_size))
self.w3s = nn.Parameter(
torch.empty(self.num_total_experts, self.hidden_size,
self.intermediate_size))

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (batch * sequence_length, n_experts)
router_logits, _ = self.gate(hidden_states)

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights,
self.top_k,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)

expanded_hidden_states, experts_range, expanded_weights, experts_indices = \
self.expand_and_permutate_hidden_states(
hidden_states, selected_experts, routing_weights)

expanded_hidden_states = self.grouped_mlp(expanded_hidden_states,
experts_range, self.w1s.data,
self.w2s.data, self.w3s.data)

expanded_hidden_states.mul_(expanded_weights.unsqueeze(-1))

tensor_model_parallel_all_reduce(expanded_hidden_states)

return self.merge_expert_outputs(expanded_hidden_states,
experts_indices).view(
batch_size, sequence_length,
hidden_size)

def expand_and_permutate_hidden_states(
self,
hidden_states: torch.Tensor, # [batch_size, hidden_size]
selected_experts: torch.Tensor, # [batch_size, top_k_experts]
routing_weights: torch.Tensor, # [batch_size, top_k_experts]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
_, experts_indices = torch.sort(selected_experts.view(-1), dim=-1)
cum_experts_range = torch.zeros(self.num_total_experts + 1,
dtype=torch.int32,
device=hidden_states.device)
num_rows_per_expert = torch.zeros(self.num_total_experts,
dtype=torch.int32,
device=hidden_states.device)
ops.bincount(selected_experts.view(-1), num_rows_per_expert)
torch.cumsum(num_rows_per_expert, dim=0, out=cum_experts_range[1:])
expanded_weights = routing_weights.view(-1)[experts_indices]
return hidden_states[experts_indices.div_(
self.top_k, rounding_mode="floor"
)], cum_experts_range, expanded_weights, experts_indices

def grouped_mlp(
self,
expanded_hidden_states: torch.
Tensor, # [batch_size * top_k_experts, hidden_size]
cum_experts_range: torch.Tensor, # [num_experts + 1]
w1s: torch.Tensor, # [num_experts, hidden_size, ffn_dim]
w2s: torch.Tensor, # [num_experts, ffn_dim, hidden_size]
w3s: torch.Tensor, # [num_experts, hidden_size, ffn_dim]
) -> torch.Tensor: # [batch_size * top_k_experts, hidden_size]
grouped_w1_out = grouped_matmul(expanded_hidden_states,
cum_experts_range, w1s, "silu")
grouped_w3_out = grouped_matmul(expanded_hidden_states,
cum_experts_range, w3s)
Comment on lines +177 to +180
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we merge w1s and w3 just like what we do for LlamaMLP? Merging the two weights will be highly efficient given the cost of grouped GEMM.

grouped_w1_out.mul_(grouped_w3_out)
return grouped_matmul(grouped_w1_out, cum_experts_range, w2s)

def merge_expert_outputs(
self,
expanded_hidden_states: torch.
Tensor, # [batch_size * top_k_experts, hidden_size]
expert_indicies, # [batch_size * top_k_experts]
) -> torch.Tensor:
out = torch.zeros(expanded_hidden_states.shape[0] // self.top_k,
self.hidden_size,
device=expanded_hidden_states.device,
dtype=expanded_hidden_states.dtype)
out.index_add_(0, expert_indicies, expanded_hidden_states)
return out


@triton.jit
def grouped_matmul_kernel(
# device tensor of matrices pointers
fused_input_ptr,
cum_input_group_range,
fused_b_ptr,
fused_output_ptr,
group_size,
n,
k,
lda,
ldb,
ldc,
# number of virtual SM
NUM_SM: tl.constexpr,
# tile sizes
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
ACTIVATION: tl.constexpr,
):
tile_idx = tl.program_id(0)
last_problem_end = 0
for g in range(group_size):
# get the gemm size of the current problem
a_offset = tl.load(cum_input_group_range + g)
gm = tl.load(cum_input_group_range + g + 1) - a_offset
gn = n
gk = k
num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)
num_tiles = num_m_tiles * num_n_tiles
# iterate through the tiles in the current gemm problem
while (tile_idx >= last_problem_end
and tile_idx < last_problem_end + num_tiles):

# pick up a tile from the current gemm problem
k = gk
a_ptr = fused_input_ptr + a_offset * lda
b_ptr = fused_b_ptr + g * k * n
c_ptr = fused_output_ptr + a_offset * ldc
# figure out tile coordinates
tile_idx_in_gemm = tile_idx - last_problem_end
tile_m_idx = tile_idx_in_gemm // num_n_tiles
tile_n_idx = tile_idx_in_gemm % num_n_tiles

# do regular gemm here
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :]
b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :]
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N),
dtype=tl.float32)
for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
# hint to Triton compiler to do proper loop pipelining
tl.multiple_of(a_ptrs, [16, 16])
tl.multiple_of(b_ptrs, [16, 16])

a = tl.load(a_ptrs,
mask=offs_k[None, :] < k - kk * BLOCK_SIZE_K,
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < k - kk * BLOCK_SIZE_K,
other=0.0)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K * ldb

if ACTIVATION == "silu":
accumulator = silu(accumulator)
c = accumulator.to(tl.float16)

offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :]
c_mask = (offs_cm[:, None] < gm) & (offs_cn[None, :] < gn)

tl.store(c_ptrs, c, mask=c_mask)

# go to the next tile by advancing NUM_SM
tile_idx += NUM_SM

# get ready to go to the next gemm problem
last_problem_end = last_problem_end + num_tiles


@triton.jit
def silu(x):
return x * tl.sigmoid(x)


def grouped_matmul(fused_input: torch.Tensor,
cum_group_range: torch.Tensor,
fused_group_b: torch.Tensor,
activation: str = ""):
device = torch.device('cuda')
assert cum_group_range.shape[0] == fused_group_b.shape[0] + 1
group_size = cum_group_range.shape[0] - 1
output = torch.zeros(fused_input.shape[0],
fused_group_b.shape[2],
device=device,
dtype=fused_input.dtype)
BLOCK_SIZE_N = 64
num_warps = 2
if fused_input.shape[0] >= 8:
num_warps = 4
BLOCK_SIZE_N = 128
# we use a fixed number of CTA, and it's auto-tunable
grid = lambda META: (META['NUM_SM'], )
grouped_matmul_kernel[grid](fused_input,
cum_group_range,
fused_group_b,
output,
group_size,
n=fused_group_b.shape[2],
k=fused_group_b.shape[1],
lda=fused_input.stride(0),
ldb=fused_group_b.stride(1),
ldc=output.stride(0),
ACTIVATION=activation,
BLOCK_SIZE_M=16,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=32,
NUM_SM=128,
num_warps=num_warps,
num_stages=5),

return output
Loading
Loading