Skip to content

Commit 5b86104

Browse files
committed
Merge remote-tracking branch 'upstream/main' into HEAD
2 parents 0d3a5f5 + 83481ce commit 5b86104

File tree

109 files changed

+6009
-2485
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

109 files changed

+6009
-2485
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ steps:
117117
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
118118
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
119119
- pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
120-
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py
120+
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/correctness/
121121
- pytest -v -s entrypoints/test_chat_utils.py
122122
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
123123

@@ -205,7 +205,7 @@ steps:
205205
- VLLM_USE_V1=1 pytest -v -s v1/e2e
206206
# Integration test for streaming correctness (requires special branch).
207207
- pip install -U git+https://github.com/robertgshaw2-neuralmagic/lm-evaluation-harness.git@streaming-api
208-
- pytest -v -s entrypoints/openai/test_accuracy.py::test_lm_eval_accuracy_v1_engine
208+
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
209209

210210
- label: Examples Test # 25min
211211
working_dir: "/vllm-workspace/examples"
@@ -339,6 +339,14 @@ steps:
339339
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
340340
- bash ./run-tests.sh -c configs/models-small.txt -t 1
341341

342+
- label: OpenAI API correctness
343+
source_file_dependencies:
344+
- csrc/
345+
- vllm/entrypoints/openai/
346+
- vllm/model_executor/models/whisper.py
347+
commands: # LMEval+Transcription WER check
348+
- pytest -s entrypoints/openai/correctness/
349+
342350
- label: Encoder Decoder tests # 5min
343351
source_file_dependencies:
344352
- vllm/

.pre-commit-config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ repos:
1919
rev: v2.4.0
2020
hooks:
2121
- id: codespell
22-
exclude: 'benchmarks/sonnet.txt|(build|tests/(lora/data|models/fixtures|prompts))/.*|vllm/third_party/.*'
22+
additional_dependencies: ['tomli']
23+
args: ['--toml', 'pyproject.toml']
2324
- repo: https://github.com/PyCQA/isort
2425
rev: 5.13.2
2526
hooks:

CMakeLists.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
228228
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
229229

230230
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
231-
set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use")
231+
# Please keep this in sync with FetchContent_Declare line below.
232+
set(CUTLASS_REVISION "v3.7.0" CACHE STRING "CUTLASS revision to use")
232233

233234
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
234235
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
@@ -245,6 +246,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
245246
FetchContent_Declare(
246247
cutlass
247248
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
249+
# Please keep this in sync with CUTLASS_REVISION line above.
248250
GIT_TAG v3.7.0
249251
GIT_PROGRESS TRUE
250252

@@ -266,7 +268,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
266268
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
267269
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
268270
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
269-
"csrc/sparse/cutlass/sparse_compressor_entry.cu"
270271
"csrc/cutlass_extensions/common.cpp")
271272

