Skip to content

Commit 73571b3

Browse files
committed
generalize test generator and rename test files in prep for SM80
1 parent 4ee508e commit 73571b3

File tree

129 files changed

+21
-10
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

129 files changed

+21
-10
lines changed

test/device/simt_sm50.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import os
2+
import sys
3+
import argparse
24

35
# this file creates the test/unit/gemm/device simt tests and the CMake file to go with it
46
################################################################################
@@ -92,10 +94,10 @@
9294

9395
test_template = """\
9496
#if defined(CUASR_TEST_LEVEL) and (CUASR_TEST_LEVEL >= {21})
95-
TEST(SM50_device_{0}_{1}_{2}srgemm_{4}{5}_{6}, {10}x{11}x{12}_{13}x{14}x1_{15}x{16}_{17}x{18}_{19}x{20}) {{
97+
TEST(SM{22}_device_{0}_{1}_{2}srgemm_{4}{5}_{6}, {10}x{11}x{12}_{13}x{14}x1_{15}x{16}_{17}x{18}_{19}x{20}) {{
9698
using precision = {3};
9799
using OpClass = cutlass::arch::OpClassSimt;
98-
using SmArch = cutlass::arch::Sm50;
100+
using SmArch = cutlass::arch::Sm{22};
99101
100102
using ThreadblockShape = cutlass::gemm::GemmShape<{10}, {11}, {12}>;
101103
using WarpShape = cutlass::gemm::GemmShape<{13}, {14}, {12}>;
@@ -146,7 +148,8 @@ def write_test_to_file(
146148
warp_threadsM,
147149
warp_threadsN,
148150
warps_per_tb,
149-
test_level):
151+
test_level,
152+
sm_arch):
150153
print("{:.0f}x{:.0f}x{:.0f}__{:.0f}x{:.0f}_{:.0f}x{:.0f}_{:.0f}x{:.0f}".format(
151154
threadblock_tile[0], threadblock_tile[1], unroll,
152155
thread_tileM, thread_tileN,
@@ -186,11 +189,12 @@ def write_test_to_file(
186189
int(warp_threadsN), # 18
187190
int(warps_per_tb[0]), # 19
188191
int(warps_per_tb[1]), # 20
189-
int(test_level) # 21
192+
int(test_level), # 21
193+
int(sm_arch) # 22
190194
))
191195

192196

193-
def main(output_dir: str):
197+
def main(args):
194198
# warps per threadblock
195199
warps_per_threadblocks = []
196200
for warps_per_tb0 in WARPS_PER_TB_EDGE:
@@ -242,12 +246,12 @@ def main(output_dir: str):
242246
transC = "n" if column_major_C else "t"
243247

244248
# open file
245-
testfile_name = "simt_{}_{}_{}srgemm_{}{}_{}_sm50.cu".format(
246-
add_op, mult_op, precision_char,
249+
testfile_name = "sm{}_simt_{}_{}_{}srgemm_{}{}_{}.cu".format(
250+
args.sm_arch, add_op, mult_op, precision_char,
247251
transA, transB, transC)
248252
print("\n", testfile_name)
249253

250-
filePath = os.path.join(output_dir, testfile_name)
254+
filePath = os.path.join(args.output_dir, testfile_name)
251255
with open(filePath, "w") as testfile:
252256
write_test_file_header(testfile)
253257

@@ -362,10 +366,17 @@ def main(output_dir: str):
362366
warp_threadsM,
363367
warp_threadsN,
364368
warps_per_tb,
365-
test_level)
369+
test_level,
370+
args.sm_arch)
366371
num_tests += 1
367372
print("Total test count per semi-ring = {}".format(num_tests//len(semiring_operators)))
368373

369374

370375
if __name__ == "__main__":
371-
main(".")
376+
parser = argparse.ArgumentParser()
377+
parser.add_argument("-o", "--output-dir", type=str, required=False, default=".",
378+
help="Path to the output dir.")
379+
parser.add_argument("-sm", "--sm-arch", type=int, required=False, default=50, choices=[50, 80],
380+
help="SM architecture version number,")
381+
args = parser.parse_args(sys.argv[1:])
382+
main(args)

0 commit comments

Comments
 (0)