diff --git a/tests/dataflow/test_packed_systolic.py b/tests/dataflow/test_packed_systolic.py index ab4f5981..052e976d 100644 --- a/tests/dataflow/test_packed_systolic.py +++ b/tests/dataflow/test_packed_systolic.py @@ -7,14 +7,16 @@ import allo.backend.hls as hls import numpy as np -L, D = 2, 2 -M, N, K = L, 1 * D, D -PP = 2 +M, N, K = 8, 8, 4 +PP = 4 P0, P1 = M // PP + 2, N + 2 if PP == 2: np_type = np.int16 allo_type = int16 +elif PP == 4: + np_type = np.int32 + allo_type = int32 else: raise ValueError(f"Unsupported packing factor: {PP}") @@ -26,9 +28,9 @@ def top(): @df.kernel(mapping=[P0, P1]) def gemm( - X_packed: allo_type[L // PP, D], - W_packed: allo_type[D, 1 * D // PP], - Z_packed: allo_type[L // PP, 1 * D], + X_packed: allo_type[M, K // PP], + W_packed: allo_type[K // PP, N], + Z_packed: allo_type[M // PP, N], ): i, j = df.get_pid() # Peripheral kernels @@ -37,11 +39,11 @@ def gemm( with allo.meta_elif(j == 0): # i > 0 for k in range(K): - fifo_A[i, j + 1].put(X_packed[i - 1, k]) + fifo_A[i, j + 1].put(X_packed[(i - 1) * PP, k]) with allo.meta_elif(i == 0): # j > 0 for k in range(K): - fifo_B[i + 1, j].put(W_packed[j // PP, 0]) + fifo_B[i + 1, j].put(W_packed[k // PP, j - 1]) # drain with allo.meta_elif(i == M // PP + 1 and j > 0): @@ -68,14 +70,14 @@ def gemm( def test_packed_systolic(): - X = np.random.randint(-4, 4, size=(L, D)).astype(np.int8) - W_A_cst = np.random.randint(-4, 4, size=(D, 1 * D)).astype(np.int8) + X = np.random.randint(-4, 4, size=(M, K)).astype(np.int8) + W_A_cst = np.random.randint(-4, 4, size=(K, N)).astype(np.int8) - packed_X = np.ascontiguousarray(np.ascontiguousarray(X).view(np_type).transpose()) + packed_X = np.ascontiguousarray(np.ascontiguousarray(X).view(np_type)) W_A_packed = np.ascontiguousarray( np.ascontiguousarray(W_A_cst.transpose()).view(np_type).transpose() ) - Z_packed = np.zeros((L // PP, 1 * D), dtype=np_type) + Z_packed = np.zeros((M // PP, N), dtype=np_type) mod = df.build(top) if hls.is_available("vitis_hls"): mod(packed_X, W_A_packed, Z_packed)