|
1 | 1 | import os
|
| 2 | +import sys |
| 3 | +import argparse |
2 | 4 |
|
3 | 5 | # this file creates the test/unit/gemm/device simt tests and the CMake file to go with it
|
4 | 6 | ################################################################################
|
|
92 | 94 |
|
93 | 95 | test_template = """\
|
94 | 96 | #if defined(CUASR_TEST_LEVEL) and (CUASR_TEST_LEVEL >= {21})
|
95 |
| -TEST(SM50_device_{0}_{1}_{2}srgemm_{4}{5}_{6}, {10}x{11}x{12}_{13}x{14}x1_{15}x{16}_{17}x{18}_{19}x{20}) {{ |
| 97 | +TEST(SM{22}_device_{0}_{1}_{2}srgemm_{4}{5}_{6}, {10}x{11}x{12}_{13}x{14}x1_{15}x{16}_{17}x{18}_{19}x{20}) {{ |
96 | 98 | using precision = {3};
|
97 | 99 | using OpClass = cutlass::arch::OpClassSimt;
|
98 |
| - using SmArch = cutlass::arch::Sm50; |
| 100 | + using SmArch = cutlass::arch::Sm{22}; |
99 | 101 |
|
100 | 102 | using ThreadblockShape = cutlass::gemm::GemmShape<{10}, {11}, {12}>;
|
101 | 103 | using WarpShape = cutlass::gemm::GemmShape<{13}, {14}, {12}>;
|
@@ -146,7 +148,8 @@ def write_test_to_file(
|
146 | 148 | warp_threadsM,
|
147 | 149 | warp_threadsN,
|
148 | 150 | warps_per_tb,
|
149 |
| - test_level): |
| 151 | + test_level, |
| 152 | + sm_arch): |
150 | 153 | print("{:.0f}x{:.0f}x{:.0f}__{:.0f}x{:.0f}_{:.0f}x{:.0f}_{:.0f}x{:.0f}".format(
|
151 | 154 | threadblock_tile[0], threadblock_tile[1], unroll,
|
152 | 155 | thread_tileM, thread_tileN,
|
@@ -186,11 +189,12 @@ def write_test_to_file(
|
186 | 189 | int(warp_threadsN), # 18
|
187 | 190 | int(warps_per_tb[0]), # 19
|
188 | 191 | int(warps_per_tb[1]), # 20
|
189 |
| - int(test_level) # 21 |
| 192 | + int(test_level), # 21 |
| 193 | + int(sm_arch) # 22 |
190 | 194 | ))
|
191 | 195 |
|
192 | 196 |
|
193 |
| -def main(output_dir: str): |
| 197 | +def main(args): |
194 | 198 | # warps per threadblock
|
195 | 199 | warps_per_threadblocks = []
|
196 | 200 | for warps_per_tb0 in WARPS_PER_TB_EDGE:
|
@@ -242,12 +246,12 @@ def main(output_dir: str):
|
242 | 246 | transC = "n" if column_major_C else "t"
|
243 | 247 |
|
244 | 248 | # open file
|
245 |
| - testfile_name = "simt_{}_{}_{}srgemm_{}{}_{}_sm50.cu".format( |
246 |
| - add_op, mult_op, precision_char, |
| 249 | + testfile_name = "sm{}_simt_{}_{}_{}srgemm_{}{}_{}.cu".format( |
| 250 | + args.sm_arch, add_op, mult_op, precision_char, |
247 | 251 | transA, transB, transC)
|
248 | 252 | print("\n", testfile_name)
|
249 | 253 |
|
250 |
| - filePath = os.path.join(output_dir, testfile_name) |
| 254 | + filePath = os.path.join(args.output_dir, testfile_name) |
251 | 255 | with open(filePath, "w") as testfile:
|
252 | 256 | write_test_file_header(testfile)
|
253 | 257 |
|
@@ -362,10 +366,17 @@ def main(output_dir: str):
|
362 | 366 | warp_threadsM,
|
363 | 367 | warp_threadsN,
|
364 | 368 | warps_per_tb,
|
365 |
| - test_level) |
| 369 | + test_level, |
| 370 | + args.sm_arch) |
366 | 371 | num_tests += 1
|
367 | 372 | print("Total test count per semi-ring = {}".format(num_tests//len(semiring_operators)))
|
368 | 373 |
|
369 | 374 |
|
370 | 375 | if __name__ == "__main__":
|
371 |
| - main(".") |
| 376 | + parser = argparse.ArgumentParser() |
| 377 | + parser.add_argument("-o", "--output-dir", type=str, required=False, default=".", |
| 378 | + help="Path to the output dir.") |
| 379 | + parser.add_argument("-sm", "--sm-arch", type=int, required=False, default=50, choices=[50, 80], |
| 380 | + help="SM architecture version number,") |
| 381 | + args = parser.parse_args(sys.argv[1:]) |
| 382 | + main(args) |
0 commit comments