Skip to content

Commit 0683cc0

Browse files
committed
define cluster_ctarank
1 parent 1c94bab commit 0683cc0

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

include/common/util.cuh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,21 @@ struct shared_allocator {
285285
*/
286286
using tma_allocator = shared_allocator<1024>;
287287
using tma_swizzle_allocator = tma_allocator; // swizzled TMA modes require up to 1024 byte alignments :/
288+
289+
/* Get CTA ID within a cluster */
290+
__device__ static inline int3 clusterIdx() {
291+
int3 cluster_idx;
292+
asm volatile("mov.u32 %0, %clusterid.x;\n" : "=r"(cluster_idx.x));
293+
asm volatile("mov.u32 %0, %clusterid.y;\n" : "=r"(cluster_idx.y));
294+
asm volatile("mov.u32 %0, %clusterid.z;\n" : "=r"(cluster_idx.z));
295+
return cluster_idx;
296+
}
297+
__device__ static inline int cluster_ctarank() {
298+
uint32_t ctarank;
299+
asm volatile("mov.u32 %0, %cluster_ctarank;\n" : "=r"(ctarank));
300+
return ctarank;
301+
}
302+
288303
#endif
289304

290305
} // namespace kittens

kernels/group_gemm/group_gemm.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ struct matmul_template {
127127
} else {
128128
tma::expect(args.inputs_cluster_arrived, args.input.b);
129129
tma::load_async(args.input.b, args.globals.B,
130-
{args.common.group_idx, args.common.block_n_idx, args.iter}, args.inputs_arrived);
130+
{args.common.group_idx, args.common.block_n_idx, args.iter}, args.inputs_cluster_arrived);
131131
}
132132
}
133133
}

0 commit comments

Comments
 (0)