Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tensor parallel MOE implementation #2293

Closed
wants to merge 50 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
92c3a3c
expert parallel moe
scv119 Dec 27, 2023
d24d9dd
update
scv119 Dec 27, 2023
408ed6d
update
scv119 Dec 28, 2023
ec913db
update
scv119 Dec 28, 2023
56f3220
update
scv119 Dec 28, 2023
6067ec6
update
scv119 Dec 28, 2023
869e0c5
update
scv119 Dec 28, 2023
ea44a0f
update
scv119 Dec 28, 2023
90e223a
update
scv119 Dec 28, 2023
69b5a55
update
scv119 Dec 28, 2023
357b046
update
scv119 Dec 28, 2023
86f7e1e
update
scv119 Dec 28, 2023
40f0f59
update
scv119 Dec 28, 2023
baa90d2
update
scv119 Dec 29, 2023
4367d6a
update
scv119 Dec 29, 2023
b4df657
update
scv119 Dec 29, 2023
1ac4890
update
scv119 Dec 29, 2023
14f29b3
update
scv119 Dec 30, 2023
b3a1b77
update
scv119 Dec 30, 2023
8045832
update
scv119 Dec 30, 2023
5eb304a
update
scv119 Dec 30, 2023
a172a7c
update
scv119 Dec 30, 2023
a56b2df
update
scv119 Dec 30, 2023
ca7110e
update
scv119 Dec 30, 2023
82de999
update
scv119 Dec 30, 2023
20fcbc0
update
scv119 Dec 30, 2023
92709c1
update
scv119 Dec 30, 2023
22daa9b
update
scv119 Dec 31, 2023
42c659d
update
scv119 Dec 31, 2023
4e84e02
udpate
scv119 Dec 31, 2023
d586474
update
scv119 Dec 31, 2023
15f820f
update
scv119 Dec 31, 2023
e0d9440
update
scv119 Dec 31, 2023
817e0bb
update
scv119 Dec 31, 2023
285e7af
Apply suggestions from code review
scv119 Jan 4, 2024
d850834
update
scv119 Jan 4, 2024
920209f
update
scv119 Jan 4, 2024
fdd5b77
update
scv119 Jan 4, 2024
17d17f1
update
scv119 Jan 4, 2024
6c60c3b
update
scv119 Jan 4, 2024
9749e64
update
scv119 Jan 4, 2024
4bee472
update
scv119 Jan 4, 2024
0fe75f3
update
scv119 Jan 4, 2024
cce13fb
update
scv119 Jan 4, 2024
f0f1d5e
update
scv119 Jan 4, 2024
42b3cc3
update
scv119 Jan 4, 2024
0a8069b
update
scv119 Jan 4, 2024
f955162
update
scv119 Jan 4, 2024
1089dd8
reorder operations
scv119 Jan 17, 2024
43ec685
Merge remote-tracking branch 'origin/main' into moe
scv119 Jan 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions csrc/dispatch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,14 @@
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))

#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)

#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
35 changes: 35 additions & 0 deletions csrc/misc_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// code adapted from https://github.com/rusty1s/pytorch_bincount
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>

#include "cuda_compat.h"
#include "dispatch_utils.h"

#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS

namespace vllm {
template <typename scalar_t>
__global__ void bincount_kernel(scalar_t *__restrict__ src, int32_t *out,
size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = index; i < numel; i += stride) {
atomicAdd(out + (ptrdiff_t)src[i], 1);
}
}
}

// create a custom bincount since pytorch's bincount is
// not cudagraph capturable.
void vllm_bincount(torch::Tensor src, torch::Tensor out) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_INTEGRAL_TYPES(
src.scalar_type(), "bincount_kernel", [&] {
vllm::bincount_kernel<scalar_t><<<BLOCKS(src.numel()), THREADS, 0, stream>>>(
src.data<scalar_t>(), out.data<int32_t>(), src.numel());
});
}
4 changes: 4 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,7 @@ torch::Tensor gptq_gemm(
void gptq_shuffle(
torch::Tensor q_weight,
torch::Tensor q_perm);

void vllm_bincount(
torch::Tensor src,
torch::Tensor out);
4 changes: 4 additions & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def(
"bincount",
&vllm_bincount,
"cuda-graph compatible bincount implementation");

// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def get_torch_arch_list() -> Set[str]:
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
"csrc/quantization/gptq/q_gemm.cu",
"csrc/cuda_utils_kernels.cu",
"csrc/misc_kernels.cu",
"csrc/pybind.cpp",
]

Expand Down
57 changes: 57 additions & 0 deletions tests/kernels/test_moe_grouped_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import itertools
import random
import pytest
import torch

from vllm.model_executor.layers.moe import grouped_matmul


def ref_grouped_matmul(
fused_input: torch.Tensor,
cum_group_range: torch.Tensor,
fused_group_b: torch.Tensor,
activation: str = "",
) -> torch.Tensor:
num_groups = cum_group_range.shape[0] - 1
output = torch.zeros(fused_input.shape[0],
fused_group_b.shape[2],
device=fused_input.device,
dtype=fused_input.dtype)
for i in range(num_groups):
group_i = fused_input[cum_group_range[i]:cum_group_range[i + 1]]
group_i_b = fused_group_b[i]
output[cum_group_range[i]:cum_group_range[i + 1]] = group_i @ group_i_b

if activation == "silu":
output = torch.nn.functional.silu(output)
return output


@pytest.mark.parametrize("group_size", [1, 2, 4, 8, 16])
@pytest.mark.parametrize("m", [1, 5, 33, 81])
@pytest.mark.parametrize("n", [128, 1024, 2000])
@pytest.mark.parametrize("k", [128, 1024, 2000])
@pytest.mark.parametrize("activation", ["", "silu"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_moe_grouped_matmul(
group_size: int,
m: int,
n: int,
k: int,
activation: str,
dtype: torch.dtype,
):
groups = [random.randint(1, m) for _ in range(group_size)]
batch_size = sum(groups)
fused_input = torch.randn(batch_size, k, dtype=dtype, device="cuda")
cum_group_range = torch.tensor([0] + list(itertools.accumulate(groups)),
dtype=torch.int32,
device="cuda")
fused_group_b = torch.randn(group_size, k, n, dtype=dtype, device="cuda")

ref_output = ref_grouped_matmul(fused_input, cum_group_range,
fused_group_b, activation)

output = grouped_matmul(fused_input, cum_group_range, fused_group_b,
activation)
assert torch.allclose(output, ref_output, rtol=0.01, atol=1e-3)
Loading
Loading