@@ -111,11 +111,11 @@ def ref_program(A, B):
111111@pytest .mark .parametrize (
112112 "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads" ,
113113 [
114- (512 , 1024 , 768 , False , True , T .float16 , T .float16 , T .float16 , 128 , 128 , 32 , 2 , 128 ),
115- (512 , 1024 , 768 , False , False , T .float16 , T .float16 , T .float16 , 128 , 128 , 32 , 2 , 128 ),
116- (512 , 1024 , 768 , True , False , T .float16 , T .float16 , T .float16 , 128 , 128 , 32 , 2 , 128 ),
117- (512 , 1024 , 768 , True , True , T .float16 , T .float16 , T .float16 , 128 , 128 , 32 , 2 , 128 ),
118- (128 , 8 , 32 , False , True , T .float16 , T .float16 , T .float16 , 128 , 8 , 32 , 0 , 128 ),
114+ (512 , 1024 , 768 , False , True , T .float16 , T .float16 , T .float32 , 128 , 128 , 32 , 2 , 128 ),
115+ (512 , 1024 , 768 , False , False , T .float16 , T .float16 , T .float32 , 128 , 128 , 32 , 2 , 128 ),
116+ (512 , 1024 , 768 , True , False , T .float16 , T .float16 , T .float32 , 128 , 128 , 32 , 2 , 128 ),
117+ (512 , 1024 , 768 , True , True , T .float16 , T .float16 , T .float32 , 128 , 128 , 32 , 2 , 128 ),
118+ (128 , 8 , 32 , False , True , T .float16 , T .float16 , T .float32 , 128 , 8 , 32 , 0 , 128 ),
119119 (128 , 128 , 128 , False , True , T .int8 , T .int8 , T .int32 , 128 , 128 , 32 , 2 , 128 ),
120120 (128 , 128 , 128 , False , False , T .int8 , T .int8 , T .int32 , 128 , 128 , 32 , 2 , 128 ),
121121 (128 , 128 , 128 , True , False , T .int8 , T .int8 , T .int32 , 128 , 128 , 32 , 2 , 128 ),
@@ -244,11 +244,11 @@ def ref_program(A, B):
244244@pytest .mark .parametrize (
245245 "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads" ,
246246 [
247- (512 , 1024 , 768 , False , False , T .float16 , T .float16 , T .float16 , 128 , 256 , 32 , 2 , 128 ),
248- (512 , 1024 , 768 , False , True , T .float16 , T .float16 , T .float16 , 128 , 256 , 32 , 2 , 128 ),
249- (512 , 1024 , 768 , True , False , T .float16 , T .float16 , T .float16 , 128 , 256 , 32 , 2 , 128 ),
250- (512 , 1024 , 768 , True , True , T .float16 , T .float16 , T .float16 , 128 , 256 , 32 , 2 , 128 ),
251- (128 , 8 , 32 , False , True , T .float16 , T .float16 , T .float16 , 128 , 8 , 32 , 0 , 128 ),
247+ (512 , 1024 , 768 , False , False , T .float16 , T .float16 , T .float32 , 128 , 256 , 32 , 2 , 128 ),
248+ (512 , 1024 , 768 , False , True , T .float16 , T .float16 , T .float32 , 128 , 256 , 32 , 2 , 128 ),
249+ (512 , 1024 , 768 , True , False , T .float16 , T .float16 , T .float32 , 128 , 256 , 32 , 2 , 128 ),
250+ (512 , 1024 , 768 , True , True , T .float16 , T .float16 , T .float32 , 128 , 256 , 32 , 2 , 128 ),
251+ (128 , 8 , 32 , False , True , T .float16 , T .float16 , T .float32 , 128 , 8 , 32 , 0 , 128 ),
252252 (128 , 128 , 128 , False , True , T .int8 , T .int8 , T .int32 , 128 , 128 , 32 , 2 , 128 ),
253253 (128 , 128 , 128 , False , False , T .int8 , T .int8 , T .int32 , 128 , 128 , 32 , 2 , 128 ),
254254 (128 , 128 , 128 , True , False , T .int8 , T .int8 , T .int32 , 128 , 128 , 32 , 2 , 128 ),
@@ -376,11 +376,11 @@ def ref_program(A, B):
376376@pytest .mark .parametrize (
377377 "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads" ,
378378 [
379- (512 , 1024 , 768 , False , False , T .float16 , T .float16 , T .float16 , 128 , 256 , 32 , 2 , 128 ),
380- (512 , 1024 , 768 , False , True , T .float16 , T .float16 , T .float16 , 128 , 256 , 32 , 2 , 128 ),
381- (512 , 1024 , 768 , True , False , T .float16 , T .float16 , T .float16 , 128 , 256 , 32 , 2 , 128 ),
382- (512 , 1024 , 768 , True , True , T .float16 , T .float16 , T .float16 , 128 , 256 , 32 , 2 , 128 ),
383- (128 , 8 , 32 , False , True , T .float16 , T .float16 , T .float16 , 128 , 8 , 32 , 0 , 128 ),
379+ (512 , 1024 , 768 , False , False , T .float16 , T .float16 , T .float32 , 128 , 256 , 32 , 2 , 128 ),
380+ (512 , 1024 , 768 , False , True , T .float16 , T .float16 , T .float32 , 128 , 256 , 32 , 2 , 128 ),
381+ (512 , 1024 , 768 , True , False , T .float16 , T .float16 , T .float32 , 128 , 256 , 32 , 2 , 128 ),
382+ (512 , 1024 , 768 , True , True , T .float16 , T .float16 , T .float32 , 128 , 256 , 32 , 2 , 128 ),
383+ (128 , 8 , 32 , False , True , T .float16 , T .float16 , T .float32 , 128 , 8 , 32 , 0 , 128 ),
384384 (128 , 128 , 32 , False , True , T .int8 , T .int8 , T .int32 , 128 , 128 , 32 , 2 , 128 ),
385385 (128 , 128 , 32 , False , False , T .int8 , T .int8 , T .int32 , 128 , 128 , 32 , 2 , 128 ),
386386 (128 , 128 , 32 , True , False , T .int8 , T .int8 , T .int32 , 128 , 128 , 32 , 2 , 128 ),
@@ -514,12 +514,12 @@ def ref_program(A, B):
514514@pytest .mark .parametrize (
515515 "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads" ,
516516 [
517- (512 , 1024 , 768 , False , False , T .float16 , T .float16 , T .float16 , 128 , 256 , 32 , 2 , 128 ),
518- (512 , 1024 , 768 , False , True , T .float16 , T .float16 , T .float16 , 128 , 256 , 32 , 2 , 128 ),
519- (512 , 1024 , 768 , True , False , T .float16 , T .float16 , T .float16 , 128 , 256 , 32 , 2 , 128 ),
520- (512 , 1024 , 768 , True , True , T .float16 , T .float16 , T .float16 , 128 , 256 , 32 , 2 , 128 ),
517+ (512 , 1024 , 768 , False , False , T .float16 , T .float16 , T .float32 , 128 , 256 , 32 , 2 , 128 ),
518+ (512 , 1024 , 768 , False , True , T .float16 , T .float16 , T .float32 , 128 , 256 , 32 , 2 , 128 ),
519+ (512 , 1024 , 768 , True , False , T .float16 , T .float16 , T .float32 , 128 , 256 , 32 , 2 , 128 ),
520+ (512 , 1024 , 768 , True , True , T .float16 , T .float16 , T .float32 , 128 , 256 , 32 , 2 , 128 ),
521521 (512 , 1024 , 768 , False , True , T .bfloat16 , T .bfloat16 , T .float , 128 , 256 , 32 , 2 , 128 ),
522- (128 , 8 , 128 , False , True , T .float16 , T .float16 , T .float16 , 128 , 8 , 32 , 2 , 128 ),
522+ (128 , 8 , 128 , False , True , T .float16 , T .float16 , T .float32 , 128 , 8 , 32 , 2 , 128 ),
523523 (128 , 8 , 128 , False , True , T .int8 , T .int8 , T .int32 , 128 , 8 , 32 , 2 , 128 ),
524524 (128 , 128 , 128 , False , True , T .int8 , T .int8 , T .int32 , 128 , 128 , 32 , 2 , 128 ),
525525 (128 , 128 , 128 , False , False , T .int8 , T .int8 , T .int32 , 128 , 128 , 32 , 2 , 128 ),
0 commit comments