Skip to content

Add cuBLASMp-backed GEMM-like API to TE common #1824

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 47 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
9002214
Pick up cuBLASMp during build
vcherepanov-nv Feb 5, 2025
9e6dcc9
Saving...
vcherepanov-nv Feb 6, 2025
331ee78
Change lib order to fix link error
vcherepanov-nv Mar 20, 2025
63c0c37
Saving...
vcherepanov-nv Mar 20, 2025
c54311f
Context creation, incomplete...
vcherepanov-nv Mar 22, 2025
511a7f9
Test fixure
vcherepanov-nv Mar 26, 2025
aaadda6
Saving...
vcherepanov-nv Apr 4, 2025
c421a02
A sanity AgGemm test, failing...
vcherepanov-nv Apr 7, 2025
b7995a5
Saving...
vcherepanov-nv Apr 7, 2025
4d768b0
Fix axes
vcherepanov-nv Apr 9, 2025
c880b12
Take care of uneven distribution
vcherepanov-nv Apr 10, 2025
f798cb3
Use MPI to get position of local matrices
vcherepanov-nv Apr 10, 2025
f88a804
Refactor
vcherepanov-nv Apr 11, 2025
75d9649
Refactor & fixes
vcherepanov-nv Apr 13, 2025
ef3a92b
Saving...
vcherepanov-nv Apr 14, 2025
3cebfcb
Gemm-RS
vcherepanov-nv Apr 14, 2025
ee52826
Gemm-AR, not working...
vcherepanov-nv Apr 14, 2025
a83867f
Fixes
vcherepanov-nv Apr 14, 2025
3f2c87e
Setting all-reduce epilogue for gemm-ar
vcherepanov-nv Apr 14, 2025
d170671
Use supported shapes for GEMM-AR
vcherepanov-nv Apr 15, 2025
9f18e9d
Tweak tolerance
vcherepanov-nv Apr 15, 2025
39ce156
First shot at fp8
vcherepanov-nv Apr 18, 2025
42468da
Use TensorHolder in tests
vcherepanov-nv Apr 18, 2025
0f05a05
More test configs
vcherepanov-nv Apr 19, 2025
ba8e75a
Support comm_sm_count
vcherepanov-nv Apr 19, 2025
ee00586
Parametrize dtypes for A, B and D separately
vcherepanov-nv Apr 21, 2025
4eb68f3
Tweak scaling
vcherepanov-nv Apr 23, 2025
9979ebd
Amax ptr
vcherepanov-nv Apr 23, 2025
dfb9807
Flags parity with cublas_gemm, saving...
vcherepanov-nv Apr 24, 2025
3a1e403
Cleanup
vcherepanov-nv Apr 25, 2025
1d9395a
Bias tests
vcherepanov-nv Apr 25, 2025
69f53bc
Fix bias test
vcherepanov-nv Apr 25, 2025
229ff14
Aux, saving...
vcherepanov-nv Apr 25, 2025
877714f
aux_ld
vcherepanov-nv Apr 28, 2025
e080b55
A fix
vcherepanov-nv May 1, 2025
1f694e6
Use test::Tensor
vcherepanov-nv May 3, 2025
785b3e8
Set scale inv
vcherepanov-nv May 3, 2025
56618d4
Remove unsupported test configs
vcherepanov-nv May 5, 2025
025f14c
Tweak tests
vcherepanov-nv May 5, 2025
380e50b
Replace libcal with NCCL
vcherepanov-nv May 6, 2025
547f8c0
Add NVTX markers to API functions
vcherepanov-nv May 7, 2025
0296331
Tweak GemmAr tests
vcherepanov-nv May 14, 2025
d35c702
More test config
vcherepanov-nv May 16, 2025
8c9cde0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2025
5336efa
Fix merge fallout
vcherepanov-nv May 27, 2025
73f4ef9
Remove MPI dependency, comment API, add algo parameter
vcherepanov-nv Jun 2, 2025
9058aaa
Fix nvshmem dependency
vcherepanov-nv Jun 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

"""Installation script."""

from importlib import metadata
import os
import time
from pathlib import Path
Expand Down Expand Up @@ -72,6 +73,18 @@ def setup_common_extension() -> CMakeExtension:
if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")

if bool(int(os.getenv("NVTE_WITH_CUBLASMP", "1"))):
cmake_flags.append("-DNVTE_WITH_CUBLASMP=ON")
cublasmp_dir = os.getenv("CUBLASMP_HOME") or metadata.distribution(
"nvidia-cublasmp-cu12"
).locate_file("nvidia/cublasmp/cu12")
cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}")
nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution(
"nvidia-nvshmem-cu12"
).locate_file("nvidia/nvshmem")
cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}")
print("CMAKE_FLAGS:", cmake_flags[-2:])

# Project directory root
root_path = Path(__file__).resolve().parent

Expand Down
2 changes: 2 additions & 0 deletions tests/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/transformer_e
message(STATUS "Found transformer_engine library: ${TE_LIB}")
include_directories(../../transformer_engine/common/include)
include_directories(../../transformer_engine/common)
include_directories(../../transformer_engine)
include_directories(${CMAKE_SOURCE_DIR})

find_package(CUDAToolkit REQUIRED)
include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake)

add_subdirectory(comm_gemm)
add_subdirectory(operator)
add_subdirectory(util)
20 changes: 20 additions & 0 deletions tests/cpp/comm_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

add_executable(test_comm_gemm
test_comm_gemm.cu
../../../transformer_engine/common/transformer_engine.cpp
../test_common.cu)

find_package(OpenMP REQUIRED)
find_package(MPI REQUIRED)
find_library(NCCL_LIB
NAMES nccl libnccl
PATH_SUFFIXES lib
REQUIRED)
target_include_directories(test_comm_gemm PRIVATE ${MPI_CXX_INCLUDE_PATH} $ENV{CUBLASMP_HOME}/include)
target_link_libraries(test_comm_gemm PUBLIC CUDA::cudart GTest::gtest ${TE_LIB} CUDA::nvrtc CUDNN::cudnn MPI::MPI_CXX ${NCCL_LIB} OpenMP::OpenMP_CXX)

include(GoogleTest)
gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600)
Loading
Loading