diff --git a/src/nki_samples/reference/allocated_attention.py b/src/nki_samples/reference/allocated_attention.py index e6e4119..1413100 100644 --- a/src/nki_samples/reference/allocated_attention.py +++ b/src/nki_samples/reference/allocated_attention.py @@ -60,8 +60,7 @@ def allocated_fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, cur_addr = 0 - id0 = nl.arange(0, 128)[:, None] - id1 = nl.arange(0, 128)[None, :] + id0, id1 = nl.mgrid[0:128, 0:128] identity = nl.shared_constant(np.identity(128, dtype=np.int8), dtype=nl.bfloat16) identity_load = nl.ndarray((par_dim(128), 128), dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr)) cur_addr += 128 * pe_in_dt_itemsize @@ -90,16 +89,14 @@ def allocated_fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, cur_addr += v_seq_n_tiles * d_head * pe_in_dt_itemsize for i_v_seq_tile in nl.affine_range(v_seq_n_tiles): - ip_v = nl.arange(v_seq_tile_size)[:, None] - if_v = nl.arange(d_head_tile_size)[None, :] + ip_v, if_v = nl.mgrid[0:v_seq_tile_size, 0:d_head_tile_size] v_local[i_v_seq_tile, ip_v, if_v] = nl.load( v_ref[batch_id, i_v_seq_tile * v_seq_tile_size + ip_v, if_v], dtype=pe_in_dt) q_local = nl.ndarray((q_seq_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(q_seq_n_tiles, ))) # 8kb cur_addr += q_seq_n_tiles * q_seq_tile_size * pe_in_dt_itemsize - ip_q = nl.arange(d_head_tile_size)[:, None] - if_q = nl.arange(q_seq_tile_size)[None, :] + ip_q, if_q = nl.mgrid[0:d_head_tile_size, 0:q_seq_tile_size] for i_q_seq_tile in nl.affine_range(q_seq_n_tiles): q_local[i_q_seq_tile, ip_q, if_q] = nl.load( q_ref[batch_id, ip_q, i_q_seq_tile * q_seq_tile_size + if_q], @@ -108,8 +105,7 @@ def allocated_fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, k_local = nl.ndarray((k_seq_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(k_seq_n_tiles, ))) # 8kb cur_addr += k_seq_n_tiles * k_seq_tile_size * pe_in_dt_itemsize - ip_k = nl.arange(d_head_tile_size)[:, None] - if_k = nl.arange(k_seq_tile_size)[None, :] + ip_k, if_k = nl.mgrid[0:d_head_tile_size, 0:k_seq_tile_size] for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): k_local[i_k_seq_tile, ip_k, if_k] = nl.load( k_ref[batch_id, @@ -184,15 +180,13 @@ def psum_addr(bank_map, idx, pdim_size, fdim_size): for i_interleave_grp in nl.affine_range(2): # A SBUF buffer tile for an independent softmax tile - ip_max = nl.arange(q_seq_tile_size)[:, None] - if_max = nl.arange(k_seq_n_tiles)[None, :] + ip_max, if_max = nl.mgrid[0:q_seq_tile_size, 0:k_seq_n_tiles] # Loop over RHS free of matmul(stationary=tensor_q, moving=tensor_k, contract=d_head) for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): # indent = 4 # Tensor indices for accessing qk result in k_seq_tile_size - ip_qk = nl.arange(q_seq_tile_size)[:, None] - if_qk = nl.arange(k_seq_tile_size)[None, :] + ip_qk, if_qk = nl.mgrid[0:q_seq_tile_size, 0:k_seq_tile_size] ############################################################## # Step 2. matmul(stationary=tensor_q, moving=tensor_k, contract=d_head) @@ -219,12 +213,10 @@ def psum_addr(bank_map, idx, pdim_size, fdim_size): np.max, data=neg_max_res[i_interleave_grp, ip_max, if_max], axis=(1,), dtype=kernel_dtype, negate=True) - ip_softmax = nl.arange(q_seq_tile_size)[:, None] - if_softmax = nl.arange(seqlen)[None, :] - ip_sum_res = nl.arange(q_seq_tile_size)[:, None] - if_sum_res = nl.arange(d_head_tile_size)[None, :] + ip_softmax, if_softmax = nl.mgrid[0:q_seq_tile_size, 0:seqlen] + ip_sum_res, if_sum_res = nl.mgrid[0:q_seq_tile_size, 0:d_head_tile_size] - if_reduction = nl.arange(reduction_size)[None, :] + _, if_reduction = nl.mgrid[0:1, 0:reduction_size] for i_exp in nl.affine_range(reduction_tiles): exp_res[i_interleave_grp, ip_softmax, i_exp*reduction_size + if_reduction] = nisa.activation_reduce(np.exp, data=qk_res_buf[i_interleave_grp, ip_softmax, i_exp * reduction_size + if_reduction], @@ -242,30 +234,25 @@ def psum_addr(bank_map, idx, pdim_size, fdim_size): ################################### # Step 5. transpose(softmax_res) ################################### - ip_scores_t = nl.arange(v_seq_tile_size)[:, None] - if_scores_t = nl.arange(v_seq_tile_size)[None, :] + ip_scores_t, if_scores_t = nl.mgrid[0:v_seq_tile_size, 0:v_seq_tile_size] # Loop over matmul_1 contraction for i_v_seq_tile in nl.affine_range(v_seq_n_tiles // 4): for i_offset in nl.affine_range(4): - ip_scores = nl.arange(v_seq_tile_size)[:, None] - if_scores = nl.arange(v_seq_tile_size)[None, :] - + ip_scores, if_scores = nl.mgrid[0:v_seq_tile_size, 0:v_seq_tile_size] local_tp_buf[i_interleave_grp, i_v_seq_tile, ip_scores, i_offset*v_seq_tile_size + if_scores] = nisa.nc_matmul( exp_res[i_interleave_grp, ip_scores, (i_v_seq_tile*4+i_offset) * v_seq_tile_size + if_scores], identity_load) - if_batch = nl.arange(k_seq_tile_size)[None, :] + _, if_batch = nl.mgrid[0:1, 0:k_seq_tile_size] trans_softmax_res[i_interleave_grp, ip_scores_t, i_v_seq_tile*k_seq_tile_size + if_batch] = nl.copy(local_tp_buf[i_interleave_grp, i_v_seq_tile, ip_scores, if_batch]) - ip_out = nl.arange(d_head_tile_size)[:, None] - if_out = nl.arange(q_seq_tile_size)[None, :] + ip_out, if_out = nl.mgrid[0:d_head_tile_size, 0:q_seq_tile_size] for i_v_seq_tile in nl.affine_range(v_seq_n_tiles): ###################################################################### # Step 6. matmul_1(stationary=v_local, moving=trans_softmax_res, contract=seqlen_v=seqlen_k) ###################################################################### - ip_v_t = nl.arange(v_seq_tile_size)[:, None] - if_v_t = nl.arange(d_head_tile_size)[None, :] + ip_v_t, if_v_t = nl.mgrid[0:v_seq_tile_size, 0:d_head_tile_size] attn_res_psum[i_interleave_grp, ip_out, if_out] += \ nisa.nc_matmul(moving=trans_softmax_res[i_interleave_grp, ip_scores_t, i_v_seq_tile*v_seq_tile_size+if_scores_t], stationary=v_local[i_v_seq_tile, ip_v_t, if_v_t]) diff --git a/src/nki_samples/reference/attention.py b/src/nki_samples/reference/attention.py index a6e1f9a..ffe8cad 100644 --- a/src/nki_samples/reference/attention.py +++ b/src/nki_samples/reference/attention.py @@ -1029,37 +1029,42 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask= trans_v = nl.ndarray((par_dim(v_seq_tile_size), v_seq_n_tiles, d_head), dtype=pe_in_dt) for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): - ip_v = nl.arange(v_seq_tile_size)[:, None] - if_v = nl.arange(d_head_tile_size)[None, :] + ip_v, if_v = nl.mgrid[0:v_seq_tile_size, 0:d_head_tile_size] trans_v[ip_v, i_k_seq_tile, if_v] = nl.load( v_ref[batch_id, i_k_seq_tile * k_seq_tile_size + ip_v, if_v], dtype=pe_in_dt) q_local = nl.ndarray((q_seq_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=pe_in_dt) - ip_q = nl.arange(d_head_tile_size)[:, None] - if_q = nl.arange(q_seq_tile_size)[None, :] + ip_q, if_q = nl.mgrid[0:d_head_tile_size, 0:q_seq_tile_size] for i_q_seq_tile in nl.affine_range(q_seq_n_tiles): q_local[i_q_seq_tile, ip_q, if_q] = nl.load( q_ref[batch_id, ip_q, i_q_seq_tile * q_seq_tile_size + if_q], dtype=pe_in_dt) * softmax_scale k_local = nl.ndarray((k_seq_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), dtype=pe_in_dt) - ip_k = nl.arange(d_head_tile_size)[:, None] - if_k = nl.arange(k_seq_tile_size)[None, :] + ip_k, if_k = nl.mgrid[0:d_head_tile_size, 0:k_seq_tile_size] for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): + idx_1, idx_2 = nl.mgrid[0:k_seq_tile_size, 0:d_head_tile_size] + k_local[i_k_seq_tile, ip_k, if_k] = nl.load_transpose2d( + k_ref[batch_id, + i_k_seq_tile * k_seq_tile_size + idx_1, + idx_2], + dtype=pe_in_dt) + for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): + idx_1, idx_2 = nl.mgrid[0:k_seq_tile_size, 0:d_head_tile_size] k_local[i_k_seq_tile, ip_k, if_k] = nl.load_transpose2d( - k_ref[batch_id, - i_k_seq_tile * k_seq_tile_size + nl.arange(k_seq_tile_size)[:, None], - nl.arange(d_head_tile_size)[None, :]], - dtype=pe_in_dt) + k_ref[batch_id, + i_k_seq_tile * k_seq_tile_size + idx_1, + idx_2], + dtype=pe_in_dt) + for i_q_seq_tile in nl.affine_range(q_seq_n_tiles): # indent = 2 # A SBUF buffer for an independent softmax tile qk_res_buf = nl.ndarray((par_dim(q_seq_tile_size), seqlen), dtype=kernel_dtype) neg_max_res = nl.ndarray((par_dim(q_seq_tile_size), k_seq_n_tiles), dtype=kernel_dtype) - ip_max = nl.arange(q_seq_tile_size)[:, None] - if_max = nl.arange(k_seq_n_tiles)[None, :] + ip_max, if_max = nl.mgrid[0:q_seq_tile_size, 0:k_seq_n_tiles] # Loop over RHS free of matmul(stationary=tensor_q, moving=tensor_k, contract=d_head) for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): # indent = 4 @@ -1070,8 +1075,7 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask= dtype=np.float32, buffer=nl.psum) # Tensor indices for accessing qk result in k_seq_tile_size - ip_qk = nl.arange(q_seq_tile_size)[:, None] - if_qk = nl.arange(k_seq_tile_size)[None, :] + ip_qk, if_qk = nl.mgrid[0:q_seq_tile_size, 0:k_seq_tile_size] ############################################################## # Step 2. matmul(stationary=tensor_q, moving=tensor_k, contract=d_head) @@ -1105,10 +1109,8 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask= np.min, data=neg_max_res[ip_max, if_max], axis=(1,), dtype=kernel_dtype, negate=False) - ip_softmax = nl.arange(q_seq_tile_size)[:, None] - if_softmax = nl.arange(seqlen)[None, :] - ip_sum_res = nl.arange(q_seq_tile_size)[:, None] - if_sum_res = nl.arange(d_head_tile_size)[None, :] + ip_softmax, if_softmax = nl.mgrid[0:q_seq_tile_size, 0:seqlen] + ip_sum_res, if_sum_res = nl.mgrid[0:q_seq_tile_size, 0:d_head_tile_size] softmax_res = nl.ndarray((par_dim(q_seq_tile_size), seqlen), dtype=pe_in_dt) sum_divisor = nl.ndarray((par_dim(q_seq_tile_size), d_head_tile_size), dtype=kernel_dtype) @@ -1135,27 +1137,23 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask= attn_res_psum = nl.zeros((par_dim(d_head_tile_size), q_seq_tile_size), dtype=np.float32, buffer=nl.psum) - ip_scores_t = nl.arange(k_seq_tile_size)[:, None] - if_scores_t = nl.arange(q_seq_tile_size)[None, :] + ip_scores_t, if_scores_t = nl.mgrid[0:k_seq_tile_size, 0:q_seq_tile_size] # Loop over matmul_1 contraction for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): ################################### # Step 5. transpose(softmax_res) ################################### - ip_scores = nl.arange(q_seq_tile_size)[:, None] - if_scores = nl.arange(k_seq_tile_size)[None, :] + ip_scores, if_scores = nl.mgrid[0:q_seq_tile_size, 0:k_seq_tile_size] trans_softmax_res[ip_scores_t, i_k_seq_tile, if_scores_t] = nisa.nc_transpose( softmax_res[ip_scores, i_k_seq_tile * k_seq_tile_size + if_scores]) - ip_out = nl.arange(d_head_tile_size)[:, None] - if_out = nl.arange(q_seq_tile_size)[None, :] + ip_out, if_out = nl.mgrid[0:d_head_tile_size, 0:q_seq_tile_size] for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): ###################################################################### # Step 6. matmul_1(stationary=trans_v, moving=trans_softmax_res, contract=seqlen_v=seqlen_k) ###################################################################### - ip_v_t = nl.arange(k_seq_tile_size)[:, None] - if_v_t = nl.arange(d_head_tile_size)[None, :] + ip_v_t, if_v_t = nl.mgrid[0:k_seq_tile_size, 0:d_head_tile_size] attn_res_psum[ip_out, if_out] += \ nisa.nc_matmul(moving=trans_softmax_res[ip_scores_t, i_k_seq_tile, if_scores_t], stationary=trans_v[ip_v_t, i_k_seq_tile, if_v_t]) diff --git a/src/nki_samples/reference/tutorial.py b/src/nki_samples/reference/tutorial.py index b32492b..6a02121 100644 --- a/src/nki_samples/reference/tutorial.py +++ b/src/nki_samples/reference/tutorial.py @@ -13,8 +13,7 @@ def add_kernel_nx8x128x512(a_ptr, b_ptr, n_elements): c_ptr = nl.ndarray(a_ptr.shape, dtype=a_ptr.dtype, buffer=nl.shared_hbm) - ix = nl.arange(128)[:, None] - iy = nl.arange(512)[None, :] + ix, iy = nl.mgrid[0:128, 0:512] tile_size = 128 * 512 block_size = 8 * tile_size diff --git a/src/nki_samples/reference/vision.py b/src/nki_samples/reference/vision.py index 4899d27..763a953 100644 --- a/src/nki_samples/reference/vision.py +++ b/src/nki_samples/reference/vision.py @@ -201,8 +201,7 @@ def resize_nearest_fixed_dma_kernel(data_tensor, out_shape): data_tile = data_tensor.reshape(shape=(in_b, in_seqlen, in_c)) out_tile = out_tensor.reshape(shape=(out_b, out_seqlen, out_c)) - b_map = nl.arange(in_b)[:, None] - c_map = nl.arange(out_c)[None, :] + b_map, c_map = nl.mgrid[0:in_b, 0:out_c] for i in nl.static_range(len(map)): target_addr = data_tile[b_map, map[i], c_map] diff --git a/src/nki_samples/tutorials/layernorm/layernorm_nki_kernel.py b/src/nki_samples/tutorials/layernorm/layernorm_nki_kernel.py index c0c235c..596bf30 100644 --- a/src/nki_samples/tutorials/layernorm/layernorm_nki_kernel.py +++ b/src/nki_samples/tutorials/layernorm/layernorm_nki_kernel.py @@ -28,9 +28,7 @@ def nki_layernorm_kernel_v1(input_tensor, epsilon, gamma_vector, beta_vector): assert input_tensor.shape[1] == gamma_vector.shape[0] == beta_vector.shape[0] # Generate tile indices for loading/storing data - i_p_io = nl.arange(nl.tile_size.pmax)[:, None] - i_f_io = nl.arange(input_tensor.shape[1])[None, :] - i_p_param = nl.arange(1)[:, None] + i_p_io, i_f_io, i_p_param = nl.mgrid[0:nl.tile_size.pmax, 0:input_tensor.shape[1], 0:1] # Number of rows in the input tensor num_rows = input_tensor.shape[0] @@ -81,9 +79,7 @@ def nki_layernorm_kernel_v2(input_tensor, epsilon, gamma_vector, beta_vector): assert input_tensor.shape[1] == gamma_vector.shape[0] == beta_vector.shape[0] # Generate tile indices for loading/storing data - i_p_io = nl.arange(nl.tile_size.pmax)[:, None] - i_f_io = nl.arange(input_tensor.shape[1])[None, :] - i_p_param = nl.arange(1)[:, None] + i_p_io, i_f_io, i_p_param = nl.mgrid[0:nl.tile_size.pmax, 0:input_tensor.shape[1], 0:1] # Number of rows in the input tensor num_rows = input_tensor.shape[0] @@ -104,8 +100,7 @@ def nki_layernorm_kernel_v2(input_tensor, epsilon, gamma_vector, beta_vector): # Tile free dimension of the input tensor by nl.tile_size.bn_stats_fmax, # as bn_stats has a free dimension size limit - i_f_bn = nl.arange(nl.tile_size.bn_stats_fmax)[None, :] - i_f_stats = nl.arange(6)[None, :] + i_f_bn, i_f_stats = nl.mgrid[0:nl.tile_size.bn_stats_fmax, 0:6] num_bn_stats = math.ceil(input_tensor.shape[1]/nl.tile_size.bn_stats_fmax) stats_results = nl.ndarray((nl.tile_size.pmax, 6*num_bn_stats), dtype=np.float32) for j in nl.affine_range(num_bn_stats): diff --git a/src/nki_samples/tutorials/rmsnorm/rmsnorm_nki_kernels.py b/src/nki_samples/tutorials/rmsnorm/rmsnorm_nki_kernels.py index 402eecd..59662ee 100644 --- a/src/nki_samples/tutorials/rmsnorm/rmsnorm_nki_kernels.py +++ b/src/nki_samples/tutorials/rmsnorm/rmsnorm_nki_kernels.py @@ -25,9 +25,7 @@ def nki_rmsnorm_kernel(a_tensor, g_tensor): assert a_tensor.shape[1] == g_tensor.shape[0] # Generate tensor indices to index input tensor - ix = nl.arange(128)[:, None] - iw = nl.arange(1)[:, None] - iy = nl.arange(a_tensor.shape[1])[None, :] + ix, iw, iy = nl.mgrid[0:128, 0:1, 0:a_tensor.shape[1]] num_rows = a_tensor.shape[0] diff --git a/src/nki_samples/tutorials/sd_attention/sd_attention_nki_kernels.py b/src/nki_samples/tutorials/sd_attention/sd_attention_nki_kernels.py index dd5509c..9c7a991 100644 --- a/src/nki_samples/tutorials/sd_attention/sd_attention_nki_kernels.py +++ b/src/nki_samples/tutorials/sd_attention/sd_attention_nki_kernels.py @@ -74,38 +74,33 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask= trans_v = nl.ndarray((par_dim(v_seq_tile_size), v_seq_n_tiles, d_head), dtype=pe_in_dt) for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): - ip_v = nl.arange(v_seq_tile_size)[:, None] - if_v = nl.arange(d_head_tile_size)[None, :] - trans_v[ip_v, i_k_seq_tile, if_v] = nl.load( - v_ref[i_k_seq_tile * k_seq_tile_size + ip_v, if_v], - dtype=pe_in_dt) + ip_v_, if_v_ = nl.mgrid[0:v_seq_tile_size, 0:d_head_tile_size] + trans_v[ip_v_, i_k_seq_tile, if_v_] = nl.load( + v_ref[i_k_seq_tile * k_seq_tile_size + ip_v_, if_v_], + dtype=pe_in_dt) q_local = nl.ndarray((q_seq_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=pe_in_dt) - ip_q = nl.arange(d_head_tile_size)[:, None] - if_q = nl.arange(q_seq_tile_size)[None, :] + ip_q_, if_q_ = nl.mgrid[0:d_head_tile_size, 0:q_seq_tile_size] for i_q_seq_tile in nl.affine_range(q_seq_n_tiles): - q_local[i_q_seq_tile, ip_q, if_q] = nl.load_transpose2d( - q_ref[i_q_seq_tile * q_seq_tile_size + nl.arange(q_seq_tile_size)[:, None], - nl.arange(d_head_tile_size)[None, :] - ], - dtype=pe_in_dt) * softmax_scale + idx_1, idx_2 = nl.mgrid[0:q_seq_tile_size, 0:d_head_tile_size] + q_local[i_q_seq_tile, ip_q_, if_q_] = nl.load_transpose2d( + q_ref[i_q_seq_tile * q_seq_tile_size + idx_1, idx_2], + dtype=pe_in_dt) * softmax_scale k_local = nl.ndarray((k_seq_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), dtype=pe_in_dt) - ip_k = nl.arange(d_head_tile_size)[:, None] - if_k = nl.arange(k_seq_tile_size)[None, :] + ip_k_, if_k_ = nl.mgrid[0:d_head_tile_size, 0:k_seq_tile_size] for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): - k_local[i_k_seq_tile, ip_k, if_k] = nl.load_transpose2d( - k_ref[i_k_seq_tile * k_seq_tile_size + nl.arange(k_seq_tile_size)[:, None], - nl.arange(d_head_tile_size)[None, :]], - dtype=pe_in_dt) + idx_1, idx_2 = nl.mgrid[0:k_seq_tile_size, 0:d_head_tile_size] + k_local[i_k_seq_tile, ip_k_, if_k_] = nl.load_transpose2d( + k_ref[i_k_seq_tile * k_seq_tile_size + idx_1, idx_2], + dtype=pe_in_dt) for i_q_seq_tile in nl.affine_range(q_seq_n_tiles): # indent = 2 # A SBUF buffer for an independent softmax tile qk_res_buf = nl.ndarray((par_dim(q_seq_tile_size), seqlen), dtype=kernel_dtype) neg_max_res = nl.ndarray((par_dim(q_seq_tile_size), k_seq_n_tiles), dtype=kernel_dtype) - ip_max = nl.arange(q_seq_tile_size)[:, None] - if_max = nl.arange(k_seq_n_tiles)[None, :] + ip_max, if_max = nl.mgrid[0:q_seq_tile_size, 0:k_seq_n_tiles] # Loop over RHS free of matmul(stationary=tensor_q, moving=tensor_k, contract=d_head) for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): # indent = 4 @@ -116,8 +111,7 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask= dtype=np.float32, buffer=nl.psum) # Tensor indices for accessing qk result in k_seq_tile_size - ip_qk = nl.arange(q_seq_tile_size)[:, None] - if_qk = nl.arange(k_seq_tile_size)[None, :] + ip_qk, if_qk = nl.mgrid[0:q_seq_tile_size, 0:k_seq_tile_size] ############################################################## # Step 2. matmul(stationary=tensor_q, moving=tensor_k, contract=d_head) @@ -149,10 +143,8 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask= np.min, data=neg_max_res[ip_max, if_max], axis=(1,), dtype=kernel_dtype, negate=False) - ip_softmax = nl.arange(q_seq_tile_size)[:, None] - if_softmax = nl.arange(seqlen)[None, :] - ip_sum_res = nl.arange(q_seq_tile_size)[:, None] - if_sum_res = nl.arange(d_head_tile_size)[None, :] + ip_softmax, if_softmax = nl.mgrid[0:q_seq_tile_size, 0:seqlen] + ip_sum_res, if_sum_res = nl.mgrid[0:q_seq_tile_size, 0:d_head_tile_size] softmax_res = nl.ndarray((par_dim(q_seq_tile_size), seqlen), dtype=pe_in_dt) sum_divisor = nl.ndarray((par_dim(q_seq_tile_size), d_head_tile_size), dtype=kernel_dtype) @@ -179,27 +171,24 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask= attn_res_psum = nl.zeros((par_dim(d_head_tile_size), q_seq_tile_size), dtype=np.float32, buffer=nl.psum) - ip_scores_t = nl.arange(k_seq_tile_size)[:, None] - if_scores_t = nl.arange(q_seq_tile_size)[None, :] + ip_scores_t, if_scores_t = nl.mgrid[0:k_seq_tile_size, 0:q_seq_tile_size] + # Loop over matmul_1 contraction for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): ################################### # Step 5. transpose(softmax_res) ################################### - ip_scores = nl.arange(q_seq_tile_size)[:, None] - if_scores = nl.arange(k_seq_tile_size)[None, :] + ip_scores, if_scores = nl.mgrid[0:q_seq_tile_size, 0:k_seq_tile_size] trans_softmax_res[ip_scores_t, i_k_seq_tile, if_scores_t] = nisa.nc_transpose( softmax_res[ip_scores, i_k_seq_tile * k_seq_tile_size + if_scores]) - ip_out = nl.arange(d_head_tile_size)[:, None] - if_out = nl.arange(q_seq_tile_size)[None, :] + ip_out, if_out = nl.mgrid[0:d_head_tile_size, 0:q_seq_tile_size] for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): ###################################################################### # Step 6. matmul_1(stationary=trans_v, moving=trans_softmax_res, contract=seqlen_v=seqlen_k) ###################################################################### - ip_v_t = nl.arange(k_seq_tile_size)[:, None] - if_v_t = nl.arange(d_head_tile_size)[None, :] + ip_v_t, if_v_t = nl.mgrid[0:k_seq_tile_size, 0:d_head_tile_size] attn_res_psum[ip_out, if_out] += \ nisa.nc_matmul(moving=trans_softmax_res[ip_scores_t, i_k_seq_tile, if_scores_t], stationary=trans_v[ip_v_t, i_k_seq_tile, if_v_t]) diff --git a/src/nki_samples/tutorials/softmax/softmax_nki_kernels.py b/src/nki_samples/tutorials/softmax/softmax_nki_kernels.py index 0dc8be1..a7289db 100644 --- a/src/nki_samples/tutorials/softmax/softmax_nki_kernels.py +++ b/src/nki_samples/tutorials/softmax/softmax_nki_kernels.py @@ -11,8 +11,7 @@ def nki_softmax_kernel(a_tensor): buffer=nl.shared_hbm) # Generate tensor indices to index input tensor - ix = nl.arange(128)[:, None] - iy = nl.arange(a_tensor.shape[1])[None, :] + ix, iy = nl.mgrid[0:128, 0:a_tensor.shape[1]] num_rows = a_tensor.shape[0] diff --git a/src/nki_samples/tutorials/tensor_addition/spmd_tensor_addition_nki_kernels.py b/src/nki_samples/tutorials/tensor_addition/spmd_tensor_addition_nki_kernels.py index 508d5c4..c4f1c4d 100644 --- a/src/nki_samples/tutorials/tensor_addition/spmd_tensor_addition_nki_kernels.py +++ b/src/nki_samples/tutorials/tensor_addition/spmd_tensor_addition_nki_kernels.py @@ -31,8 +31,9 @@ def nki_tensor_add_kernel_(a_input, b_input): offset_i_y = nl.program_id(1) * 512 # Generate tensor indices to index tensors a and b - ix = offset_i_x + nl.arange(128)[:, None] - iy = offset_i_y + nl.arange(512)[None, :] + ix_, iy_ = nl.mgrid[0:128, 0:512] + ix = offset_i_x + ix_ + iy = offset_i_y + iy_ # Load input data from device memory (HBM) to on-chip memory (SBUF) # We refer to an indexed portion of a tensor as an intermediate tensor diff --git a/src/nki_samples/tutorials/tensor_addition/tensor_addition_nki_kernels.py b/src/nki_samples/tutorials/tensor_addition/tensor_addition_nki_kernels.py index ea72488..788dcbe 100644 --- a/src/nki_samples/tutorials/tensor_addition/tensor_addition_nki_kernels.py +++ b/src/nki_samples/tutorials/tensor_addition/tensor_addition_nki_kernels.py @@ -31,8 +31,9 @@ def nki_tensor_add_kernel_(a_input, b_input): offset_i_y = nl.program_id(1) * 512 # Generate tensor indices to index tensors a and b - ix = offset_i_x + nl.arange(128)[:, None] - iy = offset_i_y + nl.arange(512)[None, :] + ix_, iy_ = nl.mgrid[0:128, 0:512] + ix = offset_i_x + ix_ + iy = offset_i_y + iy_ # Load input data from device memory (HBM) to on-chip memory (SBUF) # We refer to an indexed portion of a tensor as an intermediate tensor diff --git a/src/nki_samples/tutorials/transpose2d/transpose2d_nki_kernels.py b/src/nki_samples/tutorials/transpose2d/transpose2d_nki_kernels.py index 171e6ed..3ca2796 100644 --- a/src/nki_samples/tutorials/transpose2d/transpose2d_nki_kernels.py +++ b/src/nki_samples/tutorials/transpose2d/transpose2d_nki_kernels.py @@ -54,9 +54,7 @@ def tensor_transpose2D_kernel_(in_tensor, shape2D): # We're going to need 3 indices to perform f1:f2 transpose. # - i_p0 is the parallel index # - i_f1 and i_f2 are both free-dim indices, and will be used to transpose between the f1/f2 axes - i_p0 = nl.arange(sz_p)[:, None, None] - i_f1 = nl.arange(sz_f1)[None, :, None] - i_f2 = nl.arange(sz_f2)[None, None, :] + i_p0, i_f1, i_f2 = nl.mgrid[0:sz_p, 0:sz_f1, 0:sz_f2] # Perform the transposition via a SBUF-to-SBUF copy, with access-pattern manipulation # Note that we have 2D tensors and 3 indices, since we need to represent a 2D access pattern *per partition*