Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/programming_guides/instructions.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ Memory allocation and descriptors
- `T.alloc_shared(shape, dtype, scope='shared.dyn')`: Allocate shared buffer.
- `T.alloc_fragment(shape, dtype, scope='local.fragment')`: Allocate fragment.
- `T.alloc_var(dtype, [init], scope='local.var')`: Scalar var buffer (1 elem).
- `T.alloc_barrier(arrive_count)`: Shared barrier buffer.
- `T.alloc_barrier(arrive_count)`: Allocate and initialize one or more mbarriers.
- `T.alloc_tmem(shape, dtype)`: Tensor memory (TMEM) buffer (Hopper+).
- `T.alloc_reducer(shape, dtype, op='sum', replication=None)`: Reducer buf.
- `T.alloc_descriptor(kind, dtype)`: Generic descriptor allocator.
Expand Down Expand Up @@ -155,7 +155,7 @@ Custom intrinsics
- `T.loop_break()`: Break from current loop via intrinsic.

Barriers, TMA, warp‑group
- Barriers: `T.create_list_of_mbarrier(...)`, `T.get_mbarrier(i)`.
- Barriers: `T.alloc_barrier(arrive_count)`.
- Parity ops: `T.mbarrier_wait_parity(barrier, parity)`, `T.mbarrier_arrive(barrier)`.
- Expect tx: `T.mbarrier_expect_tx(...)`; sugar: `T.barrier_wait(id, parity=None)`.
- TMA: `T.create_tma_descriptor(...)`, `T.tma_load(...)`,
Expand Down
22 changes: 11 additions & 11 deletions examples/minference/example_vertical_slash_sparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def vs_sparse_flashattn_ws(
column_count = T.alloc_var(dtype=int_dtype)
column_index = T.alloc_shared([vertical_size_round], int_dtype, scope="shared")

T.create_list_of_mbarrier([128] * 9)
mbars = T.alloc_barrier([128] * 9)

block_count = BlockCount[bz, by, bx]
column_count = ColumnCount[bz, by, bx]
Expand All @@ -153,29 +153,29 @@ def vs_sparse_flashattn_ws(
if tid >= 128:
T.annotate_producer_reg_dealloc()
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.mbarrier_arrive(mbarrier=8)
T.mbarrier_arrive(mbarrier=mbars[8])
for bi in T.serial(block_count):
k = block_offset[bi]
T.mbarrier_wait_parity(mbarrier=bi % 2 + 4, parity=(((bi & 3) >> 1) ^ 1))
T.mbarrier_wait_parity(mbarrier=mbars[bi % 2 + 4], parity=(((bi & 3) >> 1) ^ 1))
T.copy(K[bz, by, k : k + block_N, :], K_shared[bi % 2, :, :])
T.mbarrier_arrive(mbarrier=bi % 2)
T.mbarrier_wait_parity(mbarrier=bi % 2 + 6, parity=(((bi & 3) >> 1) ^ 1))
T.mbarrier_arrive(mbarrier=mbars[bi % 2])
T.mbarrier_wait_parity(mbarrier=mbars[bi % 2 + 6], parity=(((bi & 3) >> 1) ^ 1))
T.copy(V[bz, by, k : k + block_N, :], V_shared[bi % 2, :, :])
T.mbarrier_arrive(mbarrier=bi % 2 + 2)
T.mbarrier_arrive(mbarrier=mbars[bi % 2 + 2])
else:
T.annotate_consumer_reg_alloc()
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
T.mbarrier_wait_parity(mbarrier=8, parity=0)
T.mbarrier_wait_parity(mbarrier=mbars[8], parity=0)
for bi in T.serial(block_count):
k = block_offset[bi]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k + j, 0, -T.infinity(acc_s.dtype))

T.mbarrier_wait_parity(mbarrier=bi % 2, parity=((bi & 3) >> 1))
T.mbarrier_wait_parity(mbarrier=mbars[bi % 2], parity=((bi & 3) >> 1))
T.gemm(Q_shared, K_shared[bi % 2, :, :], acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.mbarrier_arrive(mbarrier=bi % 2 + 4)
T.mbarrier_arrive(mbarrier=mbars[bi % 2 + 4])

T.copy(scores_max, scores_max_prev)

Expand All @@ -191,10 +191,10 @@ def vs_sparse_flashattn_ws(
acc_o[i, j] = acc_o[i, j] * scores_scale[i]

T.copy(acc_s, acc_s_cast)
T.mbarrier_wait_parity(mbarrier=bi % 2 + 2, parity=((bi & 3) >> 1))
T.mbarrier_wait_parity(mbarrier=mbars[bi % 2 + 2], parity=((bi & 3) >> 1))
T.gemm(acc_s_cast, V_shared[bi % 2, :, :], acc_o, policy=T.GemmWarpPolicy.FullRow)

T.mbarrier_arrive(mbarrier=bi % 2 + 6)
T.mbarrier_arrive(mbarrier=mbars[bi % 2 + 6])

T.reduce_sum(acc_s, scores_sum, dim=1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,21 @@ def main(
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

# create mbarrier for tma
T.create_list_of_mbarrier(mbarrier_list)
mbars = T.alloc_barrier(mbarrier_list)

with T.ws(0):
T.clear(C_local)

for ko in range(T.ceildiv(K, block_K)):
with T.ws(1):
T.mbarrier_wait_parity(mbarrier=ko % num_stages + num_stages, parity=((ko // num_stages) % num_stages) ^ 1)
T.mbarrier_wait_parity(mbarrier=mbars[ko % num_stages + num_stages], parity=((ko // num_stages) % num_stages) ^ 1)
T.copy(A[by * block_M : (by + 1) * block_M, ko * block_K : (ko + 1) * block_K], A_shared[ko % num_stages, :, :])
T.copy(B[ko * block_K : (ko + 1) * block_K, bx * block_N : (bx + 1) * block_N], B_shared[ko % num_stages, :, :])
T.mbarrier_arrive(mbarrier=ko % num_stages)
T.mbarrier_arrive(mbarrier=mbars[ko % num_stages])
with T.ws(0):
T.mbarrier_wait_parity(mbarrier=ko % num_stages, parity=(ko // num_stages) % num_stages)
T.mbarrier_wait_parity(mbarrier=mbars[ko % num_stages], parity=(ko // num_stages) % num_stages)
T.gemm(A_shared[ko % num_stages, :, :], B_shared[ko % num_stages, :, :], C_local)
T.mbarrier_arrive(mbarrier=ko % num_stages + num_stages)
T.mbarrier_arrive(mbarrier=mbars[ko % num_stages + num_stages])

with T.ws(0):
T.copy(C_local, C[by * block_M, bx * block_N])
Expand Down
8 changes: 4 additions & 4 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,17 +183,17 @@ TVM_DLL const Op &create_tma_descriptor();
TVM_DLL const Op &create_tma_im2col_descriptor();

/*!
* \brief Create a list of mbarrier with num_threads
* \brief Create a list of mbarrier with arrive_counts for each barrier
*
* create_list_of_mbarrier(num_threads0, num_threads1, ...)
* create_list_of_mbarrier(arrive_counts0, arrive_counts1, ...)
*
*/
TVM_DLL const Op &create_list_of_mbarrier();

/*!
* \brief Get the mbarrier with barrier_id
* \brief Get the mbarrier injected by compiler via barrier_id
*
* int64_t* GetMBarrier(barrier_id)
* int64_t* get_mbarrier(barrier_id)
*
*/
TVM_DLL const Op &get_mbarrier();
Expand Down
35 changes: 12 additions & 23 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,7 @@ void CodeGenTileLangCUDA::PrintStorageScope(const std::string &scope,
<< "Cannot allocate global memory when targeting CUDA. You must pass "
"all global arrays as input instead";
if (scope == "shared" || scope == "shared.barrier") {
os << "__shared__ ";
os << "__shared__ __align__(" << barrier_alignment_bytes_ << ") ";
} else if (scope == "shared.dyn") {
os << "extern __shared__ __align__(1024) ";
}
Expand Down Expand Up @@ -1689,19 +1689,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
this->stream << ss.str();
this->stream << ");\n";
};
auto print_mbarrier_obj = [&](PrimExpr barrier_id) {
std::ostringstream ss;
if (barrier_id.as<IntImmNode>()) {
// incase the barrier_id is an integer, we need to print the barrier_id as
// an integer
ss << mbarrier_name_ << "[" << barrier_id << "]";
} else {
// otherwise may be a T.get_mbarrier() call or BufferLoad Node
// we need to print the barrier_id as a string
ss << this->PrintExpr(barrier_id);
}
return ss.str();
};
if (op->op.same_as(builtin::ptx_cp_async())) {
// args[0] = dst_access_ptr, args[1] = src_access_ptr, args[2] = bytes,
// args[3] = predicate (optional)
Expand Down Expand Up @@ -1756,23 +1743,25 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
this->PrintIndent();
int barrier_count = Downcast<IntImm>(op->args[0])->value;
auto mbarrier_storage_name = mbarrier_name_ + "_mem";
this->stream << "__shared__ uint64_t " << mbarrier_storage_name << "["
this->stream << "__shared__ __align__(" << barrier_alignment_bytes_
<< ") uint64_t " << mbarrier_storage_name << "["
<< barrier_count << "];\n";
this->PrintIndent();
this->stream << "auto " << mbarrier_name_ << " = reinterpret_cast<"
<< mbarrier_dtype_ << "*>(" << mbarrier_storage_name << ");\n";
} else if (op->op.same_as(tl::get_mbarrier())) {
// get the mbarrier injected by compiler via barrier_id
ICHECK_EQ(op->args.size(), 1);
std::string barrier_id = this->PrintExpr(op->args[0]);
os << mbarrier_name_ + "[" + barrier_id + "]";
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
if (op->args.size() == 1) {
this->PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto mbarrier_obj = this->PrintExpr(op->args[0]);
this->stream << mbarrier_obj << ".arrive();\n";
} else if (op->args.size() == 3) {
this->PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto mbarrier_obj = this->PrintExpr(op->args[0]);
auto cta_id = this->PrintExpr(op->args[1]);
auto pred = this->PrintExpr(op->args[2]);
this->stream << mbarrier_obj << ".arrive(" << cta_id << ", " << pred
Expand All @@ -1784,19 +1773,19 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
} else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
ICHECK_EQ(op->args.size(), 2);
this->PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto mbarrier_obj = this->PrintExpr(op->args[0]);
auto arrive_count = this->PrintExpr(op->args[1]);
this->stream << mbarrier_obj << ".init(" << arrive_count << ");\n";
} else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
if (op->args.size() == 2) {
this->PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto mbarrier_obj = this->PrintExpr(op->args[0]);
auto transaction_bytes = this->PrintExpr(op->args[1]);
this->stream << mbarrier_obj << ".arrive_and_expect_tx("
<< transaction_bytes << ");\n";
} else if (op->args.size() == 4) {
this->PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto mbarrier_obj = this->PrintExpr(op->args[0]);
auto transaction_bytes = this->PrintExpr(op->args[1]);
auto cta_id = this->PrintExpr(op->args[2]);
auto pred = this->PrintExpr(op->args[3]);
Expand All @@ -1816,14 +1805,14 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
} else if (op->op.same_as(tl::mbarrier_expect_tx())) {
ICHECK_EQ(op->args.size(), 2);
this->PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto mbarrier_obj = this->PrintExpr(op->args[0]);
auto transaction_bytes = this->PrintExpr(op->args[1]);
this->stream << mbarrier_obj << ".expect_transaction(" << transaction_bytes
<< ");\n";
} else if (op->op.same_as(tl::mbarrier_wait_parity())) {
ICHECK_EQ(op->args.size(), 2);
this->PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto mbarrier_obj = this->PrintExpr(op->args[0]);
auto phase = this->PrintExpr(op->args[1]);
this->stream << mbarrier_obj << ".wait(" << phase << ");\n";
} else if (op->op.same_as(tl::ptx_init_tensor_memory())) {
Expand All @@ -1846,7 +1835,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
}
auto desc = op->args[0];
ss << this->PrintExpr(desc) << ", ";
ss << print_mbarrier_obj(op->args[1]) << ", ";
ss << this->PrintExpr(op->args[1]) << ", ";
for (size_t i = 2; i < op->args.size() - 1; i++) {
if (i > 2)
ss << ", ";
Expand Down
1 change: 1 addition & 0 deletions src/target/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class CodeGenTileLangCUDA final : public CodeGenC {
// The size of the barrier array in shared memory
int barrier_count_ = -1;
// The name of the mbarrier array in shared memory
// The same as injected_mbarrier_name_ in transform/common/mbarrier.h
const std::string mbarrier_name_ = "mbarrier";
// The type name of the mbarrier array
const std::string mbarrier_dtype_ = "Barrier";
Expand Down
33 changes: 10 additions & 23 deletions src/target/codegen_cutedsl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,20 +297,6 @@ void CodeGenTileLangCuTeDSL::VisitExpr_(const CallNode *op,
stream << ")\n";
};

auto print_mbarrier_obj = [&](PrimExpr barrier_id) {
std::ostringstream ss;
if (barrier_id.as<IntImmNode>()) {
// incase the barrier_id is an integer, we need to print the barrier_id as
// an integer
ss << "(" << mbarrier_name_ << "+" << barrier_id << ")";
} else {
// otherwise may be a T.get_mbarrier() call or BufferLoad Node
// we need to print the barrier_id as a string
ss << PrintExpr_(barrier_id);
}
return ss.str();
};

if (op->op.same_as(builtin::ptx_cp_async())) {
std::string dst = PrintExpr_(op->args[0]);
std::string dst_offset = PrintExpr_(op->args[1]);
Expand Down Expand Up @@ -347,11 +333,11 @@ void CodeGenTileLangCuTeDSL::VisitExpr_(const CallNode *op,
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
if (op->args.size() == 1) {
PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto mbarrier_obj = PrintExpr_(op->args[0]);
stream << "tl.mbarrier_arrive(" << mbarrier_obj << ")\n";
} else if (op->args.size() == 3) {
PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto mbarrier_obj = PrintExpr_(op->args[0]);
auto cta_id = PrintExpr_(op->args[1]);
auto pred = PrintExpr_(op->args[2]);
stream << "tl.mbarrier_arrive(" << mbarrier_obj << ", " << cta_id << ", "
Expand All @@ -363,20 +349,20 @@ void CodeGenTileLangCuTeDSL::VisitExpr_(const CallNode *op,
} else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
ICHECK_EQ(op->args.size(), 2);
PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto mbarrier_obj = PrintExpr_(op->args[0]);
auto arrive_count = PrintExpr_(op->args[1]);
stream << "tl.mbarrier_init(" << mbarrier_obj << ", " << arrive_count
<< ")\n";
} else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
if (op->args.size() == 2) {
PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto mbarrier_obj = PrintExpr_(op->args[0]);
auto transaction_bytes = PrintExpr_(op->args[1]);
stream << "tl.arrive_and_expect_tx(" << mbarrier_obj << ", "
<< transaction_bytes << ")\n";
} else if (op->args.size() == 4) {
PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto mbarrier_obj = PrintExpr_(op->args[0]);
auto transaction_bytes = PrintExpr_(op->args[1]);
auto cta_id = PrintExpr_(op->args[2]);
auto pred = PrintExpr_(op->args[3]);
Expand All @@ -395,14 +381,14 @@ void CodeGenTileLangCuTeDSL::VisitExpr_(const CallNode *op,
} else if (op->op.same_as(tl::mbarrier_expect_tx())) {
ICHECK_EQ(op->args.size(), 2);
PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto mbarrier_obj = PrintExpr_(op->args[0]);
auto transaction_bytes = PrintExpr_(op->args[1]);
stream << "tl.mbarrier_expect_tx(" << mbarrier_obj << ", "
<< transaction_bytes << ")\n";
} else if (op->op.same_as(tl::mbarrier_wait_parity())) {
ICHECK_EQ(op->args.size(), 2);
PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto mbarrier_obj = PrintExpr_(op->args[0]);
auto phase = PrintExpr_(op->args[1]);
stream << "tl.mbarrier_wait(" << mbarrier_obj << ", " << phase << ")\n";
} else if (op->op.same_as(tl::ptx_init_tensor_memory())) {
Expand All @@ -428,7 +414,7 @@ void CodeGenTileLangCuTeDSL::VisitExpr_(const CallNode *op,
}
auto desc = op->args[0];
ss << PrintExpr_(desc) << ", ";
ss << print_mbarrier_obj(op->args[1]) << ", ";
ss << PrintExpr_(op->args[1]) << ", ";
ss << PrintExpr_(op->args[2]) << ", (";
for (size_t i = 3; i < op->args.size() - 1; i++) {
if (i > 3)
Expand Down Expand Up @@ -818,7 +804,8 @@ void CodeGenTileLangCuTeDSL::VisitStmt_(const AllocateNode *op) {
PrintType(op->dtype, stream);
stream << ", " << constant_size << "), (" << constant_size << ",))\n";
} else if (scope == "shared.barrier") {
ICHECK(false) << "Unsupported scope: " << scope;
stream << vid << " = tl.alloc_smem(cutlass.Uint64, size_in_elems="
Copy link
Collaborator Author

@Rachmanino Rachmanino Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plz review this, since im not very familiar with cutedsl details though ci passed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to have a simple mbarrier codegen test for this modification.

<< constant_size << ")\n";
} else if (scope == "local") {
stream << vid << " = tl.make_rmem_tensor((" << constant_size << "),";
PrintType(op->dtype, stream);
Expand Down
35 changes: 35 additions & 0 deletions src/transform/common/mbarrier.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#ifndef TVM_TL_TRANSFORM_COMMON_MBARRIER_H_
#define TVM_TL_TRANSFORM_COMMON_MBARRIER_H_

#include "../../op/builtin.h"
#include <tvm/ir/expr.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt.h>

namespace tvm {
namespace tl {

using namespace tir;

/*!
* \brief Create an mbarrier buffer with shared.barrier storage scope.
*
* \param name The name of the buffer.
* \param num_barriers The number of barriers in the buffer.
* \return A Buffer object for mbarrier with shared.barrier scope.
*/
inline Buffer CreateMBarrierBuffer(const std::string &name, int num_barriers) {
Var data(name, PointerType(PrimType(DataType::UInt(64)), "shared.barrier"));
return Buffer(data, DataType::UInt(64),
{IntImm(DataType::Int(32), num_barriers)}, {}, PrimExpr(), name,
0, 0, kDefault);
}

const std::string injected_mbarrier_name_ =
"mbarrier"; // todo: avoid conflict with user-defined mbarriers

} // namespace tl
} // namespace tvm
#endif // TVM_TL_TRANSFORM_COMMON_MBARRIER_H_
Loading
Loading