Skip to content

Commit

Permalink
generalize test generator and rename test files in prep for SM80
Browse files Browse the repository at this point in the history
  • Loading branch information
thakkarV committed Oct 30, 2021
1 parent 4ee508e commit 73571b3
Show file tree
Hide file tree
Showing 129 changed files with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions test/device/simt_sm50.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import sys
import argparse

# this file creates the test/unit/gemm/device simt tests and the CMake file to go with it
################################################################################
Expand Down Expand Up @@ -92,10 +94,10 @@

test_template = """\
#if defined(CUASR_TEST_LEVEL) and (CUASR_TEST_LEVEL >= {21})
TEST(SM50_device_{0}_{1}_{2}srgemm_{4}{5}_{6}, {10}x{11}x{12}_{13}x{14}x1_{15}x{16}_{17}x{18}_{19}x{20}) {{
TEST(SM{22}_device_{0}_{1}_{2}srgemm_{4}{5}_{6}, {10}x{11}x{12}_{13}x{14}x1_{15}x{16}_{17}x{18}_{19}x{20}) {{
using precision = {3};
using OpClass = cutlass::arch::OpClassSimt;
using SmArch = cutlass::arch::Sm50;
using SmArch = cutlass::arch::Sm{22};
using ThreadblockShape = cutlass::gemm::GemmShape<{10}, {11}, {12}>;
using WarpShape = cutlass::gemm::GemmShape<{13}, {14}, {12}>;
Expand Down Expand Up @@ -146,7 +148,8 @@ def write_test_to_file(
warp_threadsM,
warp_threadsN,
warps_per_tb,
test_level):
test_level,
sm_arch):
print("{:.0f}x{:.0f}x{:.0f}__{:.0f}x{:.0f}_{:.0f}x{:.0f}_{:.0f}x{:.0f}".format(
threadblock_tile[0], threadblock_tile[1], unroll,
thread_tileM, thread_tileN,
Expand Down Expand Up @@ -186,11 +189,12 @@ def write_test_to_file(
int(warp_threadsN), # 18
int(warps_per_tb[0]), # 19
int(warps_per_tb[1]), # 20
int(test_level) # 21
int(test_level), # 21
int(sm_arch) # 22
))


def main(output_dir: str):
def main(args):
# warps per threadblock
warps_per_threadblocks = []
for warps_per_tb0 in WARPS_PER_TB_EDGE:
Expand Down Expand Up @@ -242,12 +246,12 @@ def main(output_dir: str):
transC = "n" if column_major_C else "t"

# open file
testfile_name = "simt_{}_{}_{}srgemm_{}{}_{}_sm50.cu".format(
add_op, mult_op, precision_char,
testfile_name = "sm{}_simt_{}_{}_{}srgemm_{}{}_{}.cu".format(
args.sm_arch, add_op, mult_op, precision_char,
transA, transB, transC)
print("\n", testfile_name)

filePath = os.path.join(output_dir, testfile_name)
filePath = os.path.join(args.output_dir, testfile_name)
with open(filePath, "w") as testfile:
write_test_file_header(testfile)

Expand Down Expand Up @@ -362,10 +366,17 @@ def main(output_dir: str):
warp_threadsM,
warp_threadsN,
warps_per_tb,
test_level)
test_level,
args.sm_arch)
num_tests += 1
print("Total test count per semi-ring = {}".format(num_tests//len(semiring_operators)))


if __name__ == "__main__":
main(".")
parser = argparse.ArgumentParser()
parser.add_argument("-o", "--output-dir", type=str, required=False, default=".",
help="Path to the output dir.")
parser.add_argument("-sm", "--sm-arch", type=int, required=False, default=50, choices=[50, 80],
help="SM architecture version number,")
args = parser.parse_args(sys.argv[1:])
main(args)

0 comments on commit 73571b3

Please sign in to comment.