272273
set_gencode_flags_for_srcs(
@@ -359,8 +360,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
359360
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
360361
# require CUDA 12.2 or later (and only work on Hopper, 9.0a for now).
361362
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
362-
set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu"
363-
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
363+
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
364364
set_gencode_flags_for_srcs(
365365
SRCS "${SRCS}"
366366
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
@@ -476,7 +476,7 @@ define_gpu_extension_target(
476476
SOURCES ${VLLM_EXT_SRC}
477477
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
478478
ARCHITECTURES ${VLLM_GPU_ARCHES}
479-
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
479+
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
480480
USE_SABI 3
481481
WITH_SOABI)
482482

csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,30 @@ namespace vllm::c3x {
1616

1717
using namespace cute;
1818

19+
template <typename T>
20+
struct identity {
21+
CUTLASS_HOST_DEVICE
22+
T operator()(T lhs) const { return lhs; }
23+
};
24+
25+
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
26+
struct TrivialEpilogue {
27+
private:
28+
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
29+
using Compute = cutlass::epilogue::fusion::Sm90Compute<
30+
cutlass::epilogue::thread::Identity, ElementD, ElementAcc,
31+
cutlass::FloatRoundStyle::round_to_nearest>;
32+
33+
public:
34+
using EVTCompute = cutlass::epilogue::fusion::Sm90EVT<Compute, Accum>;
35+
using ArgumentType = typename EVTCompute::Arguments;
36+
37+
template <typename... Args>
38+
static ArgumentType prepare_args(Args... args) {
39+
return {};
40+
}
41+
};
42+
1943
/*
2044
* This class provides the common load descriptors for the
2145
* ScaledEpilogue[...] classes
@@ -174,6 +198,49 @@ struct ScaledEpilogueBias
174198
}
175199
};
176200

201+
/*
202+
* This epilogue performs the same operation as ScaledEpilogueBias, but the
203+
* bias is a column vector instead of a row vector. Useful e.g. if we are
204+
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
205+
*/
206+
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
207+
struct ScaledEpilogueColumnBias
208+
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
209+
private:
210+
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
211+
using Accum = typename SUPER::Accum;
212+
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
213+
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
214+
using Bias = typename SUPER::template ColLoad<ElementD>;
215+
216+
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
217+
cutlass::multiplies, float, float,
218+
cutlass::FloatRoundStyle::round_to_nearest>;
219+
220+
using EVTCompute0 =
221+
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
222+
223+
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
224+
cutlass::multiply_add, ElementD, float,
225+
cutlass::FloatRoundStyle::round_to_nearest>;
226+
227+
public:
228+
using EVTCompute =
229+
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
230+
231+
using ArgumentType = typename EVTCompute::Arguments;
232+
static ArgumentType prepare_args(torch::Tensor const& a_scales,
233+
torch::Tensor const& b_scales,
234+
torch::Tensor const& bias) {
235+
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
236+
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
237+
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
238+
239+
typename EVTCompute0::Arguments evt0_args{b_args};
240+
return ArgumentType{a_args, evt0_args, bias_args};
241+
}
242+
};
243+
177244
/*
178245
* This epilogue directly supports per-tensor azp in int32 form.
179246
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
@@ -314,4 +381,4 @@ struct ScaledEpilogueBiasAzpToken
314381
}
315382
};
316383

317-
}; // namespace vllm::c3x
384+
}; // namespace vllm::c3x

csrc/moe/moe_align_sum_kernels.cu

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -198,26 +198,27 @@ __global__ void moe_align_block_size_global_mem_kernel(
198198
}
199199

200200
// taken from
201-
// https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
201+
// https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957
202202
template <typename scalar_t>
203203
__global__ void sgl_moe_align_block_size_kernel(
204204
scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
205205
int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
206206
int32_t block_size, size_t numel, int32_t* cumsum) {
207207
__shared__ int32_t shared_counts[32][8];
208-
__shared__ int32_t local_offsets[256];
209208

210209
const int warp_id = threadIdx.x / 32;
211-
const int lane_id = threadIdx.x % 32;
212210
const int experts_per_warp = 8;
213211
const int my_expert_start = warp_id * experts_per_warp;
214212

213+
// Initialize shared_counts for this warp's experts
215214
for (int i = 0; i < experts_per_warp; ++i) {
216215
if (my_expert_start + i < num_experts) {
217216
shared_counts[warp_id][i] = 0;
218217
}
219218
}
220219

220+
__syncthreads();
221+
221222
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
222223
const size_t start_idx = threadIdx.x * tokens_per_thread;
223224

@@ -230,6 +231,7 @@ __global__ void sgl_moe_align_block_size_kernel(
230231

231232
__syncthreads();
232233

234+
// Single thread computes cumulative sum and total tokens
233235
if (threadIdx.x == 0) {
234236
cumsum[0] = 0;
235237
for (int i = 1; i <= num_experts; ++i) {
@@ -246,19 +248,28 @@ __global__ void sgl_moe_align_block_size_kernel(
246248

247249
__syncthreads();
248250

251+
// Assign expert IDs to blocks
249252
if (threadIdx.x < num_experts) {
250253
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
251254
i += block_size) {
252255
expert_ids[i / block_size] = threadIdx.x;
253256
}
254-
local_offsets[threadIdx.x] = cumsum[threadIdx.x];
255257
}
258+
}
256259

257-
__syncthreads();
258-
259-
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
260+
// taken from
261+
// https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957
262+
template <typename scalar_t>
263+
__global__ void sgl_moe_token_sort_kernel(scalar_t* __restrict__ topk_ids,
264+
int32_t* sorted_token_ids,
265+
int32_t* cumsum_buffer,
266+
size_t numel) {
267+
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
268+
const size_t stride = blockDim.x * gridDim.x;
269+
270+
for (size_t i = tid; i < numel; i += stride) {
260271
int32_t expert_id = topk_ids[i];
261-
int32_t rank_post_pad = atomicAdd(&local_offsets[expert_id], 1);
272+
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
262273
sorted_token_ids[rank_post_pad] = i;
263274
}
264275
}
@@ -377,23 +388,34 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
377388
torch::Tensor experts_ids,
378389
torch::Tensor num_tokens_post_pad) {
379390
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
391+
TORCH_CHECK(num_experts == 256,
392+
"sgl_moe_align_block_size kernel only supports deepseek v3.");
393+
380394
VLLM_DISPATCH_INTEGRAL_TYPES(
381395
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
382-
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
383-
// tensors
396+
// calc needed amount of shared mem for `cumsum` tensors
384397
auto options_int =
385398
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
386-
// torch::Tensor token_cnts_buffer =
387-
// torch::empty({(num_experts + 1) * num_experts}, options_int);
388399
torch::Tensor cumsum_buffer =
389-
torch::empty({num_experts + 1}, options_int);
400+
torch::zeros({num_experts + 1}, options_int);
390401

391-
auto kernel = vllm::moe::sgl_moe_align_block_size_kernel<scalar_t>;
392-
kernel<<<1, 1024, 0, stream>>>(
402+
auto align_kernel =
403+
vllm::moe::sgl_moe_align_block_size_kernel<scalar_t>;
404+
align_kernel<<<1, 1024, 0, stream>>>(
393405
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
394406
experts_ids.data_ptr<int32_t>(),
395407
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
396408
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
409+
410+
const int block_threads = 256;
411+
const int num_blocks =
412+
(topk_ids.numel() + block_threads - 1) / block_threads;
413+
const int max_blocks = 65535;
414+
const int actual_blocks = std::min(num_blocks, max_blocks);
415+
auto sort_kernel = vllm::moe::sgl_moe_token_sort_kernel<scalar_t>;
416+
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
417+
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
418+
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
397419
});
398420
}
399421

csrc/ops.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
176176
torch::Tensor const& b_scales,
177177
std::optional<torch::Tensor> const& bias);
178178

179-
bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed,
180-
torch::Tensor& e, torch::Tensor const& a);
179+
std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
181180
#endif
182181

183182
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,

csrc/quantization/awq/gemm_kernels.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ __global__ void __launch_bounds__(64)
334334
}
335335

336336
// TODO: Shang: Hoist loop invariance.
337-
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
337+
for (int ax1_0_1 = 0; ax1_0_1 < (N / 32); ++ax1_0_1) {
338338
for (int local_id = 0; local_id < 8; ++local_id) {
339339
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 +
340340
((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;

csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,17 @@ struct cutlass_3x_gemm {
5353

5454
using EVTCompute = typename Epilogue::EVTCompute;
5555

56+
// These are the minimum alignments needed for the kernels to compile
57+
static constexpr int AlignmentAB =
58+
128 / cutlass::sizeof_bits<ElementAB>::value;
59+
static constexpr int AlignmentCD = 4;
60+
5661
using CollectiveEpilogue =
5762
typename cutlass::epilogue::collective::CollectiveBuilder<
5863
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
5964
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
60-
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
61-
EpilogueSchedule, EVTCompute>::CollectiveOp;
65+
ElementAcc, float, ElementC, StrideC, AlignmentCD, ElementD, StrideD,
66+
AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp;
6267

6368
static constexpr size_t CEStorageSize =
6469
sizeof(typename CollectiveEpilogue::SharedStorage);
@@ -69,8 +74,8 @@ struct cutlass_3x_gemm {
6974
using CollectiveMainloop =
7075
typename cutlass::gemm::collective::CollectiveBuilder<
7176
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
72-
ElementAB, cutlass::layout::RowMajor, 16,
73-
ElementAB, cutlass::layout::ColumnMajor, 16,
77+
ElementAB, cutlass::layout::RowMajor, AlignmentAB,
78+
ElementAB, cutlass::layout::ColumnMajor, AlignmentAB,
7479
ElementAcc, TileShape, ClusterShape,
7580
Stages,
7681
KernelSchedule>::CollectiveOp;

csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,19 @@ struct cutlass_2x_gemm {
103103

104104
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
105105

106+
// These are the minimum alignments needed for the kernels to compile
107+
static constexpr int AlignmentAB =
108+
128 / cutlass::sizeof_bits<ElementAB>::value;
109+
static constexpr int AlignmentCD = 4;
110+
106111
// clang-format off
107112
using RowMajor = typename cutlass::layout::RowMajor;
108113
using ColumnMajor = typename cutlass::layout::ColumnMajor;
109114
using KernelType =
110115
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
111-
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
112-
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
113-
float, cutlass::layout::RowMajor, 4,
116+
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, AlignmentAB,
117+
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, AlignmentAB,
118+
float, cutlass::layout::RowMajor, AlignmentCD,
114119
ElementAcc, float, cutlass::arch::OpClassTensorOp,
115120
Arch,
116121
TileShape, WarpShape, InstructionShape,

0 commit comments

Comments
 (0)