Skip to content

Fix recent changes of universal_gemm in tile_engine #2344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions tile_engine/ops/gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,54 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BL
set(GEMM_CODEGEN_CPP_FILES "")
set(GEMM_CODEGEN_HPP_FILES "")


add_custom_command(
OUTPUT ${GEMM_CODEGEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
--gen_blobs
)

foreach(blob ${GEMM_CODEGEN_BLOBS})
string(STRIP "${blob}" stripped_blob)

if(stripped_blob MATCHES "\\.cpp$")
message(STATUS "Adding gemm codegen file: ${stripped_blob}")
list(APPEND GEMM_CODEGEN_CPP_FILES "${stripped_blob}")
elseif(stripped_blob MATCHES "\\.hpp$")
list(APPEND GEMM_CODEGEN_HPP_FILES "${stripped_blob}")
endif()
endforeach()

add_custom_command(
OUTPUT ${GEMM_CODEGEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
--gen_blobs
)

add_library(gemm_template_instances OBJECT EXCLUDE_FROM_ALL ${GEMM_CODEGEN_CPP_FILES})
set(chunk_size 9)
set(chunk_index 0)
set(intermediate_libs "")

list(LENGTH GEMM_CODEGEN_CPP_FILES total_files)
message(STATUS "Total gemm codegen files: ${total_files}")
math(EXPR num_chunks "( ${total_files} + ${chunk_size} - 1 ) / ${chunk_size}")

foreach(i RANGE 0 ${num_chunks})
set(chunk_files "")
math(EXPR start "${i} * ${chunk_size}")
math(EXPR end "${start} + ${chunk_size} - 1")
foreach(j RANGE ${start} ${end})
if(j LESS total_files)
list(GET GEMM_CODEGEN_CPP_FILES ${j} file)
list(APPEND chunk_files ${file})
endif()
endforeach()
if(chunk_files)
set(lib_name "gemm_objlib_${i}")
add_library(${lib_name} OBJECT ${chunk_files})
list(APPEND intermediate_libs $<TARGET_OBJECTS:${lib_name}>)
endif()
endforeach()
add_library(gemm_template_instances STATIC EXCLUDE_FROM_ALL ${intermediate_libs})

#add_library(gemm_template_instances STATIC EXCLUDE_FROM_ALL ${GEMM_CODEGEN_CPP_FILES})
# Explicitly set LINKER_LANGUAGE to avoid build config failures with Ninja.
set_target_properties(gemm_template_instances PROPERTIES LINKER_LANGUAGE CXX)
target_include_directories(gemm_template_instances PRIVATE ${CMAKE_CURRENT_LIST_DIR})
Expand Down
3 changes: 3 additions & 0 deletions tile_engine/ops/gemm/codegen_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,12 @@
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Expand Down
27 changes: 9 additions & 18 deletions tile_engine/ops/gemm/configs/default_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,19 @@
},
"tile_config": {
"tile_m": {
"max": 512,
"max": 256,
"min": 64,
"step": 64,
"exclude": []
"exclude": [192]
},
"tile_n": {
"max": 512,
"max": 256,
"min": 64,
"step": 32,
"exclude": []
"step": 64,
"exclude": [192]
},
"tile_k": {
"max": 512,
"max": 256,
"min": 64,
"step": 64,
"exclude": [192]
Expand All @@ -59,7 +59,6 @@
},
"warp_n": {
"values": [
4,
2,
1
]
Expand All @@ -72,16 +71,12 @@
"warp_tile_m": {
"values": [
4,
8,
16,
32,
64
32
]
},
"warp_tile_n": {
"values": [
4,
8,
16,
32,
64
Expand All @@ -100,20 +95,16 @@
"trait_config": {
"pipeline": {
"values": [
"compv4",
"compv3",
"mem"
"compv3"
]
},
"scheduler": {
"values": [
"intrawave",
"interwave"
"intrawave"
]
},
"epilogue": {
"values": [
"default",
"cshuffle"
]
},
Expand Down
70 changes: 40 additions & 30 deletions tile_engine/ops/gemm/gemm_instance_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,14 @@ def list_all_trait_names(self):
f.write(str(w_p / "gemm_dispatcher.hpp") + "\n")
for trait in self.valid_trait_names:
f.write(str(w_p / f"gemm_{trait}.hpp") + "\n")
file_name = set()
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
for tile in tile_valid_params:
for tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k in tile:
sparse = self.config.problem.datatype_map['matrix_a'] == 'fp16' and \
self.config.problem.datatype_map['matrix_b'] == 'fp16' and \
self.config.problem.datatype_map['matrix_c'] == 'fp16' and \
((warp_tile_m == 32 and warp_tile_n == 32 and warp_tile_k == 16) or
(warp_tile_m == 16 and warp_tile_n == 16 and warp_tile_k == 32))
if sparse:
f.write(str(
w_p / f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}_true.cpp") + "\n")
f.write(str(
w_p / f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}_false.cpp") + "\n")
file_name.add(f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp")
for name in file_name:
f.write(str(
w_p / name) + "\n")

def _generate_all_traits(self):
"""Generate all possible kernel traits names."""
Expand Down Expand Up @@ -193,7 +188,7 @@ def _generate_kernel_struct(self, pipeline: str, epilogue: str, scheduler: str,
static constexpr bool kPadN = {pad_n};
static constexpr bool kPadK = {pad_k};

static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{
static float launch(ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{
static constexpr bool permuteA = false;
static constexpr bool permuteB = false;
static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"};
Expand Down Expand Up @@ -306,7 +301,7 @@ def _generate_kernel_struct(self, pipeline: str, epilogue: str, scheduler: str,
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.c_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_));
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_));
}};
ave_time = ck_tile::launch_kernel_preprocess(
stream,
Expand Down Expand Up @@ -502,10 +497,22 @@ def get_tile_value(tile_param): return tile_param.generate_candidates(

def _generate_instantiation_source_files(self):
"""Generate kernel instance instantiation source files """
tile_map = {}
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
for tile in tile_valid_params:
for tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k in tile:
content = f"""
key = f'{tile_m}x{tile_n}x{tile_k}x{warp_m}x{warp_n}x{warp_k}'
value = f'{warp_tile_m}x{warp_tile_n}x{warp_tile_k}'
if key not in tile_map:
tile_map[key] = set()
tile_map[key].add(value)
#print(f"Generating {len(tile_map)} tiles and warps...")
#print(f"Valid traits: {tile_map}")
count = 0
for trait, _ in self.valid_trait_tile_combinations.items():
for block_tile, warp_tiles in tile_map.items():
tile_m, tile_n, tile_k, warp_m, warp_n, warp_k = map(int, block_tile.split('x'))
content = f"""
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

Expand All @@ -514,23 +521,26 @@ def _generate_instantiation_source_files(self):
#include "gemm_{trait}.hpp"

"""
for warp_tile in warp_tiles:
warp_tile_m, warp_tile_n, warp_tile_k = map(int, warp_tile.split('x'))

sparse = self.config.problem.datatype_map['matrix_a'] == 'fp16' and \
self.config.problem.datatype_map['matrix_b'] == 'fp16' and \
self.config.problem.datatype_map['matrix_c'] == 'fp16' and \
((warp_tile_m == 32 and warp_tile_n == 32 and warp_tile_k == 16) or
(warp_tile_m == 16 and warp_tile_n == 16 and warp_tile_k == 32))
self.config.problem.datatype_map['matrix_b'] == 'fp16' and \
self.config.problem.datatype_map['matrix_c'] == 'fp16' and \
((warp_tile_m == 32 and warp_tile_n == 32 and warp_tile_k == 16) or
(warp_tile_m == 16 and warp_tile_n == 16 and warp_tile_k == 32))
if sparse:
sparse_content = content + f"""
template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, true>;
"""
(self.output_dir /
f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}_true.cpp").write_text(sparse_content)

no_sparse_content = content + f"""
template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, false>;
count = count + 1
content = content + f"""
template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, true>;"""
count = count + 1
content = content + f"""
template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, false>;"""
content += f"""
"""
(self.output_dir /
f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}_false.cpp").write_text(no_sparse_content)
(self.output_dir /
f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp").write_text(content)
print(f"Generated {count} kernel instances in total.")

def _generate_dispatcher_file(self):
"""Generate the code block of dispatch mechanism."""
Expand Down Expand Up @@ -570,7 +580,7 @@ def _generate_dispatcher_file(self):
// Use a static local variable
static std::unordered_map<
std::string,
std::vector<std::function<std::tuple<std::string, float>(ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>>
std::vector<std::function<std::tuple<std::string, float>(ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>>
kernel_map;
return kernel_map;
}
Expand All @@ -586,7 +596,7 @@ def _generate_dispatcher_file(self):
for j in range(len(tile)):
tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k = tile[
j]
content += f"""[=](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ """
content += f"""[=](ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{ """
content += f"""
if(structured_sparsity){{ // SMFMA"""
sparse = self.config.problem.datatype_map['matrix_a'] == 'fp16' and \
Expand Down Expand Up @@ -615,7 +625,7 @@ def _generate_dispatcher_file(self):
content += """ }

template <typename Kernel>
static std::tuple<std::string, float> run_kernel(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream)
static std::tuple<std::string, float> run_kernel(ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream)
{
std::string name = Kernel::get_name();
float avg_time = Kernel::launch(args, stream);
Expand Down
27 changes: 15 additions & 12 deletions tile_engine/ops/gemm/gemm_profiler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class GemmProfiler

void benchmark(GemmProblem& gemm_problem,
std::vector<std::function<std::tuple<std::string, float>(
ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables)
ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>& callables)
{
const ALayout layout_a = ALayout{};
const BLayout layout_b = BLayout{};
Expand Down Expand Up @@ -89,17 +89,20 @@ class GemmProfiler
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();

ck_tile::GemmHostArgs gemm_args;
gemm_args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
gemm_args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
gemm_args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
gemm_args.k_batch = gemm_problem.split_k_;
gemm_args.M = gemm_problem.m_;
gemm_args.N = gemm_problem.n_;
gemm_args.K = gemm_problem.k_;
gemm_args.stride_A = gemm_problem.stride_a_;
gemm_args.stride_B = gemm_problem.stride_b_;
gemm_args.stride_C = gemm_problem.stride_c_;
ck_tile::GemmHostArgs<> gemm_args = {
a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
{}, // ds_ptr
c_m_n_dev_buf.GetDeviceBuffer(),
gemm_problem.split_k_,
gemm_problem.m_,
gemm_problem.n_,
gemm_problem.k_,
gemm_problem.stride_a_,
gemm_problem.stride_b_,
{}, // stride_Ds
gemm_problem.stride_c_,
};

ck_tile::HostTensor<CDataType> c_m_n_host_result(ck_tile::host_tensor_descriptor(
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)));
Expand Down