|
1 | 1 | import os |
| 2 | +import sys |
| 3 | +import argparse |
2 | 4 |
|
3 | 5 | # this file creates the test/unit/gemm/device simt tests and the CMake file to go with it |
4 | 6 | ################################################################################ |
|
92 | 94 |
|
93 | 95 | test_template = """\ |
94 | 96 | #if defined(CUASR_TEST_LEVEL) and (CUASR_TEST_LEVEL >= {21}) |
95 | | -TEST(SM50_device_{0}_{1}_{2}srgemm_{4}{5}_{6}, {10}x{11}x{12}_{13}x{14}x1_{15}x{16}_{17}x{18}_{19}x{20}) {{ |
| 97 | +TEST(SM{22}_device_{0}_{1}_{2}srgemm_{4}{5}_{6}, {10}x{11}x{12}_{13}x{14}x1_{15}x{16}_{17}x{18}_{19}x{20}) {{ |
96 | 98 | using precision = {3}; |
97 | 99 | using OpClass = cutlass::arch::OpClassSimt; |
98 | | - using SmArch = cutlass::arch::Sm50; |
| 100 | + using SmArch = cutlass::arch::Sm{22}; |
99 | 101 |
|
100 | 102 | using ThreadblockShape = cutlass::gemm::GemmShape<{10}, {11}, {12}>; |
101 | 103 | using WarpShape = cutlass::gemm::GemmShape<{13}, {14}, {12}>; |
@@ -146,7 +148,8 @@ def write_test_to_file( |
146 | 148 | warp_threadsM, |
147 | 149 | warp_threadsN, |
148 | 150 | warps_per_tb, |
149 | | - test_level): |
| 151 | + test_level, |
| 152 | + sm_arch): |
150 | 153 | print("{:.0f}x{:.0f}x{:.0f}__{:.0f}x{:.0f}_{:.0f}x{:.0f}_{:.0f}x{:.0f}".format( |
151 | 154 | threadblock_tile[0], threadblock_tile[1], unroll, |
152 | 155 | thread_tileM, thread_tileN, |
@@ -186,11 +189,12 @@ def write_test_to_file( |
186 | 189 | int(warp_threadsN), # 18 |
187 | 190 | int(warps_per_tb[0]), # 19 |
188 | 191 | int(warps_per_tb[1]), # 20 |
189 | | - int(test_level) # 21 |
| 192 | + int(test_level), # 21 |
| 193 | + int(sm_arch) # 22 |
190 | 194 | )) |
191 | 195 |
|
192 | 196 |
|
193 | | -def main(output_dir: str): |
| 197 | +def main(args): |
194 | 198 | # warps per threadblock |
195 | 199 | warps_per_threadblocks = [] |
196 | 200 | for warps_per_tb0 in WARPS_PER_TB_EDGE: |
@@ -242,12 +246,12 @@ def main(output_dir: str): |
242 | 246 | transC = "n" if column_major_C else "t" |
243 | 247 |
|
244 | 248 | # open file |
245 | | - testfile_name = "simt_{}_{}_{}srgemm_{}{}_{}_sm50.cu".format( |
246 | | - add_op, mult_op, precision_char, |
| 249 | + testfile_name = "sm{}_simt_{}_{}_{}srgemm_{}{}_{}.cu".format( |
| 250 | + args.sm_arch, add_op, mult_op, precision_char, |
247 | 251 | transA, transB, transC) |
248 | 252 | print("\n", testfile_name) |
249 | 253 |
|
250 | | - filePath = os.path.join(output_dir, testfile_name) |
| 254 | + filePath = os.path.join(args.output_dir, testfile_name) |
251 | 255 | with open(filePath, "w") as testfile: |
252 | 256 | write_test_file_header(testfile) |
253 | 257 |
|
@@ -362,10 +366,17 @@ def main(output_dir: str): |
362 | 366 | warp_threadsM, |
363 | 367 | warp_threadsN, |
364 | 368 | warps_per_tb, |
365 | | - test_level) |
| 369 | + test_level, |
| 370 | + args.sm_arch) |
366 | 371 | num_tests += 1 |
367 | 372 | print("Total test count per semi-ring = {}".format(num_tests//len(semiring_operators))) |
368 | 373 |
|
369 | 374 |
|
370 | 375 | if __name__ == "__main__": |
371 | | - main(".") |
| 376 | + parser = argparse.ArgumentParser() |
| 377 | + parser.add_argument("-o", "--output-dir", type=str, required=False, default=".", |
| 378 | + help="Path to the output dir.") |
| 379 | + parser.add_argument("-sm", "--sm-arch", type=int, required=False, default=50, choices=[50, 80], |
| 380 | + help="SM architecture version number,") |
| 381 | + args = parser.parse_args(sys.argv[1:]) |
| 382 | + main(args) |
0 commit comments