diff --git a/.github/workflows/run_simulation_tests.yml b/.github/workflows/run_simulation_tests.yml new file mode 100644 index 0000000..7ea89ef --- /dev/null +++ b/.github/workflows/run_simulation_tests.yml @@ -0,0 +1,32 @@ +name: Run Python Simulation Tests + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +permissions: + contents: read + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.10 + uses: actions/setup-python@v3 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com + python -m pip install wget awscli + python -m pip install pytest + python -m pip install neuronx-cc==2.* + - name: Test with pytest + run: | + PYTHONPATH=$PYTHONPATH:src/ pytest test/unit/ --simulation-only \ No newline at end of file diff --git a/src/nki_samples/reference/__init__.py b/src/nki_samples/reference/__init__.py new file mode 100644 index 0000000..c9e5d37 --- /dev/null +++ b/src/nki_samples/reference/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2023, Amazon.com. All Rights Reserved + +""" +Package containing public kernels for Neuron Kernel Interface (NKI). + +Kernels here are also available in the `neuronxcc.nki.kernels` namespace, and they +are synced with this repository on every Neuron SDK release. + +https://github.com/aws-neuron/nki-samples +""" diff --git a/src/nki_samples/reference/allocated_attention.py b/src/nki_samples/reference/allocated_attention.py new file mode 100644 index 0000000..94b513f --- /dev/null +++ b/src/nki_samples/reference/allocated_attention.py @@ -0,0 +1,283 @@ +import functools +import neuronxcc.nki as nki +import neuronxcc.nki.language as nl +import neuronxcc.nki.isa as nisa +import neuronxcc.nki.compiler as ncc +from neuronxcc.nki.language import par_dim +import numpy as np + +@nki.jit +def allocated_fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, + use_causal_mask=False, + mixed_precision=True): + """ + Allocated fused self attention kernel for small head size Stable Diffusion workload. + + Computes (softmax(Q.T@K)V).T. The wired layout is chosen to avoid transpose as + much as possible to simplify the debug. The kernel uses the direct allocation API, + and implements double buffering to achieve better performance than automatic allocation. + As of NeuronSDK 2.21, it achieves 18% better performance than auto allocated equivalent. + To see the performance gap, you can use ``force_auto_alloc`` decorator to override + manual allocation and benchmark the performance difference. + + This kernel is designed to be used for Stable Diffusion models where the + n_heads is equal to 128. Seqlen must be divisible by 1024, and smaller than 5120. + Assertion is thrown if ``n_heads`` or sequence length does not satisfy the requirement. + These restrictions are to simplify the address calculation in allocations. + + IO tensor layouts: + - q_ptr: shape (bs, d_heads, seq_q) + - k_ptr: shape (bs, d_heads, seq_k) + - v_ptr: shape (bs, seq_v, n_heads) + - out_ptr: shape (bs, d_heads, seq_q) + - We use seq_q and seq_k just for clarity, this kernel requires seq_q == seq_k + + IO tensor dtypes: + - This kernel assumes all IO tensors have the same dtype + - If mixed_precision is True, then all Tensor Engine operation will be performed in + bfloat16 and accumulation will be performed in float32. Otherwise the intermediates + will be in the same type as the inputs. + """ + # Use q_ref dtype as the intermediate tensor dtype + # Assume all IO tensors have the same dtype + kernel_dtype = np.float32 + pe_in_dt = nl.bfloat16 if mixed_precision else kernel_dtype + + kernel_dtype_itemsize = np.dtype(kernel_dtype).itemsize + pe_in_dt_itemsize = np.dtype(pe_in_dt).itemsize + assert q_ref.dtype == k_ref.dtype == v_ref.dtype + + # Shape checking + bs, d_head, seqlen = q_ref.shape + assert d_head <= 128, "Cannot use this kernel for d_head > 128" + assert tuple(q_ref.shape) == (bs, d_head, seqlen), 'Input shape mismatch!' + assert tuple(k_ref.shape) == (bs, d_head, seqlen), 'Input shape mismatch!' + assert tuple(v_ref.shape) == (bs, seqlen, + d_head), f'Input shape mismatch! Expected: {(bs, seqlen, d_head)} Actual: {tuple(v_ref.shape)}' + out_ref = nl.ndarray((bs, d_head, seqlen), dtype=q_ref.dtype, buffer=nl.shared_hbm) + + assert d_head == 128 + + cur_addr = 0 + + id0 = nl.arange(0, 128)[:, None] + id1 = nl.arange(0, 128)[None, :] + identity = nl.shared_constant(np.identity(128, dtype=np.int8), dtype=nl.bfloat16) + identity_load = nl.ndarray((par_dim(128), 128), dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr)) + cur_addr += 128 * pe_in_dt_itemsize + identity_load[id0, id1] = nl.load(identity) + + identity_load_fp32 = nl.ndarray((par_dim(128), 128), dtype=np.float32, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr)) + cur_addr += 128 * np.dtype(np.float32).itemsize + identity_load_fp32[id0, id1] = nl.load(identity) + + # Softmax scaling factor, multiplied onto Q + softmax_scale = 0.125 + + # Different batch samples/attention heads have independent attention + batch_id = nl.program_id(axis=0) + + q_seq_n_tiles, q_seq_tile_size = seqlen // 128, 128 + k_seq_n_tiles, k_seq_tile_size = seqlen // 512, 512 + # No tiling on d_head dimension since the number of d_head fits in SB + d_head_tile_size = d_head + v_seq_n_tiles, v_seq_tile_size = seqlen // 128, 128 + + ################################### + # Step 1. preload tensors + ################################### + v_local = nl.ndarray((v_seq_n_tiles, par_dim(v_seq_tile_size), d_head), dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(v_seq_n_tiles, ))) # 8kb + cur_addr += v_seq_n_tiles * d_head * pe_in_dt_itemsize + + for i_v_seq_tile in nl.affine_range(v_seq_n_tiles): + ip_v = nl.arange(v_seq_tile_size)[:, None] + if_v = nl.arange(d_head_tile_size)[None, :] + v_local[i_v_seq_tile, ip_v, if_v] = nl.load( + v_ref[batch_id, i_v_seq_tile * v_seq_tile_size + ip_v, if_v], + dtype=pe_in_dt) + + q_local = nl.ndarray((q_seq_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(q_seq_n_tiles, ))) # 8kb + cur_addr += q_seq_n_tiles * q_seq_tile_size * pe_in_dt_itemsize + ip_q = nl.arange(d_head_tile_size)[:, None] + if_q = nl.arange(q_seq_tile_size)[None, :] + for i_q_seq_tile in nl.affine_range(q_seq_n_tiles): + q_local[i_q_seq_tile, ip_q, if_q] = nl.load( + q_ref[batch_id, ip_q, i_q_seq_tile * q_seq_tile_size + if_q], + dtype=pe_in_dt) + q_local[i_q_seq_tile, ip_q, if_q] = nl.multiply(q_local[i_q_seq_tile, ip_q, if_q], softmax_scale) + + k_local = nl.ndarray((k_seq_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(k_seq_n_tiles, ))) # 8kb + cur_addr += k_seq_n_tiles * k_seq_tile_size * pe_in_dt_itemsize + ip_k = nl.arange(d_head_tile_size)[:, None] + if_k = nl.arange(k_seq_tile_size)[None, :] + for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): + k_local[i_k_seq_tile, ip_k, if_k] = nl.load( + k_ref[batch_id, + ip_k, + i_k_seq_tile * k_seq_tile_size + if_k + ], + dtype=pe_in_dt) + + for i_q_seq_tile in nl.affine_range(q_seq_n_tiles//2): # indent = 2 + # perform activation and reduction in softmax in larger tile to amortize instruction overhead + reduction_size = 1024 + reduction_tiles = seqlen // reduction_size + + # =================================== SBUF Allocation Starts =================================== + + # The num_free_tiles is intentionally set to (1, ) to disable double buffering on the first matmul. + # From the profile, when the first matmul is double buffered, the tensor_scalar_reduce instruction that writes to this buffer + # spends long time waiting for the matmul it depends on to be executed. The instruction scheduler made a bad decision and + # clogged the pipeline when double buffering is on. This is a workaround to hint the scheduler. + qk_res_buf = nl.ndarray((2, par_dim(q_seq_tile_size), seqlen), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(1, ))) # 32 k + cur_addr += seqlen * kernel_dtype_itemsize + exp_res = nl.ndarray((2, par_dim(q_seq_tile_size), seqlen),dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 16 kb + cur_addr += seqlen * 2 * pe_in_dt_itemsize + trans_softmax_res = nl.ndarray( + (2, par_dim(v_seq_tile_size), seqlen), name='trans_softmax_res', + dtype=pe_in_dt, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 16kb + cur_addr += seqlen * 2 * pe_in_dt_itemsize + + sum_divisor = nl.ndarray((2, par_dim(q_seq_tile_size), d_head_tile_size), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 1kb + cur_addr += 2 * d_head_tile_size * kernel_dtype_itemsize + sum_reciprocal_broadcast = nl.ndarray((2, par_dim(q_seq_tile_size), d_head_tile_size), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 1kb + cur_addr += 2 * d_head_tile_size * kernel_dtype_itemsize + + attn_res_sbuf = nl.ndarray((2, par_dim(d_head_tile_size), q_seq_tile_size), dtype=kernel_dtype, + buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, )), name="attn_res_sbuf") # 1kb + cur_addr += 2 * q_seq_tile_size * kernel_dtype_itemsize + attn_res_div = nl.ndarray((2, par_dim(q_seq_tile_size), d_head_tile_size), dtype=kernel_dtype, + buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2,))) # 1kb + cur_addr += 2 * d_head_tile_size * kernel_dtype_itemsize + + neg_max_res = nl.ndarray((2, par_dim(q_seq_tile_size), k_seq_n_tiles), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 64b + cur_addr += 2 * k_seq_n_tiles * kernel_dtype_itemsize + partial_sum_res = nl.ndarray((2, par_dim(q_seq_tile_size), reduction_tiles), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 32b + cur_addr += 2 * reduction_tiles * kernel_dtype_itemsize + neg_max_res_final = nl.ndarray((2, par_dim(q_seq_tile_size), 1), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 8b + cur_addr += 2 * 1 * kernel_dtype_itemsize + sum_res = nl.ndarray((2, par_dim(q_seq_tile_size), 1), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 8b + cur_addr += 2 * 1 * kernel_dtype_itemsize + sum_reciprocal = nl.ndarray((2, par_dim(q_seq_tile_size), 1), dtype=kernel_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=cur_addr, num_free_tiles=(2, ))) # 8b + cur_addr += 2 * 1 * kernel_dtype_itemsize + + # =================================== SBUF Allocation End =================================== + + qk_psum = nl.ndarray((2, k_seq_n_tiles, par_dim(q_seq_tile_size), k_seq_tile_size), + dtype=np.float32, buffer=ncc.psum.mod_alloc(base_bank=0, num_bank_tiles=(2, 4))) + + assert k_seq_tile_size == 4 * v_seq_tile_size + local_tp_buf = nl.ndarray((2, k_seq_n_tiles, par_dim(q_seq_tile_size), k_seq_tile_size), dtype=np.float32, + buffer=ncc.psum.mod_alloc(base_bank=0, num_bank_tiles=(2, 4))) + + def psum_addr(bank_map, idx, pdim_size, fdim_size): + return (bank_map[idx], 0, 0) + + # Result psum buffer has the hidden dim as P + # qk_psum is using 0, 1, 2, 3 for fisrt interleave group, and 4, 5, 6, 7 for the second. + # assign 1 and 5 avoid bank collision between groups + attn_res_psum = nl.ndarray((2, par_dim(d_head_tile_size), q_seq_tile_size), + dtype=np.float32, buffer=ncc.psum.alloc(functools.partial(psum_addr, bank_map={(0, ): 1, (1, ): 5}))) + + sum_local_tp_buf = nl.ndarray((2, par_dim(q_seq_tile_size), k_seq_tile_size), dtype=np.float32, + buffer=ncc.psum.alloc(functools.partial(psum_addr, bank_map={(0, ): 2, (1, ): 7}))) + + for i_interleave_grp in nl.affine_range(2): + # A SBUF buffer tile for an independent softmax tile + ip_max = nl.arange(q_seq_tile_size)[:, None] + if_max = nl.arange(k_seq_n_tiles)[None, :] + + # Loop over RHS free of matmul(stationary=tensor_q, moving=tensor_k, contract=d_head) + for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): # indent = 4 + + # Tensor indices for accessing qk result in k_seq_tile_size + ip_qk = nl.arange(q_seq_tile_size)[:, None] + if_qk = nl.arange(k_seq_tile_size)[None, :] + + ############################################################## + # Step 2. matmul(stationary=tensor_q, moving=tensor_k, contract=d_head) + ############################################################## + qk_psum[i_interleave_grp, i_k_seq_tile, ip_qk, if_qk] = nisa.nc_matmul(moving=k_local[i_k_seq_tile, ip_k, if_k], + stationary=q_local[i_q_seq_tile*2+i_interleave_grp, ip_q, if_q]) + + ################################### + # Step 3. Apply optional causal mask + ################################### + if use_causal_mask: + assert not use_causal_mask, "Causal mask not supported yet!" + # Magic number -9984.0 to replace -inf similar to what Tensorizer uses + qk_res_buf[i_interleave_grp, ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nisa.affine_select( + pred=(i_q_seq_tile * q_seq_tile_size + ip_qk >= i_k_seq_tile * k_seq_tile_size + if_qk), + on_true_tile=qk_psum[i_interleave_grp, i_k_seq_tile,ip_qk, if_qk], on_false_value=-9984.0, dtype=kernel_dtype) + else: + # Copy result to SBUF and find partial maximum for softmax + qk_res_buf[i_interleave_grp, ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nisa.tensor_scalar_reduce(data=qk_psum[i_interleave_grp, i_k_seq_tile,ip_qk, if_qk], op0=np.add, operand0=1.0, + reduce_op=np.max, reduce_res=neg_max_res[i_interleave_grp, ip_max, i_k_seq_tile], dtype=kernel_dtype) + + # Find global max from tiles + neg_max_res_final[i_interleave_grp, ip_max, 0] = nisa.tensor_reduce( + np.max, data=neg_max_res[i_interleave_grp, ip_max, if_max], + axis=(1,), dtype=kernel_dtype, negate=True) + + ip_softmax = nl.arange(q_seq_tile_size)[:, None] + if_softmax = nl.arange(seqlen)[None, :] + ip_sum_res = nl.arange(q_seq_tile_size)[:, None] + if_sum_res = nl.arange(d_head_tile_size)[None, :] + + if_reduction = nl.arange(reduction_size)[None, :] + for i_exp in nl.affine_range(reduction_tiles): + exp_res[i_interleave_grp, ip_softmax, i_exp*reduction_size + if_reduction] = nisa.activation_reduce(np.exp, + data=qk_res_buf[i_interleave_grp, ip_softmax, i_exp * reduction_size + if_reduction], + reduce_op=np.sum, reduce_res=partial_sum_res[i_interleave_grp, ip_softmax, i_exp], + bias=neg_max_res_final[i_interleave_grp, ip_max, 0], scale=1.0, + ) + + sum_res[i_interleave_grp, ip_softmax, 0] = nisa.tensor_reduce(np.add, data=partial_sum_res[i_interleave_grp, :, :], axis=(1,), + dtype=kernel_dtype) + + sum_reciprocal[i_interleave_grp, ip_softmax, 0] = nl.divide(1.0, sum_res[i_interleave_grp, ip_softmax, 0]) + sum_reciprocal_broadcast[i_interleave_grp, ip_softmax, if_sum_res] = sum_reciprocal[i_interleave_grp, ip_softmax, 0].broadcast_to((q_seq_tile_size, d_head_tile_size)) + sum_divisor[i_interleave_grp, ip_sum_res, if_sum_res] = nl.copy(sum_reciprocal_broadcast[i_interleave_grp, ip_softmax, if_sum_res], dtype=kernel_dtype) + + ################################### + # Step 5. transpose(softmax_res) + ################################### + ip_scores_t = nl.arange(v_seq_tile_size)[:, None] + if_scores_t = nl.arange(v_seq_tile_size)[None, :] + # Loop over matmul_1 contraction + for i_v_seq_tile in nl.affine_range(v_seq_n_tiles // 4): + for i_offset in nl.affine_range(4): + ip_scores = nl.arange(v_seq_tile_size)[:, None] + if_scores = nl.arange(v_seq_tile_size)[None, :] + + local_tp_buf[i_interleave_grp, i_v_seq_tile, ip_scores, i_offset*v_seq_tile_size + if_scores] = nisa.nc_matmul( + exp_res[i_interleave_grp, ip_scores, (i_v_seq_tile*4+i_offset) * v_seq_tile_size + if_scores], + identity_load) + + if_batch = nl.arange(k_seq_tile_size)[None, :] + trans_softmax_res[i_interleave_grp, ip_scores_t, i_v_seq_tile*k_seq_tile_size + if_batch] = nl.copy(local_tp_buf[i_interleave_grp, i_v_seq_tile, ip_scores, if_batch]) + + ip_out = nl.arange(d_head_tile_size)[:, None] + if_out = nl.arange(q_seq_tile_size)[None, :] + + for i_v_seq_tile in nl.affine_range(v_seq_n_tiles): + ###################################################################### + # Step 6. matmul_1(stationary=v_local, moving=trans_softmax_res, contract=seqlen_v=seqlen_k) + ###################################################################### + ip_v_t = nl.arange(v_seq_tile_size)[:, None] + if_v_t = nl.arange(d_head_tile_size)[None, :] + attn_res_psum[i_interleave_grp, ip_out, if_out] += \ + nisa.nc_matmul(moving=trans_softmax_res[i_interleave_grp, ip_scores_t, i_v_seq_tile*v_seq_tile_size+if_scores_t], + stationary=v_local[i_v_seq_tile, ip_v_t, if_v_t]) + + attn_res_sbuf[i_interleave_grp, ip_out, if_out] = nisa.tensor_copy(attn_res_psum[i_interleave_grp, ip_out, if_out], + dtype=kernel_dtype, engine=nisa.vector_engine) + + sum_local_tp_buf[i_interleave_grp, ip_sum_res, if_sum_res] = nisa.nc_matmul(sum_divisor[i_interleave_grp, ip_sum_res, if_sum_res], identity_load_fp32) + attn_res_div[i_interleave_grp, ip_sum_res, if_sum_res] = attn_res_sbuf[i_interleave_grp, :, :] * sum_local_tp_buf[i_interleave_grp, ip_sum_res, if_sum_res] + + nl.store( + out_ref[batch_id, ip_out, (i_q_seq_tile*2+i_interleave_grp) * q_seq_tile_size + if_out], + value=attn_res_div[i_interleave_grp, :, :]) + + return out_ref \ No newline at end of file diff --git a/src/nki_samples/reference/allocated_fused_linear.py b/src/nki_samples/reference/allocated_fused_linear.py new file mode 100644 index 0000000..21e32af --- /dev/null +++ b/src/nki_samples/reference/allocated_fused_linear.py @@ -0,0 +1,114 @@ +""" +Copyright (c) 2024, Amazon.com. All Rights Reserved + +kernels - Fused normalization with linear layers + +""" + +import neuronxcc.nki.language as nl +import neuronxcc.nki.isa as nisa +import neuronxcc.nki.compiler as ncc +import math +import numpy as np +from neuronxcc import nki +from neuronxcc.nki.language import par_dim + +@nki.jit +def allocated_fused_rms_norm_qkv(hidden, weights, norm_dtype=nl.float32, eps=1e-6): + """ + Allocated kernel that computes RMSNorm(hidden) @ wQKV. This kernel is designed to only handle fp16/bf16 tensor types. + Internally, normalizations are cast to fp32 to avoid NaN errors. + + Args: + hidden (_type_): Input tensor of the attention block in BSH layout + weights (_type_): Fused QKV linear weights, assumed to be eltwise-multiplied with RMS norm weight vector (gamma) + out_tensor (_type_): Output tensor + norm_dtype (_type_, optional): Data type for RMS norm, should be f32 to avoid NaN. Defaults to nl.float32. + eps (_type_, optional): RMS norm epsilon term. Defaults to 1e-6. + """ + # Hidden should be in BSH layout. + batch, batchless_shape = hidden.shape[0], hidden.shape[1:] + seqlen, dim = batchless_shape + _dim, head_dim = weights.shape + + assert dim <= 8192 and dim & 128 == 0, "Unsupported hidden dimension" + assert _dim == dim, "Reduction dimension must match" + assert head_dim <= 512, "Head dimension must be 512 or less" + + out_tensor = nl.ndarray((batch, seqlen, head_dim), dtype=hidden.dtype, buffer=nl.shared_hbm) + + pmax, fmax = nl.tile_size.pmax, nl.tile_size.psum_fmax # 128, 512 + ix, iy = nl.mgrid[0:pmax, 0:dim] + i_lhs = nl.mgrid[0:pmax, 0:pmax] + i_rhs = nl.mgrid[0:pmax, 0:fmax] + i_res = nl.mgrid[0:pmax, 0:fmax] + M = math.ceil(dim / pmax) + NUM_TRANSP_TILES = math.ceil(dim / fmax) + NUM_TILES = math.ceil(seqlen / pmax) + TILES_INT = math.ceil(NUM_TILES / 2) + scale = 1 / dim + + iden_x, iden_y = nl.mgrid[0:pmax, 0:128] + + identity_a = nl.shared_constant(np.identity(n=128, dtype=np.int8), dtype=hidden.dtype) + identity_tensor = nl.ndarray((par_dim(pmax), 128), dtype=weights.dtype, buffer=ncc.sbuf.mod_alloc(base_addr=0)) + identity_tensor[iden_x, iden_y] = nl.load(identity_a, dtype=weights.dtype) + bias_placeholder = nl.ndarray((par_dim(pmax), 1), dtype=np.float32, buffer=ncc.sbuf.mod_alloc(base_addr=128*2)) + bias_placeholder[...] = 0 + + for b in nl.affine_range(batch): + weights_buffer = nl.ndarray((M, par_dim(pmax), fmax), dtype=weights.dtype, + buffer=ncc.sbuf.mod_alloc(base_addr=260+(3*dim+fmax)*2+(dim+1)*4, num_free_tiles=(M,))) + # Preload the entire weights tensor. everything fits in SBUF for LLaMA 3.1 70B + for m in nl.affine_range(M): + weights_buffer[m, i_rhs.p, i_rhs.x] = nl.load(weights[m*pmax+i_rhs.p, i_rhs.x], + mask=(m*pmax+i_rhs.p= local_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE + else: + multiplication_required_selection = True + + if multiplication_required_selection: + qk_psum[:, :] = nl.matmul(q_local_tile, k[:, k_i_b_f_slice], transpose_x=True) # (p(128), 512) + else: + qk_psum[:, :] = 0 + + if use_causal_mask: + left_diagonal_selection = q_tile_idx * B_P_SIZE >= local_k_large_tile_idx * LARGE_TILE_SZ + (k_i + 1) * B_F_SIZE + diagonal_and_right_selection = (q_tile_idx * B_P_SIZE < local_k_large_tile_idx * LARGE_TILE_SZ + (k_i + 1) * B_F_SIZE) + right_diagonal_selection = ((q_tile_idx + 1) * B_P_SIZE <= local_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE) + diagonal = ((q_tile_idx * B_P_SIZE < local_k_large_tile_idx * LARGE_TILE_SZ + (k_i + 1) * B_F_SIZE) & + ((q_tile_idx + 1) * B_P_SIZE > local_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE)) + + i_q_p, i_q_f = nl.mgrid[0:B_P_SIZE, 0:B_F_SIZE] + q_pos = q_tile_idx * B_P_SIZE + i_q_p + k_pos = local_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE + i_q_f + pred = q_pos >= k_pos + + qk_select_tmp = nl.ndarray(qk_psum.shape, dtype=qk_psum.dtype, buffer=nl.sbuf) + + if logit_bias_tile is not None: + if right_diagonal_selection: + qk_select_tmp[...] = qk_psum + + # For tiles to the right of the diagonal, do affine_select. + # Magic number -9984.0 to replace -inf similar to what Tensorizer uses + qk_res_buf[:, k_i_b_f_slice] = nisa.affine_select( + pred=pred, + on_true_tile=qk_select_tmp, on_false_value=-9984.0, dtype=acc_type) + + # For tiles on the diagonal, add logit bias and need to do affine_select. + intermediate = \ + nl.add(qk_psum, logit_bias_tile[:, k_i_b_f_slice], + dtype=acc_type, mask=diagonal) + qk_res_buf[:, k_i_b_f_slice] = nisa.affine_select( + pred=pred, + on_true_tile=intermediate, on_false_value=-9984.0, dtype=acc_type, + mask=diagonal) + + # For tiles on the left of the diagonal, just add logit bias, no select required. + qk_res_buf[:, k_i_b_f_slice] = \ + nl.add(qk_psum, logit_bias_tile[:, k_i_b_f_slice], + dtype=acc_type, mask=left_diagonal_selection) + else: + # For tiles on and to the right of the diagonal, need to do affine_select. + # Magic number -9984.0 to replace -inf similar to what Tensorizer uses + if diagonal_and_right_selection: + qk_select_tmp[...] = qk_psum + + qk_res_buf[:, k_i_b_f_slice] = nisa.affine_select( + pred=pred, + on_true_tile=qk_select_tmp, on_false_value=-9984.0, dtype=acc_type) + + # For tiles on the left of the diagonal, direct copy, no select required. + qk_res_buf[:, k_i_b_f_slice] = \ + nl.copy(qk_psum, dtype=acc_type, mask=left_diagonal_selection) + else: + if logit_bias_tile is not None: + # Simply add logit bias which copies back to sbuf at the same time + qk_res_buf[:, k_i_b_f_slice] = \ + nl.add(qk_psum, logit_bias_tile[:, k_i_b_f_slice], dtype=acc_type) + else: + # Simply send psum result back to sbuf + qk_res_buf[:, k_i_b_f_slice] = nl.copy(qk_psum, dtype=acc_type) + + # Calculate max of the current tile + max_local[:, k_i] = nisa.tensor_reduce( + np.max, qk_res_buf[:, k_i_b_f_slice], axis=(1,), dtype=acc_type, + negate=False) + + max_ = nisa.tensor_reduce(np.max, max_local[:, :], axis=(1, ), + dtype=acc_type, negate=False) + + o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE), dtype=o_buffer.dtype) + + if initialize: + m_buffer[:, 0] = nl.copy(max_) + m_current = max_ + else: + m_previous = nl.copy(m_buffer[:, 0]) + m_buffer[:, 0] = nl.maximum(m_previous, max_) # (128,1) + + m_current = m_buffer[:, 0] + # Compute scaling factor + alpha = nisa.activation(np.exp, m_current, bias=m_previous, scale=-1.0) + o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha) + + p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype) + REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2) + + p_partial_sum = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE), dtype=acc_type) + + for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE): + k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE) + + # dropout + if dropout_p > 0.0: + # compute exp(qk-max) + p_local[:, k_r_i_reduce_slice] = \ + nisa.activation(np.exp, qk_res_buf[:, k_r_i_reduce_slice], + bias=-1 * m_current, scale=1.0, + dtype=kernel_dtype) + + seed_offset_base = k_r_i * (REDUCTION_TILE // B_F_SIZE) \ + + local_k_large_tile_idx * (LARGE_TILE_SZ // B_F_SIZE) \ + + q_tile_idx * seq_k_num_tiles \ + + (head_id * q_h_per_k_h + gqa_head_idx) * seq_k_num_tiles * seq_q_num_tiles \ + + batch_id * nheads * seq_k_num_tiles * seq_q_num_tiles + + dropout_p_local(p_local=p_local, dropout_p=dropout_p, + dropout_p_tensor=dropout_p_tensor, seed_tensor=seed_tensor, + seed_offset_base=seed_offset_base, k_r_i=k_r_i, + REDUCTION_TILE=REDUCTION_TILE) + + # Compute partial row-tile sum of exp(qk-max)) + # FIXME: Use activation accumulate and accumulate over k_r_i loop? + p_partial_sum[:, k_r_i] = nl.sum(p_local[:, k_r_i_reduce_slice], + axis=1, dtype=acc_type) + else: + # compute exp(qk-max) + # Compute partial row-tile sum of exp(qk-max)) + # FIXME: Use activation accumulate to accumulate over k_r_i loop? + p_local[:, k_r_i_reduce_slice] = \ + nisa.activation_reduce(np.exp, qk_res_buf[:, k_r_i_reduce_slice], + bias=-1 * m_current, scale=1.0, + reduce_op=nl.add, reduce_res=p_partial_sum[:, k_r_i], + dtype=kernel_dtype) + + ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type) + + p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype) + transpose_p_local(p_local_transposed=p_local_transposed, p_local=p_local, + LARGE_TILE_SZ=LARGE_TILE_SZ) + + pv_psum = nl.zeros((par_dim(B_P_SIZE), B_D_SIZE), dtype=np.float32, + buffer=nl.psum, lazy_initialization=True) + for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE): + pv_psum[:, :] += nl.matmul(p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)], + v[k_i, :, :], transpose_x=True) # (128, 128) (p(Br), d) + + if initialize: + o_buffer[:, :] = nl.copy(pv_psum[:, :]) + l_buffer[:, 0] = nl.add(nl.log(ps), max_) + else: + o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum) + + exp = nisa.activation(nl.exp, m_current, bias=l_buffer[:, 0], scale=-1.0) + l_buffer[:, 0] = nl.add(m_current, nisa.activation(nl.log, exp, bias=ps)) + + +@nki.jit(mode='trace') +def load_v_tile(v_hbm_tile, cur_v_tile, j, v_i, config): + LARGE_TILE_SZ = config.seq_tile_size + B_P_SIZE = 128 + + if not config.should_transpose_v: + cur_v_tile[v_i, :, :] = nl.load( + v_hbm_tile[nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE), :], + dtype=cur_v_tile.dtype) + return + + if nisa.get_nc_version() == nisa.nc_version.gen3: + cur_v_tile_transposed = nisa.dma_transpose( + v_hbm_tile[:, nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)]) + cur_v_tile[v_i, :, :] = nisa.tensor_copy(cur_v_tile_transposed, + dtype=cur_v_tile.dtype) + return + + cur_v_tile[v_i, :, :] = nl.load_transpose2d( + v_hbm_tile[:, nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)], + dtype=cur_v_tile.dtype) + + + +@nki.jit +def flash_fwd(q, k, v, seed, logit_bias=None, + softmax_scale=None, + use_causal_mask=True, + mixed_precision=True, + dropout_p=0.0, config=None): + """ + Flash Attention Forward kernel + + IO tensor layouts: + - q: shape (bs, n_heads, d, seq_q) + - k: shape (bs, nk_heads, d, seq_k) + - v: shape (bs, nv_heads, d, seq_v) if config.should_transpose_v else (bs, nv_heads, seq_v, d) + - seed: shape (1,) + - logit_bias: shape (bs, n_heads, seq_q, seq_k) + - o: shape (bs, n_heads, seq_q, d) + - lse: shape (bs, n_heads, nl.tile_size.pmax, seq // nl.tile_size.pmax) if training else None + - This kernel requires seq_k == seq_v + + IO tensor dtypes: + - This kernel assumes all IO tensors have the same dtype + - If mixed_precision is True, then all Tensor Engine operation will be performed in + bfloat16 and accumulation will be performed in float32. Otherwise the intermediates + will be in the same type as the inputs. + + Compile-time Constants: + - softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)` + - mixed_precision: flag to set non-matmul ops in fp32 precision, default is set to `true`, if false, we use same precision as input types + - causal_mask: flag to set causal masking + - config: Instance of :class:`nki.kernels.attention.FlashConfig` with Performance config parameters for flash attention with default values + seq_tile_size: `default=2048`, size of the kv tile size for attention computation reduction + training: bool to indicate training vs inference `default=True` + + Performance Notes: + For better performance, the kernel is tiled to be of size `config.seq_tile_size`, and Flash attention math techniques are applied in unit + of `config.seq_tile_size`. Seqlen that is not divisible by `config.seq_tile_size` is not supported at the moment. + + For large seqlen, `o_buffer` will overflow the statebuf. the kernel is tile `o_buffer` based on the value of `config.attn_core_tile_size`. + This is a tradeoff between memory usage and performance. The default value of `config.attn_core_tile_size` is 256, which means the `o_buffer` + will roughly take half of the statebuf. The computes are also tiled accordingly. DMA will be rematerialized + `seqlen_q // B_P_SIZE // attn_core_tile_size times`. + + + + GQA support Notes: + the spmd kernel for launching kernel should be on kv_heads instead of nheads + + Example usage: + MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d] + usage: `flash_fwd[b, h](q, k, v, ...)` + GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d] + usage: `flash_fwd[b, kv_h](q, k, v, ...)` + """ + config = config or FlashConfig() + B_F_SIZE=512 + B_P_SIZE=128 + b, h, d, seqlen_q = q.shape + B_D_SIZE = d + _, k_h, _, seqlen_k = k.shape + if config.should_transpose_v: + assert tuple(v.shape) == (b, k_h, d, seqlen_k), f"Expect shape of V to be {(b, k_h, d, seqlen_k)} (batch, heads, d_head, seqlen_k) but got {v.shape}" + assert tuple(k.shape) == (b, k_h, d, seqlen_k), f"Expect shape of K to be {(b, k_h, d, seqlen_k)} (batch, heads, d_head, seqlen_k) but got {k.shape}" + else: + assert tuple(v.shape) == (b, k_h, seqlen_k, d), f"Expect shape of V to be {(b, k_h, seqlen_k, d)} (batch, heads, seqlen_k, d_head) but got {v.shape}" + assert tuple(k.shape) == (b, k_h, d, seqlen_k), f"Expect shape of K to be {(b, k_h, d, seqlen_k)} (batch, heads, d_head, seqlen_k) but got {k.shape}" + assert d <= 128, f" we do not support head_dim > 128, got head dim {d}" + kernel_dtype = nl.bfloat16 if mixed_precision else q.dtype + acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype + + o = nl.ndarray((b, h, seqlen_q, d), dtype=q.dtype, buffer=nl.shared_hbm) + if config.training: + if config.lse_dtype: + lse_dtype = getattr(nl, config.lse_dtype) + else: + lse_dtype = acc_type + lse = nl.ndarray((b, h, nl.tile_size.pmax, seqlen_q // nl.tile_size.pmax), + dtype=lse_dtype, buffer=nl.shared_hbm) + else: + lse = None + + assert nl.program_ndim() == 2,\ + f'Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!' + batch_id = nl.program_id(axis=0) + head_id = nl.program_id(axis=1) + + softmax_scale = softmax_scale or (1.0 / (d ** 0.5)) + + n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine + + LARGE_TILE_SZ = config.seq_tile_size + attn_core_tile_size = config.attn_core_tile_size + + # FIXME: Add masking for different seqlen values. + assert config.seq_tile_size >= 512, f" seq tile_size {config.seq_tile_size} cannot be less than 512" + assert seqlen_k % LARGE_TILE_SZ == 0, f"Need seqlen_k to be divisible by {LARGE_TILE_SZ} but got {seqlen_k}" + num_large_k_tile = seqlen_k // LARGE_TILE_SZ + + # inference flag, check if lse is none + inference = not config.training + if inference: + assert lse is None, "lse should be none for inference" + assert seed is None, f"seed should be None for inference, but got {seed}" + assert dropout_p==0.0, f"dropout should be 0.0 for inference but got {dropout_p}" + else: + assert lse is not None, "lse should not be none for training" + q_h_per_k_h = h // k_h + + if dropout_p > 0.0 and not inference: + seed_local = nl.load(seed[0]) + # TODO: Remove this once the dropout supports scale prob + dropout_p_tensor = nl.full((B_P_SIZE, 1), fill_value=dropout_p, dtype=np.float32) + else: + dropout_p_tensor = None + seed_local = None + + if logit_bias is not None: + b_logit_bias, h_logit_bias, _, _ = logit_bias.shape + assert b_logit_bias == 1 and h_logit_bias == 1, "only support broadcasting logit_bias with batch 1, n_heads 1" + + n_remat = div_ceil(n_tile_q, attn_core_tile_size) + attn_core_tile_size = min(n_tile_q, attn_core_tile_size) + + for i_q_h in nl.affine_range(q_h_per_k_h): + # =============== Global Flash Attention accumulators ====================== # + l_buffer = nl.zeros((par_dim(B_P_SIZE), n_tile_q), dtype=acc_type, + buffer=nl.sbuf, lazy_initialization=True) + # =============== Global Flash Attention accumulators END ================== # + + for i0 in nl.sequential_range(n_remat): + # =============== Global Flash Attention accumulators ====================== # + o_buffer = nl.zeros((attn_core_tile_size, par_dim(B_P_SIZE), d), dtype=acc_type, + buffer=nl.sbuf, lazy_initialization=True) + m_buffer = nl.zeros((attn_core_tile_size, par_dim(B_P_SIZE), 1), dtype=acc_type, + buffer=nl.sbuf, lazy_initialization=True) + # =============== Global Flash Attention accumulators END ================== # + + for j in nl.sequential_range(0, num_large_k_tile): + cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype) + cur_v_tile = nl.ndarray((LARGE_TILE_SZ // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE), dtype=kernel_dtype) + + cur_k_tile[:, :] = nl.load(k[batch_id, head_id, :, nl.ds(j*LARGE_TILE_SZ, LARGE_TILE_SZ)]) + + load_tile_size = B_P_SIZE + + v_hbm_tile = v[batch_id, head_id] + for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size): + load_v_tile(v_hbm_tile=v_hbm_tile, cur_v_tile=cur_v_tile, j=j, v_i=v_i, + config=config) + + for i1 in nl.affine_range(attn_core_tile_size): + i = i0 * attn_core_tile_size + i1 + # mask are used to only apply computation to the lower half of the matrix, + # which reduce the arthimetic intensity by half. + # forward_mask imply initialize, i.e. if forward_mask is false, initialize will + # be false as well + if use_causal_mask: + forward_mask = i * B_P_SIZE >= j * LARGE_TILE_SZ + else: + forward_mask = True + + if (i < n_tile_q) & forward_mask: + q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE),dtype=kernel_dtype) + q_hbm_tile = q[batch_id, head_id * q_h_per_k_h + i_q_h] + q_sbuf_tile = nl.load(q_hbm_tile[:, nl.ds(i * B_P_SIZE, B_P_SIZE)], + dtype=kernel_dtype) # load (d, 128) tile in SBUF + q_tile[:, :] = q_sbuf_tile * softmax_scale + + logit_bias_tile = None + if logit_bias is not None: + logit_bias_tile = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype) + logit_bias_tile[:, :] = nl.load( + logit_bias[0, 0, nl.ds(i * B_P_SIZE, B_P_SIZE), + nl.ds(j * LARGE_TILE_SZ, LARGE_TILE_SZ)]) + + _flash_attention_core(q_local_tile=q_tile, k=cur_k_tile, v=cur_v_tile, + q_h_per_k_h=q_h_per_k_h, seqlen_q=seqlen_q, nheads=h, + o_buffer=o_buffer[i1], l_buffer=l_buffer[:, i], m_buffer=m_buffer[i1], + batch_id=batch_id, head_id=head_id, + gqa_head_idx=i_q_h, q_tile_idx=i, local_k_large_tile_idx=j, + kernel_dtype=kernel_dtype, acc_type=acc_type, + flash_config=config, use_causal_mask=use_causal_mask, + initialize=j == 0, + B_P_SIZE=B_P_SIZE, B_F_SIZE=B_F_SIZE, B_D_SIZE=B_D_SIZE, + dropout_p=dropout_p, dropout_p_tensor=dropout_p_tensor, + seed_tensor=seed_local, logit_bias_tile=logit_bias_tile) + + # -------- write output to buffer on HBM ------------ # + for i1 in nl.affine_range(attn_core_tile_size): + i = i0 * attn_core_tile_size + i1 + + if i < n_tile_q: + exp = nisa.activation(np.exp, l_buffer[:, i], bias=m_buffer[i1, :, :], + scale=-1.0) + out = nl.multiply(o_buffer[i1, :, :], exp, + dtype=kernel_dtype) + + nl.store(o[batch_id, head_id * q_h_per_k_h + i_q_h, + nl.ds(i*B_P_SIZE, B_P_SIZE), :], out) + + if not inference: + nl.store(lse[batch_id, head_id * q_h_per_k_h + i_q_h, :, :], l_buffer[:, :]) + + if config.training: + return o, lse + + return o + + + +@nki.jit +def flash_attn_bwd( + q_ref, k_ref, v_ref, o_ref, + dy_ref, + lse_ref, + seed_ref, + logit_bias_ref=None, + use_causal_mask=False, + mixed_precision=False, + dropout_p=0.0, + softmax_scale=None, +): + """ + Flash attention backward kernel. Compute the backward gradients. + + IO tensor layouts: + - q_ref: shape (bs, nheads, head_size, seq) + - k_ref: shape (bs, nheads, head_size, seq) + - v_ref: shape (bs, nheads, head_size, seq) + - o_ref: shape (bs, nheads, head_size, seq) + - dy_ref: shape (bs, nheads, head_size, seq) + - lse_ref: shape (bs, nheads, nl.tile_size.pmax, seq // nl.tile_size.pmax) + - seed_ref: shape (1,) + - logit_bias_ref: shape (bs, n_heads, seq_q, seq_k) + - out_dq_ref: shape (bs, nheads, head_size, seq) + - out_dk_ref: shape (bs, nheads, head_size, seq) + - out_dv_ref: shape (bs, nheads, head_size, seq) + + Detailed steps: + 1. D = rowsum(dO ◦ O) (pointwise multiply) + + 2. Recompute (softmax(Q^T@K + logic_bias)) + + 2.1 Q^T@K + 2.2 Scale the QK score + 2.3 Apply causal mask and add logit_bias + 2.4 softmax + + 3. Compute the gradients of y = score @ V with respect to the loss + + 4. Compute the gradients of y = softmax(x) + + 5. Compute the gradients of Q^T@K + + 4.1 Compute dQ + 4.2 Compute dK + """ + + # Use q_ref dtype as the intermediate tensor dtype + # Assume all IO tensors have the same dtype + kernel_dtype = q_ref.dtype + mixed_dtype = np.dtype(np.float32) if mixed_precision else kernel_dtype + + assert q_ref.dtype == k_ref.dtype == v_ref.dtype == o_ref.dtype == dy_ref.dtype + + # Shape checking + bs, nheads, d_head, seqlen_q = q_ref.shape + _, _, _, seqlen_k = k_ref.shape + assert tuple(k_ref.shape) == (bs, nheads, d_head, seqlen_k), \ + f"Input K shape mismatch, got {k_ref.shape}" + assert tuple(v_ref.shape) == (bs, nheads, d_head, seqlen_k), \ + f"Input V shape mismatch, got {v_ref.shape}" + assert tuple(o_ref.shape) == (bs, nheads, d_head, seqlen_q), \ + f"Input o shape mismatch, got {o_ref.shape}" + assert tuple(dy_ref.shape) == (bs, nheads, d_head, seqlen_q), \ + f"Input dy shape mismatch, got {dy_ref.shape}" + assert tuple(lse_ref.shape) == (bs, nheads, nl.tile_size.pmax, seqlen_q // nl.tile_size.pmax), \ + f"Input lse shape mismatch, got {lse_ref.shape}" + if seed_ref is not None: + assert tuple(seed_ref.shape) == (1,), \ + f"Input seed shape mismatch, got {seed_ref.shape}" + + out_dq_ref = nl.ndarray((bs, nheads, d_head, seqlen_q), dtype=q_ref.dtype, + buffer=nl.shared_hbm) + out_dk_ref = nl.ndarray((bs, nheads, d_head, seqlen_k), dtype=q_ref.dtype, + buffer=nl.shared_hbm) + out_dv_ref = nl.ndarray((bs, nheads, d_head, seqlen_k), dtype=q_ref.dtype, + buffer=nl.shared_hbm) + + # FIXME: Add masking for different seqlen values. + assert seqlen_q % 128 == 0 and seqlen_k % 128 == 0, \ + f"Input sequence lengths must be divisible by 128, got seqlen_q == {seqlen_q} and seqlen_k == {seqlen_k}" + + # Softmax scaling factor, multiplied onto Q + softmax_scale = softmax_scale or 1.0 / float(d_head ** 0.5) + + assert nl.program_ndim() == 2,\ + f'Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!' + # Different batch samples/attention heads have independent attention + batch_id = nl.program_id(axis=0) + head_id = nl.program_id(axis=1) + + assert nl.num_programs(1) == nheads, \ + f"The grid shape mismatch, got {nl.num_programs(1)} but should be {nheads}" + + if logit_bias_ref is not None: + b_logit_bias, h_logit_bias, _, _ = logit_bias_ref.shape + assert b_logit_bias == 1 and h_logit_bias == 1, "Only support broadcasting logit_bias with batch 1, n_heads 1" + + q_seq_n_tiles, q_seq_tile_size = div_ceil(seqlen_q, 128), 128 + d_head_n_tiles, d_head_tile_size = div_ceil(d_head, 128), min(d_head, 128) + + if seqlen_k >= 512: + k_seq_n_tiles, k_seq_tile_size = seqlen_k // 512, 512 + else: + k_seq_n_tiles, k_seq_tile_size = seqlen_k // 128, 128 + + k_seq_n_tiles_backward, k_seq_tile_size_backward = seqlen_k // 128, 128 + k_seq_fwd_bwd_tile_multipler = k_seq_tile_size // k_seq_tile_size_backward + + ############################################################## + # Step 2.4 Prefetch exp bias for softmax + ############################################################## + softmax_exp_bias = nl.zeros((par_dim(q_seq_tile_size), q_seq_n_tiles), dtype=mixed_dtype) + lse_local = nl.load(lse_ref[batch_id, head_id, :, :], dtype=mixed_dtype) + softmax_exp_bias[:, :] = lse_local * -1.0 + + ############################################################## + # Step 1 Compute rowsum(dO ◦ O) + ############################################################## + dy_o_sum = nl.ndarray((q_seq_n_tiles, par_dim(q_seq_tile_size), 1), dtype=mixed_dtype) + compute_rowsum(dy_o_sum=dy_o_sum, + dy_ref_hbm_tile=dy_ref[batch_id, head_id], + o_ref_hbm_tile=o_ref[batch_id, head_id], + d_head_n_tiles=d_head_n_tiles, d_head_tile_size=d_head_tile_size, + q_seq_n_tiles=q_seq_n_tiles, q_seq_tile_size=q_seq_tile_size) + + if dropout_p > 0.0: + seed_local = nl.load(seed_ref[0]) + # TODO: Remove this once the dropout supports scale prob + dropout_p_local = nl.full((q_seq_tile_size, 1), fill_value=dropout_p, dtype=np.float32) + else: + seed_local = None + dropout_p_local = None + + dq_local_reduced = nl.zeros((q_seq_n_tiles, d_head_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), + dtype=mixed_dtype) + + # affine_range give the compiler permission to vectorize instructions + # inside the loop which improves the performance. However, when using the + # the dropout we should use sequential_range to avoid setting + # seed vectorization. TODO: the compiler should avoid vectorizing seed setting + _range = nl.sequential_range if dropout_p > 0.0 else nl.affine_range + + for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): + i_k_seq_dslice = nl.ds(i_k_seq_tile * k_seq_tile_size, k_seq_tile_size) + + # Prefetch V, K + v_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), + dtype=kernel_dtype) + k_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), + dtype=kernel_dtype) + transposed_k_local = nl.zeros((k_seq_fwd_bwd_tile_multipler, d_head_n_tiles, + par_dim(k_seq_tile_size_backward), d_head_tile_size), + dtype=kernel_dtype) + + load_kv(k_ref_hbm_tile=k_ref[batch_id, head_id], + v_ref_hbm_tile=v_ref[batch_id, head_id], + k_local=k_local, transposed_k_local=transposed_k_local, v_local=v_local, + d_head_n_tiles=d_head_n_tiles, d_head_tile_size=d_head_tile_size, + i_k_seq_tile=i_k_seq_tile, k_seq_tile_size=k_seq_tile_size, + k_seq_tile_size_backward=k_seq_tile_size_backward) + + # FIXME: Pass sbuf instead, we will have psum spilling in the current implementation + dv_psum = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), + dtype=np.float32, buffer=nl.psum) + dk_psum = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), + dtype=np.float32, buffer=nl.psum) + for i_q_seq_tile in _range(q_seq_n_tiles): + # Prefetch dy, Q + dy_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=kernel_dtype) + q_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=kernel_dtype) + + load_dy_q(dy_ref_hbm_tile = dy_ref[batch_id, head_id], + q_ref_hbm_tile = q_ref[batch_id, head_id], + dy_local=dy_local, q_local=q_local, d_head_n_tiles=d_head_n_tiles, + d_head_tile_size=d_head_tile_size, i_q_seq_tile=i_q_seq_tile, + q_seq_tile_size=q_seq_tile_size, softmax_scale=softmax_scale) + + logit_bias_tile = None + if logit_bias_ref is not None: + i_q_seq_dslice = nl.ds(i_q_seq_tile * q_seq_tile_size, q_seq_tile_size) + logit_bias_tile = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), + buffer=nl.sbuf, dtype=kernel_dtype) + logit_bias_tile[:, :] = nl.load( + logit_bias_ref[0, 0, i_q_seq_dslice, i_k_seq_dslice]) + + _flash_attn_bwd_core( + q_local=q_local, k_local=k_local, transposed_k_local=transposed_k_local, + v_local=v_local, dy_local=dy_local, + dk_psum=dk_psum, dv_psum=dv_psum, dq_local_reduced=dq_local_reduced, + softmax_exp_bias=softmax_exp_bias, dy_o_sum=dy_o_sum, + local_i_q_seq_tile=i_q_seq_tile, local_i_k_seq_tile=i_k_seq_tile, + seqlen_q=seqlen_q, seqlen_k=seqlen_k, d_head=d_head, nheads=nheads, + use_causal_mask=use_causal_mask, + kernel_dtype=kernel_dtype, mixed_dtype=mixed_dtype, + softmax_scale=softmax_scale, + seed_local=seed_local, dropout_p=dropout_p, dropout_p_local=dropout_p_local, + logit_bias_tile=logit_bias_tile + ) + + # Write dK, dV + store_dk_dv(out_dk_ref_hbm_tile=out_dk_ref[batch_id, head_id], + out_dv_ref_hbm_tile=out_dv_ref[batch_id, head_id], + local_dk=dk_psum, local_dv=dv_psum, i_k_seq_dslice=i_k_seq_dslice, + d_head_n_tiles=d_head_n_tiles, d_head_tile_size=d_head_tile_size) + + # Write dQ + for i_q_seq_tile in nl.affine_range(q_seq_n_tiles): + for i_d_head_tile in nl.affine_range(d_head_n_tiles): + i_q_seq_dslice = nl.ds(i_q_seq_tile * q_seq_tile_size, q_seq_tile_size) + i_d_head_dslice = nl.ds(i_d_head_tile * d_head_tile_size, d_head_tile_size) + nl.store( + out_dq_ref[batch_id, head_id, i_d_head_dslice, i_q_seq_dslice], + value=dq_local_reduced[i_q_seq_tile, i_d_head_tile, :, :], + ) + + return out_dq_ref, out_dk_ref, out_dv_ref + + +@nki.jit(mode='trace') +def load_dy_q(dy_ref_hbm_tile, q_ref_hbm_tile, dy_local, q_local, d_head_n_tiles, d_head_tile_size, i_q_seq_tile, + q_seq_tile_size, softmax_scale): + for i_d_head_tile in nl.affine_range(d_head_n_tiles): + i_d_head_dslice = nl.ds(i_d_head_tile * d_head_tile_size, d_head_tile_size) + i_q_seq_dslice = nl.ds(i_q_seq_tile * q_seq_tile_size, q_seq_tile_size) + + dy_local[i_d_head_tile, :, :] = nl.load( + dy_ref_hbm_tile[i_d_head_dslice, i_q_seq_dslice], + dtype=dy_local.dtype) + + q_local[i_d_head_tile, :, :] = nl.load( + q_ref_hbm_tile[i_d_head_dslice, i_q_seq_dslice], + dtype=q_local.dtype) * softmax_scale + + +@nki.jit(mode='trace') +def store_dk_dv(out_dk_ref_hbm_tile, out_dv_ref_hbm_tile, local_dk, local_dv, + d_head_n_tiles, d_head_tile_size, i_k_seq_dslice): + for i in nl.affine_range(d_head_n_tiles): + i_d_head_dslice = nl.ds(i * d_head_tile_size, d_head_tile_size) + + nl.store(out_dv_ref_hbm_tile[i_d_head_dslice, i_k_seq_dslice], + value=local_dv[i, :, :]) + + nl.store(out_dk_ref_hbm_tile[i_d_head_dslice, i_k_seq_dslice], + value=local_dk[i, :, :]) + + +@nki.jit(mode='trace') +def load_kv(k_ref_hbm_tile, v_ref_hbm_tile, k_local, transposed_k_local, v_local, + d_head_n_tiles, d_head_tile_size, i_k_seq_tile, k_seq_tile_size, + k_seq_tile_size_backward): + k_seq_fwd_bwd_tile_multipler = k_seq_tile_size // k_seq_tile_size_backward + + for i in nl.affine_range(d_head_n_tiles): + i_d_head_dslice = nl.ds(i * d_head_tile_size, d_head_tile_size) + i_k_seq_dslice = nl.ds(i_k_seq_tile * k_seq_tile_size, k_seq_tile_size) + k_local[i, :, :] = nl.load(k_ref_hbm_tile[i_d_head_dslice, i_k_seq_dslice], + dtype=k_local.dtype) + v_local[i, :, :] = nl.load(v_ref_hbm_tile[i_d_head_dslice, i_k_seq_dslice], + dtype=v_local.dtype) + ############################################################## + # Prefetch k transpose for the backward too + ############################################################## + for j in nl.affine_range(k_seq_fwd_bwd_tile_multipler): + i_k_dslice = nl.ds(j * k_seq_tile_size_backward, k_seq_tile_size_backward) + transposed_k_local[j, i, :, :] = nisa.nc_transpose(k_local[i, :, i_k_dslice]) + + +@nki.jit(mode='trace') +def compute_rowsum(dy_o_sum, dy_ref_hbm_tile, o_ref_hbm_tile, d_head_n_tiles, d_head_tile_size, q_seq_n_tiles, + q_seq_tile_size): + mixed_dtype = dy_o_sum.dtype + for i in nl.affine_range(q_seq_n_tiles): + dy_o_partial = nl.zeros((par_dim(q_seq_tile_size), d_head_n_tiles), dtype=mixed_dtype) + for j in nl.affine_range(d_head_n_tiles): + d_head_dslice = nl.ds(j * d_head_tile_size, d_head_tile_size) + q_seq_dslice = nl.ds(i * q_seq_tile_size, q_seq_tile_size) + + dy_local = nl.load_transpose2d(dy_ref_hbm_tile[d_head_dslice, q_seq_dslice], + dtype=mixed_dtype) + o_local = nl.load_transpose2d(o_ref_hbm_tile[d_head_dslice, q_seq_dslice], + dtype=mixed_dtype) + + dy_o = nl.multiply(dy_local, o_local, dtype=mixed_dtype) + dy_o_partial[:, j] = nisa.tensor_reduce(np.add, data=dy_o, axis=(1,), + dtype=mixed_dtype) + + dy_o_sum[i, :, 0] = nisa.tensor_reduce( + np.add, data=dy_o_partial[:, :], axis=(1,), dtype=mixed_dtype) + + +@nki.jit(mode='trace') +def _flash_attn_bwd_core( + q_local, k_local, transposed_k_local, v_local, dy_local, + dk_psum, dv_psum, dq_local_reduced, + softmax_exp_bias, dy_o_sum, + local_i_q_seq_tile, local_i_k_seq_tile, + seqlen_q, seqlen_k, d_head, nheads, + use_causal_mask, + kernel_dtype, mixed_dtype, + softmax_scale, + seed_local, dropout_p, dropout_p_local, + logit_bias_tile=None): + """ + The flash backward core function to calculate the gradients of Q, K and V + of the given tiles. The result will be accumulated into the dk, dv, dq psum + """ + q_seq_n_tiles, q_seq_tile_size = div_ceil(seqlen_q, 128), 128 + d_head_n_tiles, d_head_tile_size = div_ceil(d_head, 128), min(d_head, 128) + if seqlen_k >= 512: + k_seq_n_tiles, k_seq_tile_size = seqlen_k // 512, 512 + else: + k_seq_n_tiles, k_seq_tile_size = seqlen_k // 128, 128 + k_seq_n_tiles_backward, k_seq_tile_size_backward = seqlen_k // 128, 128 + k_seq_fwd_bwd_tile_multipler = k_seq_tile_size // k_seq_tile_size_backward + + mask = local_i_q_seq_tile * q_seq_tile_size >= local_i_k_seq_tile * k_seq_tile_size if use_causal_mask else None + # PSUM buffer shape: [q_seq_tile_size P, k_seq_tile_size F] + qk_psum = nl.zeros((par_dim(q_seq_tile_size), k_seq_tile_size), + dtype=np.float32, buffer=nl.psum) + qk_res_buf = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), buffer=nl.sbuf, dtype=kernel_dtype) + + batch_id = nl.program_id(axis=0) + head_id = nl.program_id(axis=1) + + # Loop over contraction dim of QK matmul + for i_d_head_tile in nl.affine_range(d_head_n_tiles): + ############################################################## + # Step 2.1 Compute Q^T@K, with matmul(stationary=tensor_q, moving=tensor_k, contract=d_head) + ############################################################## + qk_psum[:, :] += nisa.nc_matmul(q_local[i_d_head_tile, :, :], + k_local[i_d_head_tile, :, :], + mask=mask) + + ###################################### + # Step 2.2. Apply optional causal mask + ###################################### + if use_causal_mask: + iq, ik = nl.mgrid[0:q_seq_tile_size, 0:k_seq_tile_size] + causal_pred = (local_i_q_seq_tile * q_seq_tile_size + iq >= local_i_k_seq_tile * k_seq_tile_size + ik) + if logit_bias_tile is not None: + # Magic number -9984.0 to replace -inf similar to what Tensorizer uses + intermediate = \ + nl.add(qk_psum[:, :], logit_bias_tile[:, :], dtype=mixed_dtype, mask=mask) + qk_res_buf[:, :] = nisa.affine_select( + pred=causal_pred, + on_true_tile=intermediate, on_false_value=-9984.0, dtype=mixed_dtype, + mask=mask + ) + + else: + # Magic number -9984.0 to replace -inf similar to what Tensorizer uses + qk_res_buf[:, :] = nisa.affine_select( + pred=causal_pred, + on_true_tile=qk_psum[:, :], on_false_value=-9984.0, dtype=mixed_dtype, + mask=mask) + else: + if logit_bias_tile is not None: + # Simply add logit bias which copies back to sbuf at the same time + qk_res_buf[:, :] = \ + nl.add(qk_psum[:, :], logit_bias_tile[:, :], dtype=mixed_dtype) + else: + # Simply send psum result back to sbuf + qk_res_buf[:, :] = \ + nl.copy(qk_psum[:, :], dtype=mixed_dtype) + + softmax_y = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=kernel_dtype, buffer=nl.sbuf) + softmax_y[:, :] = nisa.activation(np.exp, + data=qk_res_buf[:, :], + bias=softmax_exp_bias[:, local_i_q_seq_tile], + scale=1.0, + mask=mask) + ##################################################################### + # Dropout + ##################################################################### + if dropout_p > 0.0: + offset = local_i_k_seq_tile + local_i_q_seq_tile * k_seq_n_tiles \ + + head_id * k_seq_n_tiles * q_seq_n_tiles \ + + batch_id * nheads * k_seq_n_tiles * q_seq_n_tiles + offset_seed = nl.add(seed_local[0, 0], offset, mask=mask) + nl.random_seed(seed=offset_seed, mask=mask) + softmax_y[:, :] = nl.dropout(softmax_y[:, :], rate=dropout_p_local[:, 0], mask=mask) + softmax_y[:, :] = nl.multiply(softmax_y[:, :], 1 / (1 - dropout_p), mask=mask) + + ##################################################################### + # Step 3.1 Calculate the backward gradients dL/dV, where y=softmax@V + # in value projection with matmul(stationary=dy, moving=softmax) + ##################################################################### + for i_d_head_tile in nl.affine_range(d_head_n_tiles): + trans_dy = nisa.nc_transpose(dy_local[i_d_head_tile, :, :], + mask=mask) + dv_psum[i_d_head_tile, :, :] += \ + nisa.nc_matmul(trans_dy, softmax_y[:, :], mask=mask) + + ##################################################################### + # Step 3.2 Calculate the backward gradients dL/dsoftmax, where y=softmax@V + # in value projection with matmul(stationary=dy, moving=v) + ##################################################################### + softmax_dy_psum = nl.zeros((par_dim(q_seq_tile_size), k_seq_tile_size), + dtype=np.float32, buffer=nl.psum) + for i_d_head_tile in nl.affine_range(d_head_n_tiles): + softmax_dy_psum[:, :] += \ + nisa.nc_matmul(dy_local[i_d_head_tile, :, :], + v_local[i_d_head_tile, :, :], + mask=mask) + + softmax_dy = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=kernel_dtype, buffer=nl.sbuf) + softmax_dy[:, :] = nl.copy(softmax_dy_psum[:, :], dtype=kernel_dtype, + mask=mask) + + ##################################################################### + # Step 4 Calculate the softmax backward gradients dL/dx, where y=softmax(x) + # dL/dx = y * (dL/dy - rowsum(dO_O)), where y = softmax(x) + ##################################################################### + softmax_dx_local = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=kernel_dtype, buffer=nl.sbuf) + softmax_dx_local[:, :] = \ + nisa.scalar_tensor_tensor(data=softmax_dy[:, :], + op0=np.subtract, + operand0=dy_o_sum[local_i_q_seq_tile, :, 0], + op1=np.multiply, + operand1=softmax_y[:, :], + mask=mask) + + ##################################################################### + # Step 5.1 Calculate dK, with matmul(stationary=Q, moving=softmax_dx) + ##################################################################### + for i_d_head_tile in nl.affine_range(d_head_n_tiles): + trans_q_local = nisa.nc_transpose(q_local[i_d_head_tile, :, :], + mask=mask) + dk_psum[i_d_head_tile, :, :] += \ + nisa.nc_matmul(trans_q_local, + softmax_dx_local[:, :], + mask=mask) + + ##################################################################### + # Step 5.2 Calculate dQ + ##################################################################### + for i_d_head_tile in nl.affine_range(d_head_n_tiles): + dq_psum = nl.zeros((par_dim(d_head_tile_size), q_seq_tile_size), + dtype=np.float32, buffer=nl.psum) + for i_k_seq_tile_backward in nl.affine_range(k_seq_fwd_bwd_tile_multipler): + i_k_seq_dslice = nl.ds(i_k_seq_tile_backward * k_seq_tile_size_backward, + k_seq_tile_size_backward) + transposed_softmax_dx_local = \ + nisa.nc_transpose(softmax_dx_local[:, i_k_seq_dslice], + mask=mask) + dq_psum[:, :] += nisa.nc_matmul( + transposed_k_local[i_k_seq_tile_backward, i_d_head_tile, :, :], + transposed_softmax_dx_local, + mask=mask) + dq_local = nl.multiply(dq_psum[:, :], softmax_scale, dtype=kernel_dtype, mask=mask) + dq_local_reduced[local_i_q_seq_tile, i_d_head_tile, :, :] = nl.loop_reduce( + dq_local, op=np.add, loop_indices=(local_i_k_seq_tile,), + dtype=mixed_dtype, mask=mask) + + +@nki.jit +def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask=False, + mixed_precision=True): + """ + Fused self attention kernel for small head size Stable Diffusion workload. + + Computes softmax(QK^T)V. Decoder model can optionally include a causal mask + application. Does not include QKV projection, output projection, dropout, + residual connection, etc. + + This kernel is designed to be used for Stable Diffusion models where the + n_heads is smaller or equal to 128. Assertion is thrown if `n_heads` does + not satisfy the requirement. + + IO tensor layouts: + - q_ptr: shape (bs, n_heads, seq_q) + - k_ptr: shape (bs, seq_k, n_heads) + - v_ptr: shape (bs, seq_v, n_heads) + - out_ptr: shape (bs, seq_q, n_heads) + - We use seq_q and seq_k just for clarity, this kernel requires seq_q == seq_k + + IO tensor dtypes: + - This kernel assumes all IO tensors have the same dtype + - If mixed_precision is True, then all Tensor Engine operation will be performed in + bfloat16 and accumulation will be performed in float32. Otherwise the intermediates + will be in the same type as the inputs. + """ + # Use q_ref dtype as the intermediate tensor dtype + # Assume all IO tensors have the same dtype + kernel_dtype = q_ref.dtype + pe_in_dt = nl.bfloat16 if mixed_precision else np.float32 + assert q_ref.dtype == k_ref.dtype == v_ref.dtype + + # Shape checking + bs, d_head, seqlen = q_ref.shape + assert d_head <= 128, "Cannot use this kernel for d_head > 128" + assert tuple(q_ref.shape) == (bs, d_head, seqlen), 'Input shape mismatch!' + assert tuple(k_ref.shape) == (bs, seqlen, d_head), 'Input shape mismatch!' + assert tuple(v_ref.shape) == (bs, seqlen, d_head), \ + f'Input shape mismatch! Expected: {(bs, seqlen, d_head)} Actual: {tuple(v_ref.shape)}' + + out_ref = nl.ndarray((bs, seqlen, d_head), dtype=q_ref.dtype, buffer=nl.shared_hbm) + + # Softmax scaling factor, multiplied onto Q + softmax_scale = 0.125 + + # Different batch samples/attention heads have independent attention + batch_id = nl.program_id(axis=0) + # batch_id = 0 + + # TODO: make q_seq_tile_size user input + # The matmuls currently use a fixed tile size of (128, 128). This may not achieve the best + # performance for dense attention. However, since this kernel is in preparation + # for block-sparse attention, this tile size is acceptable because the block + # size of block-sparse attention cannot be too large. + q_seq_n_tiles, q_seq_tile_size = seqlen // 128, 128 + k_seq_n_tiles, k_seq_tile_size = seqlen // 128, 128 + # No tiling on d_head dimension since the number of d_head fits in SB + d_head_tile_size = d_head + v_seq_n_tiles, v_seq_tile_size = seqlen // 128, 128 + + ################################### + # Step 1. transpose(tensor_v) + ################################### + # Buffer for v matrix transposed + # Pre-fetch and keep it in SBUF throughout different softmax tiles + trans_v = nl.ndarray((par_dim(v_seq_tile_size), v_seq_n_tiles, d_head), dtype=pe_in_dt) + + for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): + ip_v = nl.arange(v_seq_tile_size)[:, None] + if_v = nl.arange(d_head_tile_size)[None, :] + trans_v[ip_v, i_k_seq_tile, if_v] = nl.load( + v_ref[batch_id, i_k_seq_tile * k_seq_tile_size + ip_v, if_v], + dtype=pe_in_dt) + + q_local = nl.ndarray((q_seq_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=pe_in_dt) + ip_q = nl.arange(d_head_tile_size)[:, None] + if_q = nl.arange(q_seq_tile_size)[None, :] + for i_q_seq_tile in nl.affine_range(q_seq_n_tiles): + q_local[i_q_seq_tile, ip_q, if_q] = nl.load( + q_ref[batch_id, ip_q, i_q_seq_tile * q_seq_tile_size + if_q], + dtype=pe_in_dt) * softmax_scale + + k_local = nl.ndarray((k_seq_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), dtype=pe_in_dt) + ip_k = nl.arange(d_head_tile_size)[:, None] + if_k = nl.arange(k_seq_tile_size)[None, :] + for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): + k_local[i_k_seq_tile, ip_k, if_k] = nl.load_transpose2d( + k_ref[batch_id, + i_k_seq_tile * k_seq_tile_size + nl.arange(k_seq_tile_size)[:, None], + nl.arange(d_head_tile_size)[None, :]], + dtype=pe_in_dt) + + for i_q_seq_tile in nl.affine_range(q_seq_n_tiles): # indent = 2 + # A SBUF buffer for an independent softmax tile + qk_res_buf = nl.ndarray((par_dim(q_seq_tile_size), seqlen), dtype=kernel_dtype) + + neg_max_res = nl.ndarray((par_dim(q_seq_tile_size), k_seq_n_tiles), dtype=kernel_dtype) + ip_max = nl.arange(q_seq_tile_size)[:, None] + if_max = nl.arange(k_seq_n_tiles)[None, :] + + # Loop over RHS free of matmul(stationary=tensor_q, moving=tensor_k, contract=d_head) + for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): # indent = 4 + + # Since the K^T tile is the RHS, the q_seq_len dimension will be P in the result + # PSUM buffer shape: [q_seq_tile_size P, k_seq_tile_size F] + qk_psum = nl.zeros((par_dim(q_seq_tile_size), k_seq_tile_size), + dtype=np.float32, buffer=nl.psum) + + # Tensor indices for accessing qk result in k_seq_tile_size + ip_qk = nl.arange(q_seq_tile_size)[:, None] + if_qk = nl.arange(k_seq_tile_size)[None, :] + + ############################################################## + # Step 2. matmul(stationary=tensor_q, moving=tensor_k, contract=d_head) + ############################################################## + qk_psum[ip_qk, if_qk] += nisa.nc_matmul(moving=k_local[i_k_seq_tile, ip_k, if_k], + stationary=q_local[i_q_seq_tile, ip_q, if_q]) + + ################################### + # Step 3. Apply optional causal mask + ################################### + if use_causal_mask: + # Magic number -9984.0 to replace -inf similar to what Tensorizer uses + qk_res_buf[ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nisa.affine_select( + pred=(i_q_seq_tile * q_seq_tile_size + ip_qk >= i_k_seq_tile * k_seq_tile_size + if_qk), + on_true_tile=qk_psum[ip_qk, if_qk], on_false_value=-9984.0, dtype=kernel_dtype) + else: + # Simply send psum result back to sbuf + qk_res_buf[ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nl.copy(qk_psum[ip_qk, if_qk], + dtype=kernel_dtype) + + ################################### + # Step 4. Softmax + ################################### + # TODO: use TensorScalarCacheReduce to avoid an extra copy + # We want to break this reduction in tiles because we want to overlap it with the previous matmul + neg_max_res[ip_max, i_k_seq_tile] = nisa.tensor_reduce( + np.max, data=qk_res_buf[ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk], + axis=(1,), dtype=kernel_dtype, negate=True) + + neg_max_res_final = nisa.tensor_reduce( + np.min, data=neg_max_res[ip_max, if_max], + axis=(1,), dtype=kernel_dtype, negate=False) + + ip_softmax = nl.arange(q_seq_tile_size)[:, None] + if_softmax = nl.arange(seqlen)[None, :] + ip_sum_res = nl.arange(q_seq_tile_size)[:, None] + if_sum_res = nl.arange(d_head_tile_size)[None, :] + + softmax_res = nl.ndarray((par_dim(q_seq_tile_size), seqlen), dtype=pe_in_dt) + sum_divisor = nl.ndarray((par_dim(q_seq_tile_size), d_head_tile_size), dtype=kernel_dtype) + + # Simply use a large tile of seq_len in size since this is a "blocking" instruction + # Assuming the compiler will merge exp and reduce_add into a single instruction on ACT + exp_res = nisa.activation(np.exp, + data=qk_res_buf[ip_softmax, if_softmax], + bias=neg_max_res_final, scale=1.0) + + sum_res = nisa.tensor_reduce(np.add, data=exp_res, axis=(1,), + dtype=kernel_dtype) + softmax_res[ip_softmax, if_softmax] = nl.copy(exp_res, dtype=pe_in_dt) + + sum_reciprocal_broadcast = (1.0 / sum_res).broadcast_to((q_seq_tile_size, d_head_tile_size)) + sum_divisor[ip_sum_res, if_sum_res] = nl.copy(sum_reciprocal_broadcast, dtype=kernel_dtype) + + # Buffer for transposed softmax results (FP32 in PSUM) + trans_softmax_res = nl.ndarray( + (par_dim(k_seq_tile_size), k_seq_n_tiles, q_seq_tile_size), + dtype=pe_in_dt) + + # Result psum buffer has the hidden dim as P + attn_res_psum = nl.zeros((par_dim(d_head_tile_size), q_seq_tile_size), + dtype=np.float32, buffer=nl.psum) + + ip_scores_t = nl.arange(k_seq_tile_size)[:, None] + if_scores_t = nl.arange(q_seq_tile_size)[None, :] + # Loop over matmul_1 contraction + for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): + ################################### + # Step 5. transpose(softmax_res) + ################################### + ip_scores = nl.arange(q_seq_tile_size)[:, None] + if_scores = nl.arange(k_seq_tile_size)[None, :] + + trans_softmax_res[ip_scores_t, i_k_seq_tile, if_scores_t] = nisa.nc_transpose( + softmax_res[ip_scores, i_k_seq_tile * k_seq_tile_size + if_scores]) + + ip_out = nl.arange(d_head_tile_size)[:, None] + if_out = nl.arange(q_seq_tile_size)[None, :] + for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): + ###################################################################### + # Step 6. matmul_1(stationary=trans_v, moving=trans_softmax_res, contract=seqlen_v=seqlen_k) + ###################################################################### + ip_v_t = nl.arange(k_seq_tile_size)[:, None] + if_v_t = nl.arange(d_head_tile_size)[None, :] + attn_res_psum[ip_out, if_out] += \ + nisa.nc_matmul(moving=trans_softmax_res[ip_scores_t, i_k_seq_tile, if_scores_t], + stationary=trans_v[ip_v_t, i_k_seq_tile, if_v_t]) + + attn_res_sbuf = nl.copy(attn_res_psum[ip_out, if_out], dtype=kernel_dtype) + + attn_res_div = attn_res_sbuf * nisa.nc_transpose(sum_divisor[ip_sum_res, if_sum_res]) + + nl.store( + out_ref[batch_id, i_q_seq_tile * q_seq_tile_size + if_out, ip_out], + value=attn_res_div) + + return out_ref diff --git a/src/nki_samples/reference/tutorial.py b/src/nki_samples/reference/tutorial.py new file mode 100644 index 0000000..b32492b --- /dev/null +++ b/src/nki_samples/reference/tutorial.py @@ -0,0 +1,31 @@ +""" +Copyright (c) 2023, Amazon.com. All Rights Reserved + +kernels - Builtin high performance NKI kernels used in tutorial + +""" + +from neuronxcc import nki +import neuronxcc.nki.language as nl + + +@nki.jit +def add_kernel_nx8x128x512(a_ptr, b_ptr, n_elements): + c_ptr = nl.ndarray(a_ptr.shape, dtype=a_ptr.dtype, buffer=nl.shared_hbm) + + ix = nl.arange(128)[:, None] + iy = nl.arange(512)[None, :] + + tile_size = 128 * 512 + block_size = 8 * tile_size + + j = nl.program_id(axis=0) + + for i in nl.affine_range(8): + offset = j * block_size + i * tile_size + 512 * ix + iy + a = nl.load(a_ptr[j, i, ix, iy], mask=offset < n_elements) + b = nl.load(b_ptr[j, i, ix, iy], mask=offset < n_elements) + c = nl.add(a, b, mask=offset < n_elements) + nl.store(c_ptr[j, i, ix, iy], value=c, mask=offset < n_elements) + + return c_ptr diff --git a/src/reference/vision.py b/src/nki_samples/reference/vision.py similarity index 93% rename from src/reference/vision.py rename to src/nki_samples/reference/vision.py index bc54941..4899d27 100644 --- a/src/reference/vision.py +++ b/src/nki_samples/reference/vision.py @@ -8,10 +8,13 @@ import neuronxcc.nki.language as nl import neuronxcc.nki.isa as nisa +from neuronxcc import nki from neuronxcc.nki.language import par_dim import neuronxcc.nki.typing as nt -def select_and_scatter_kernel(operand_tensor, source_tensor, out_tensor): + +@nki.jit +def select_and_scatter_kernel(operand_tensor, source_tensor): """ Implementation of a select-and-scatter kernel. @@ -51,7 +54,10 @@ def select_and_scatter_kernel(operand_tensor, source_tensor, out_tensor): assert C == 64 and N % 2 == 0 kernel_dtype = operand_tensor.dtype - assert operand_tensor.dtype == source_tensor.dtype == out_tensor.dtype + assert operand_tensor.dtype == source_tensor.dtype + + out_tensor = nl.ndarray((N, C, H, W), dtype=operand_tensor.dtype, + buffer=nl.shared_hbm) p = 128 # num of partitions to use for ib in nl.affine_range(N // 2): @@ -156,8 +162,11 @@ def select_and_scatter_kernel(operand_tensor, source_tensor, out_tensor): nl.store(out_tensor[2 * ib + ib_1, 0:64, 0:H, 0:W], value=out_local[(ib_1 * 64):((ib_1 + 1) * 64), 0:H, 0:W]) + return out_tensor -def resize_nearest_fixed_dma_kernel(data_tensor, out_tensor): + +@nki.jit +def resize_nearest_fixed_dma_kernel(data_tensor, out_shape): """ Resize the input image to the given size using the nearest interpolation mode. This kernel is designed to be used when the scaling factor is not an integer. @@ -174,7 +183,9 @@ def resize_nearest_fixed_dma_kernel(data_tensor, out_tensor): """ in_b, in_h, in_w, in_c = data_tensor.shape - out_b, out_h, out_w, out_c = out_tensor.shape + out_b, out_h, out_w, out_c = out_shape + out_tensor = nl.ndarray(out_shape, dtype=data_tensor.dtype, + buffer=nl.shared_hbm) assert in_b == out_b, "Input batch and output batch must be identical" assert in_c == out_c, "Input channel and output channel must be identical" @@ -198,3 +209,5 @@ def resize_nearest_fixed_dma_kernel(data_tensor, out_tensor): local_data = nl.load(target_addr) dst_addr_0 = out_tile[b_map, i, c_map] nl.store(dst_addr_0, value=local_data) + + return out_tensor diff --git a/src/tutorials/average_pool2d/average_pool2d_jax.py b/src/nki_samples/tutorials/average_pool2d/average_pool2d_jax.py similarity index 68% rename from src/tutorials/average_pool2d/average_pool2d_jax.py rename to src/nki_samples/tutorials/average_pool2d/average_pool2d_jax.py index e3b428d..139c42d 100644 --- a/src/tutorials/average_pool2d/average_pool2d_jax.py +++ b/src/nki_samples/tutorials/average_pool2d/average_pool2d_jax.py @@ -4,29 +4,22 @@ JAX implementation for average pool 2D NKI tutorial. """ -from functools import partial -from jax_neuronx import nki_call -import jax +# NKI_EXAMPLE_40_BEGIN import jax.numpy as jnp - -from average_pool2d_nki_kernels import tensor_avgpool_kernel_ - - -def tensor_avgpool_kernel(in_array, pool_size): - return nki_call( - partial(tensor_avgpool_kernel_, pool_size=pool_size), - in_array, - out_shape=jax.ShapeDtypeStruct((C, HOUT, WOUT), dtype=in_array.dtype), - ) +# NKI_EXAMPLE_40_END +from average_pool2d_nki_kernels import tensor_avgpool_kernel +# NKI_EXAMPLE_40_BEGIN # Reference JAX implementation def jax_average_pool_2D(in_tensor, pool_size): c, h_in, w_in = in_tensor.shape reshaped = in_tensor.reshape(c, h_in // pool_size, pool_size, w_in // pool_size, pool_size) return jnp.nanmean(reshaped, axis=(2, 4)) + # NKI_EXAMPLE_40_END +# NKI_EXAMPLE_41_BEGIN if __name__ == "__main__": POOL_SIZE = 2 C, HIN, WIN = 2, 6, 6 @@ -34,7 +27,9 @@ def jax_average_pool_2D(in_tensor, pool_size): in_array = jnp.arange(C * HIN * WIN, dtype=jnp.float32).reshape(C, HIN, WIN) + # NKI_EXAMPLE_39_BEGIN out_nki = tensor_avgpool_kernel(in_array, pool_size=POOL_SIZE) + # NKI_EXAMPLE_39_END out_jax = jax_average_pool_2D(in_array, pool_size=POOL_SIZE) print(in_array, out_nki, out_jax) @@ -42,4 +37,5 @@ def jax_average_pool_2D(in_tensor, pool_size): if jnp.allclose(out_nki, out_jax): print("NKI and JAX match") else: - print("NKI and JAX differ") \ No newline at end of file + print("NKI and JAX differ") + # NKI_EXAMPLE_41_END diff --git a/src/tutorials/average_pool2d/average_pool2d_nki_kernels.py b/src/nki_samples/tutorials/average_pool2d/average_pool2d_nki_kernels.py similarity index 59% rename from src/tutorials/average_pool2d/average_pool2d_nki_kernels.py rename to src/nki_samples/tutorials/average_pool2d/average_pool2d_nki_kernels.py index c81a4a5..68d3a31 100644 --- a/src/tutorials/average_pool2d/average_pool2d_nki_kernels.py +++ b/src/nki_samples/tutorials/average_pool2d/average_pool2d_nki_kernels.py @@ -5,48 +5,40 @@ """ import numpy as np +# NKI_EXAMPLE_37_BEGIN import neuronxcc.nki as nki import neuronxcc.nki.language as nl +from neuronxcc.nki.typing import tensor - -def tensor_avgpool_kernel_(in_tensor, out_tensor, pool_size): +@nki.jit +def tensor_avgpool_kernel(in_tensor, pool_size): """NKI kernel to compute a 2D avg-pool operation Args: in_tensor: an input tensor, of shape C x H x W pool_size: an integer representing a (square) pool-window size + + Return: out_tensor: the resulting output tensor, of shape C x (H/pool_size) x (W/pool_size) """ # Get input/output dimensions sz_cin, sz_hin, sz_win = in_tensor.shape - sz_cout, sz_hout, sz_wout = out_tensor.shape - assert sz_cin == sz_cout + sz_hout = sz_hin // pool_size + sz_wout = sz_win // pool_size + # Create output tensor shared between all SPMD instances as result tensor + out_tensor = nl.ndarray((sz_cin, sz_hout, sz_wout), dtype=in_tensor.dtype, + buffer=nl.shared_hbm) # Set relevant sizes sz_p = sz_cin sz_pool = pool_size - # Generate tensor h/w index patterns - # 3D indexing according to [C, H, W] - i_p = nl.arange(sz_p)[:, None, None] # 3D for - i_win = nl.arange(sz_win)[None, None, :] - i_hin = nl.arange(sz_hin)[None, :, None] - - i_wout = nl.arange(sz_wout)[None, None, :] - i_hout = nl.arange(sz_hout)[None, :, None] - # Generate pool index patterns (requires two extra dimensions, for the pool window) - i_0 = nl.arange(sz_p)[:, None, None, None, None] # - i_1 = nl.arange(sz_hin//sz_pool)[None, :, None, None, None] # y_outer - i_2 = nl.arange(sz_pool)[None, None, :, None, None] # y_inner - i_3 = nl.arange(sz_win//sz_pool)[None, None, None, :, None] # x_outer - i_4 = nl.arange(sz_pool)[None, None, None, None, :] # x_inner + i0, i1, i2, i3, i4 = nl.mgrid[0:sz_p, 0:sz_hin//sz_pool, 0:sz_pool, 0:sz_win//sz_pool, 0:sz_pool] # Load input data from external memory to on-chip memory - # Declare ndarray to force a 3D tensor (temporary requirement) - in_tile = nl.ndarray([sz_p, sz_hin, sz_win], dtype=in_tensor.dtype) - in_tile[:,:,:] = nl.load(in_tensor[i_p, i_hin, i_win]) + in_tile: tensor[sz_p, sz_hin, sz_win] = nl.load(in_tensor) # Perform the pooling operation: # We use numpy's advanced indexing, in order to extend in_tile to 5D, and then reduce-average two dimension. @@ -54,10 +46,15 @@ def tensor_avgpool_kernel_(in_tensor, out_tensor, pool_size): # axis[1] and axis[2] together index the rows, with axis[2] responsible for inner strides # (i.e. inside a pooling window), and axis[1] responsible for the outer strides. As such, we reduce over axis[2]. # Similarly, axis[3] and axis[4] together index the columns, and we thus reduce over axis[4]. - out_tile = nl.sum(in_tile[i_0, sz_pool*i_1+i_2, sz_pool*i_3+i_4], axis=[2,4]) / (pool_size*pool_size) + out_tile : tensor[sz_p, sz_hout, sz_wout] = nl.sum(in_tile[i0, sz_pool*i1+i2, sz_pool*i3+i4], + axis=[2,4]) / (pool_size*pool_size) + + # Store the results back to hbm + nl.store(out_tensor, value=out_tile) - # Store the results back to external memory - nl.store(out_tensor[i_p, i_hout, i_wout], value=out_tile) + # Transfer the ownership of `out_tensor` to the caller + return out_tensor + # NKI_EXAMPLE_37_END # Reference NumPy implementation @@ -74,10 +71,8 @@ def np_average_pool_2D(in_tensor, pool_size): HOUT, WOUT = HIN//POOL_SIZE, WIN//POOL_SIZE in_tensor = np.arange(C * HIN * WIN, dtype=np.float16).reshape(C, HIN, WIN) - out_nki = np.zeros((C, HOUT, WOUT), dtype=np.float16) - tensor_avgpool_kernel_baremetal = nki.baremetal(tensor_avgpool_kernel_) - tensor_avgpool_kernel_baremetal(in_tensor, out_nki, POOL_SIZE) + out_nki = tensor_avgpool_kernel(in_tensor, POOL_SIZE) out_np = np_average_pool_2D(in_tensor, POOL_SIZE) diff --git a/src/tutorials/average_pool2d/average_pool2d_torch.py b/src/nki_samples/tutorials/average_pool2d/average_pool2d_torch.py similarity index 78% rename from src/tutorials/average_pool2d/average_pool2d_torch.py rename to src/nki_samples/tutorials/average_pool2d/average_pool2d_torch.py index 3409a31..c5fb4ea 100644 --- a/src/tutorials/average_pool2d/average_pool2d_torch.py +++ b/src/nki_samples/tutorials/average_pool2d/average_pool2d_torch.py @@ -4,13 +4,14 @@ PyTorch implementation for average pool 2D NKI tutorial. """ +# NKI_EXAMPLE_38_BEGIN import torch -from torch_neuronx import nki_jit from torch_xla.core import xla_model as xm - -from average_pool2d_nki_kernels import tensor_avgpool_kernel_ +# NKI_EXAMPLE_38_END +from average_pool2d_nki_kernels import tensor_avgpool_kernel +# NKI_EXAMPLE_38_BEGIN if __name__ == "__main__": device = xm.xla_device() @@ -22,8 +23,7 @@ in_tensor = torch.arange(C * HIN * WIN, dtype=torch.bfloat16).reshape(C, HIN, WIN).to(device=device) out_nki = torch.zeros((C, HOUT, WOUT), dtype=torch.bfloat16).to(device=device) - tensor_avgpool_kernel_torch = nki_jit(tensor_avgpool_kernel_) - tensor_avgpool_kernel_torch(in_tensor, out_nki, POOL_SIZE) + out_nki = tensor_avgpool_kernel(in_tensor, POOL_SIZE) out_torch = torch.nn.functional.avg_pool2d(in_tensor, POOL_SIZE, POOL_SIZE) @@ -33,3 +33,4 @@ print("NKI and Torch match") else: print("NKI and Torch differ") + # NKI_EXAMPLE_38_END diff --git a/src/tutorials/fused_mamba/mamba_nki_kernels.py b/src/nki_samples/tutorials/fused_mamba/mamba_nki_kernels.py similarity index 94% rename from src/tutorials/fused_mamba/mamba_nki_kernels.py rename to src/nki_samples/tutorials/fused_mamba/mamba_nki_kernels.py index 9f8af60..4ff6642 100644 --- a/src/tutorials/fused_mamba/mamba_nki_kernels.py +++ b/src/nki_samples/tutorials/fused_mamba/mamba_nki_kernels.py @@ -4,16 +4,19 @@ Mamba-v1 NKI kernel implementation. """ +# NKI_EXAMPLE_25_BEGIN import neuronxcc.nki as nki import neuronxcc.nki.language as nl import neuronxcc.nki.isa as nisa import numpy as np +# NKI_EXAMPLE_25_END import os import argparse import itertools - -def mamba_v1(delta, u, A, B, C, output): +# NKI_EXAMPLE_25_BEGIN +@nki.jit +def mamba_v1(delta, u, A, B, C): """Computes the SSM operation in the Mamba model. :param delta: (batch_size, channels, seq_len) @@ -24,6 +27,9 @@ def mamba_v1(delta, u, A, B, C, output): :return: (batch_size, channels, seq_len) """ batch_size, channels, seq_len = delta.shape + output = nl.ndarray((batch_size, channels, seq_len), dtype=delta.dtype, + buffer=nl.shared_hbm) + _, state_size = A.shape # We can relax this using mask paramters in all the NKI API calls @@ -84,8 +90,12 @@ def mamba_v1(delta, u, A, B, C, output): nl.store(output[i_batch, channel_start:channel_start+channel_psize, 0:seq_len], scanC_accum[i_channel_tile, 0:channel_psize, 0:seq_len]) + return output +# NKI_EXAMPLE_25_END -def mamba_v2(delta, u, A, B, C, output): +# NKI_EXAMPLE_26_BEGIN +@nki.jit +def mamba_v2(delta, u, A, B, C): """Computes the SSM operation in the Mamba model. :param delta: (batch_size, channels, seq_len) @@ -96,6 +106,8 @@ def mamba_v2(delta, u, A, B, C, output): :return: (batch_size, channels, seq_len) """ batch_size, channels, seq_len = delta.shape + output = nl.ndarray((batch_size, channels, seq_len), dtype=delta.dtype, + buffer=nl.shared_hbm) _, state_size = A.shape assert channels % 128 == 0 @@ -153,8 +165,12 @@ def mamba_v2(delta, u, A, B, C, output): nl.store(output[i_batch, channel_start:channel_start+channel_psize, 0:seq_len], scanC_accum[0:channel_psize, 0:seq_len]) + return output +# NKI_EXAMPLE_26_END + -def mamba_v3(delta, u, A, B, C, output): +@nki.jit +def mamba_v3(delta, u, A, B, C): """Computes the SSM operation in the Mamba model. :param delta: (batch_size, channels, seq_len) @@ -165,6 +181,8 @@ def mamba_v3(delta, u, A, B, C, output): :return: (batch_size, channels, seq_len) """ batch_size, channels, seq_len = delta.shape + output = nl.ndarray((batch_size, channels, seq_len), dtype=delta.dtype, + buffer=nl.shared_hbm) _, state_size = A.shape # Map channels to the partition dimension @@ -239,6 +257,7 @@ def mamba_v3(delta, u, A, B, C, output): # Store scanC_accum for a single batch to output nl.store(output[i_batch, channel_start:channel_start+channel_psize, 0:seq_len], scanC_accum[0:channel_psize, 0:seq_len]) + return output def parse_args(): @@ -310,9 +329,7 @@ def parse_args(): if args.mode == "accuracy": # v1: reference kernel print(f">>>> Running v1 (reference).") - nki_out_v1 = np.empty((batch, channels, seq_len), dtype=dtype) - nki.baremetal(mamba_v1)\ - (delta, u, A, B, C, nki_out_v1) + nki_out_v1 = mamba_v1(delta, u, A, B, C) for version in args.version: if version == "v1": @@ -321,9 +338,7 @@ def parse_args(): print(f">>>> Running version {version}.") func = func_dict[version] - nki_out_test = np.empty((batch, channels, seq_len), dtype=dtype) - nki.baremetal(func)\ - (delta, u, A, B, C, nki_out_test) + nki_out_test = func(delta, u, A, B, C) print(f">>>> mamba {version} matches?", np.all(nki_out_test == nki_out_v1)) assert np.all(nki_out_test == nki_out_v1) @@ -333,11 +348,10 @@ def parse_args(): for version in args.version: print(f">>>> Running version {version}.") func = func_dict[version] - nki_out_test = np.empty((batch, channels, seq_len), dtype=dtype) nki.benchmark(func, save_neff_name='file.neff', save_trace_name='profile.ntff')\ - (delta, u, A, B, C, nki_out_test) + (delta, u, A, B, C) # TODO: rename neff/ntff (bug in nki.benchmark with neff name) os.rename("file.neff", f"{version}_b{batch}_sl{seq_len}_c{channels}_ss{state_size}.neff") os.rename("profile.ntff", f"{version}_b{batch}_sl{seq_len}_c{channels}_ss{state_size}.ntff") diff --git a/src/tutorials/fused_mamba/mamba_torch.py b/src/nki_samples/tutorials/fused_mamba/mamba_torch.py similarity index 95% rename from src/tutorials/fused_mamba/mamba_torch.py rename to src/nki_samples/tutorials/fused_mamba/mamba_torch.py index a2e593f..cd94a0b 100644 --- a/src/tutorials/fused_mamba/mamba_torch.py +++ b/src/nki_samples/tutorials/fused_mamba/mamba_torch.py @@ -5,6 +5,7 @@ """ +# NKI_EXAMPLE_24_BEGIN import torch import torch_neuronx import torch_xla.core.xla_model as xm @@ -99,16 +100,14 @@ def parse_args(): torch_out = mamba_layer(delta, A, B, u, C) xm.mark_step() print(torch_out) + # NKI_EXAMPLE_24_END if args.mode == "accuracy": # Call NKI mamba_v1 kernel to check accuracy from mamba_nki_kernels import mamba_v1 - from torch_neuronx import nki_jit - - nki_out = torch.empty((batch, channels, seq_len), dtype=dtype, device=device) xm.mark_step() - nki_jit(mamba_v1)(delta, u, A, B, C, nki_out) + nki_out = mamba_v1(delta, u, A, B, C) xm.mark_step() allclose = torch.allclose(torch_out, nki_out, atol=1e-2, rtol=1e-2) diff --git a/src/tutorials/layernorm/layernorm_nki_kernel.py b/src/nki_samples/tutorials/layernorm/layernorm_nki_kernel.py similarity index 64% rename from src/tutorials/layernorm/layernorm_nki_kernel.py rename to src/nki_samples/tutorials/layernorm/layernorm_nki_kernel.py index 503ce7d..c0c235c 100644 --- a/src/tutorials/layernorm/layernorm_nki_kernel.py +++ b/src/nki_samples/tutorials/layernorm/layernorm_nki_kernel.py @@ -4,21 +4,27 @@ LayerNorm NKI kernel implementation. """ +# NKI_EXAMPLE_45_BEGIN import neuronxcc.nki as nki import neuronxcc.nki.language as nl import neuronxcc.nki.isa as nisa import numpy as np import math +# NKI_EXAMPLE_45_END import os import argparse -def nki_layernorm_kernel_v1(input_tensor, epsilon, gamma_vector, beta_vector, output_tensor): +# NKI_EXAMPLE_45_BEGIN +@nki.jit +def nki_layernorm_kernel_v1(input_tensor, epsilon, gamma_vector, beta_vector): """Computes LayerNorm. Used nki.language APIs only. """ + output_tensor = nl.ndarray(input_tensor.shape, dtype=input_tensor.dtype, + buffer=nl.shared_hbm) + # Ensure that the shapes of tensors match - assert input_tensor.shape == output_tensor.shape assert input_tensor.shape[1] == gamma_vector.shape[0] == beta_vector.shape[0] # Generate tile indices for loading/storing data @@ -58,12 +64,20 @@ def nki_layernorm_kernel_v1(input_tensor, epsilon, gamma_vector, beta_vector, ou nl.store(output_tensor[i * nl.tile_size.pmax + i_p_io, i_f_io], value=output_sb, mask=(i * nl.tile_size.pmax + i_p_io < num_rows)) -def nki_layernorm_kernel_v2(input_tensor, epsilon, gamma_vector, beta_vector, output_tensor): + return output_tensor + # NKI_EXAMPLE_45_END + + +# NKI_EXAMPLE_46_BEGIN +@nki.jit +def nki_layernorm_kernel_v2(input_tensor, epsilon, gamma_vector, beta_vector): """Computes LayerNorm. Used nki.isa APIs to calculate mean/variance and perform shift/scale. """ + output_tensor = nl.ndarray(input_tensor.shape, dtype=input_tensor.dtype, + buffer=nl.shared_hbm) + # Ensure that the shapes of tensors match - assert input_tensor.shape == output_tensor.shape assert input_tensor.shape[1] == gamma_vector.shape[0] == beta_vector.shape[0] # Generate tile indices for loading/storing data @@ -122,69 +136,66 @@ def nki_layernorm_kernel_v2(input_tensor, epsilon, gamma_vector, beta_vector, ou nl.store(output_tensor[i * nl.tile_size.pmax + i_p_io, i_f_io], value=output_sb, mask=(i * nl.tile_size.pmax + i_p_io < num_rows)) + return output_tensor + # NKI_EXAMPLE_46_END + def parse_args(): - parser = argparse.ArgumentParser( - """Run LayerNorm pytorch implementation. - """) - parser.add_argument("--nrows", - default=4*1024, - type=int, - help="""The number of input rows""") - parser.add_argument("--ncols", - default=8*1024, - type=int, - help="""The number of input columns""") - parser.add_argument("--mode", - choices=["accuracy", "perf"], - default="accuracy", - help="""Do accuracy test or perf test. - Accuracy test compares LayerNorm kernel against PyTorch implementation. - Perf test will generate a NEFF for the PyTorch implementation in local directory - for a manual run of neuron-profile. - """) - args = parser.parse_args() - return args + parser = argparse.ArgumentParser( + """Run LayerNorm pytorch implementation. + """) + parser.add_argument("--nrows", + default=4*1024, + type=int, + help="""The number of input rows""") + parser.add_argument("--ncols", + default=8*1024, + type=int, + help="""The number of input columns""") + parser.add_argument("--mode", + choices=["accuracy", "perf"], + default="accuracy", + help="""Do accuracy test or perf test. + Accuracy test compares LayerNorm kernel against PyTorch implementation. + Perf test will generate a NEFF for the PyTorch implementation in local directory + for a manual run of neuron-profile. + """) + args = parser.parse_args() + return args if __name__ == "__main__": - args = parse_args() - func_dict = {"v1": nki_layernorm_kernel_v1, - "v2": nki_layernorm_kernel_v2, - } - - # Generate toy example - num_rows = args.nrows - num_cols = args.ncols - input_tensor = np.random.rand(num_rows, num_cols).astype(np.float32) - gamma_vector = np.random.rand(num_cols).astype(np.float32) - beta_vector = np.random.rand(num_cols).astype(np.float32) - epsilon = 1e-5 - - if args.mode == "accuracy": - # version 1 - print(f">>>> Running version 1") - nki_out_v1 = np.empty((num_rows, num_cols), dtype=np.float32) - nki.baremetal(nki_layernorm_kernel_v1)\ - (input_tensor, epsilon, gamma_vector, beta_vector, nki_out_v1) - # version 2 - print(f">>>> Running version 2") - nki_out_v2 = np.empty((num_rows, num_cols), dtype=np.float32) - nki.baremetal(nki_layernorm_kernel_v2)\ - (input_tensor, epsilon, gamma_vector, beta_vector, nki_out_v2) - # compare - np_all = np.all(nki_out_v1 == nki_out_v1) - print(f">>>> LayerNorm V1 and V2 matches?", np_all) - assert np_all - - else: - # perf mode - for version in ["v1", "v2"]: - print(f">>>> Running version {version}.") - func = func_dict[version] - nki_out_test = np.empty((num_rows, num_cols), dtype=np.float32) - nki.benchmark(func, - save_neff_name='file.neff', - save_trace_name='profile.ntff')\ - (input_tensor, epsilon, gamma_vector, beta_vector, nki_out_test) - os.rename("file.neff", f"{version}_{num_rows}_{num_cols}.neff") - os.rename("profile.ntff", f"{version}_{num_rows}_{num_cols}.ntff") + args = parse_args() + func_dict = {"v1": nki_layernorm_kernel_v1, + "v2": nki_layernorm_kernel_v2, + } + + # Generate toy example + num_rows = args.nrows + num_cols = args.ncols + input_tensor = np.random.rand(num_rows, num_cols).astype(np.float32) + gamma_vector = np.random.rand(num_cols).astype(np.float32) + beta_vector = np.random.rand(num_cols).astype(np.float32) + epsilon = 1e-5 + + if args.mode == "accuracy": + # version 1 + print(f">>>> Running version 1") + nki_out_v1 = nki_layernorm_kernel_v1(input_tensor, epsilon, gamma_vector, beta_vector) + # version 2 + print(f">>>> Running version 2") + nki_out_v2 = nki_layernorm_kernel_v2(input_tensor, epsilon, gamma_vector, beta_vector) + # compare + np_all = np.all(nki_out_v1 == nki_out_v1) + print(f">>>> LayerNorm V1 and V2 matches?", np_all) + assert np_all + + else: + # perf mode + for version in ["v1", "v2"]: + print(f">>>> Running version {version}.") + func = func_dict[version] + benchmark_kernel = nki.benchmark(func, save_neff_name='file.neff', + save_trace_name='profile.ntff') + nki_out_test = benchmark_kernel(input_tensor, epsilon, gamma_vector, beta_vector) + os.rename("file.neff", f"{version}_{num_rows}_{num_cols}.neff") + os.rename("profile.ntff", f"{version}_{num_rows}_{num_cols}.ntff") diff --git a/src/tutorials/layernorm/layernorm_torch.py b/src/nki_samples/tutorials/layernorm/layernorm_torch.py similarity index 87% rename from src/tutorials/layernorm/layernorm_torch.py rename to src/nki_samples/tutorials/layernorm/layernorm_torch.py index 59853fd..c2be186 100644 --- a/src/tutorials/layernorm/layernorm_torch.py +++ b/src/nki_samples/tutorials/layernorm/layernorm_torch.py @@ -4,9 +4,9 @@ LayerNorm NKI kernel implementation. """ +# NKI_EXAMPLE_47_BEGIN import torch from torch_xla.core import xla_model as xm -from torch_neuronx import nki_jit import argparse import os @@ -42,13 +42,16 @@ def parse_args(): args = parser.parse_args() return args + +from neuronxcc.nki.docs.examples.layernorm.layernorm_nki_kernel import nki_layernorm_kernel_v1, \ + nki_layernorm_kernel_v2 + if __name__ == "__main__": args = parse_args() - from neuronxcc.nki.docs.examples.layernorm.layernorm_nki_kernel import nki_layernorm_kernel_v1, nki_layernorm_kernel_v2 func_dict = {"v1": nki_layernorm_kernel_v1, "v2": nki_layernorm_kernel_v2, } - + device = xm.xla_device() num_rows = args.nrows num_cols = args.ncols @@ -58,7 +61,7 @@ def parse_args(): gamma_vector = torch.rand((num_cols), dtype=torch.float32) beta_vector = torch.rand((num_cols), dtype=torch.float32) epsilon = 1e-5 - + # Compute torch layernorm layer in cpu output_torch = layernorm_layer(input_tensor, epsilon, gamma_vector, beta_vector) @@ -66,17 +69,15 @@ def parse_args(): input_tensor = input_tensor.to(device=device) gamma_vector = gamma_vector.to(device=device) beta_vector = beta_vector.to(device=device) - output_nki = torch.zeros((num_rows, num_cols), dtype=torch.float32).to(device=device) print(f">>>> Running version {args.version}.") func = func_dict[args.version] # add nki_jit decorator - nki_layernorm_kernel = nki_jit(func) # Compute NKI layernorm kernel in NeuronDevice xm.mark_step() - nki_layernorm_kernel(input_tensor, epsilon, gamma_vector, beta_vector, output_nki) + output_nki = func(input_tensor, epsilon, gamma_vector, beta_vector) xm.mark_step() output_nki = output_nki.to(device='cpu') @@ -86,5 +87,6 @@ def parse_args(): print("NKI and Torch match") else: print("NKI and Torch differ") - + # NKI_EXAMPLE_47_END + assert allclose \ No newline at end of file diff --git a/src/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py b/src/nki_samples/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py similarity index 94% rename from src/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py rename to src/nki_samples/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py index 7aeb5d6..8f913f2 100644 --- a/src/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py +++ b/src/nki_samples/tutorials/matrix_multiplication/matrix_multiplication_nki_kernels.py @@ -12,7 +12,9 @@ import numpy as np -def nki_matmul_basic_(lhsT, rhs, result): +# NKI_EXAMPLE_16_BEGIN +@nki.jit +def nki_matmul_basic_(lhsT, rhs): """NKI kernel to compute a 64x128x512 matrix multiplication operation Args: @@ -20,8 +22,11 @@ def nki_matmul_basic_(lhsT, rhs, result): matrix multiplication, delivered transposed for optimal performance rhs: an input tensor of shape [128,512], a right hand side argument of the matrix multiplication + Returns: result: the resulting output tensor of shape [64,512] """ + result = nl.ndarray((64, 512), dtype=lhsT.dtype, buffer=nl.shared_hbm) + # Defining indexes for input LHS.T # - Note: here we take LayoutConstraint #1 into account: # "For MatMult, contraction axis must be mapped to P-dim" @@ -53,8 +58,13 @@ def nki_matmul_basic_(lhsT, rhs, result): # This dictates which indices to use to address the result tile. nl.store(result[i_out_p, i_out_f], value=result_sbuf) + return result + # NKI_EXAMPLE_16_END + -def nki_matmul_tiled_(lhsT, rhs, result): +# NKI_EXAMPLE_18_BEGIN +@nki.jit +def nki_matmul_tiled_(lhsT, rhs): """NKI kernel to compute a matrix multiplication operation in a tiled manner Args: @@ -64,12 +74,14 @@ def nki_matmul_tiled_(lhsT, rhs, result): rhs: an input tensor of shape [K,N], where K is a multiple of 128, and N is a multiple of 512. It is the right-hand-side argument of the matrix multiplication. + Returns: result: the resulting output tensor of shape [M,N] """ K, M = lhsT.shape K_, N = rhs.shape assert K == K_, "lhsT and rhs must have the same contraction dimension" + result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm) TILE_M = nl.tile_size.gemm_stationary_fmax # 128 TILE_K = nl.tile_size.pmax # 128 @@ -100,8 +112,13 @@ def nki_matmul_tiled_(lhsT, rhs, result): nl.store(result[m * TILE_M:(m + 1) * TILE_M, n * TILE_N:(n + 1) * TILE_N], value=res_sb) + return result + # NKI_EXAMPLE_18_END -def nki_matmul_hoist_load_(lhsT, rhs, result): + +# NKI_EXAMPLE_19_BEGIN +@nki.jit +def nki_matmul_hoist_load_(lhsT, rhs): """NKI kernel to compute a matrix multiplication operation in a tiled manner while hoisting the load of the lhsT and rhs to outer loops. @@ -112,12 +129,14 @@ def nki_matmul_hoist_load_(lhsT, rhs, result): rhs: an input tensor of shape [K,N], where K is a multiple of 128, and N is a multiple of 512. It is the right-hand-side argument of the matrix multiplication. + Returns: result: the resulting output tensor of shape [M,N] """ K, M = lhsT.shape K_, N = rhs.shape assert K == K_, "lhsT and rhs must have the same contraction dimension" + result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm) TILE_M = nl.tile_size.gemm_stationary_fmax # 128 TILE_K = nl.tile_size.pmax # 128 @@ -163,8 +182,13 @@ def nki_matmul_hoist_load_(lhsT, rhs, result): res_sb = nl.copy(res_psum, dtype=result.dtype) nl.store(result[m * TILE_M + i_res.p, n * TILE_N + i_res.x], value=res_sb) + return result + # NKI_EXAMPLE_19_END + -def nki_matmul_block_free_dimension_(lhsT, rhs, result): +# NKI_EXAMPLE_20_BEGIN +@nki.jit +def nki_matmul_block_free_dimension_(lhsT, rhs): """NKI kernel to compute a matrix multiplication operation while blocking the free dimensions of the LHS and RHS to improve memory access pattern. @@ -175,12 +199,14 @@ def nki_matmul_block_free_dimension_(lhsT, rhs, result): rhs: an input tensor of shape [K,N], where K is a multiple of 128, and N is a multiple of 512. It is the right-hand-side argument of the matrix multiplication. + Returns: result: the resulting output tensor of shape [M,N] """ K, M = lhsT.shape K_, N = rhs.shape assert K == K_, "lhsT and rhs must have the same contraction dimension" + result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm) TILE_M = nl.tile_size.gemm_stationary_fmax # 128 TILE_K = nl.tile_size.pmax # 128 @@ -243,11 +269,15 @@ def nki_matmul_block_free_dimension_(lhsT, rhs, result): (n * TILES_IN_BLOCK_N + bn) * TILE_N + i_res.x], value=res_sb) + return result + # NKI_EXAMPLE_20_END + +# NKI_EXAMPLE_21_BEGIN +@nki.jit def nki_matmul_fully_optimized_( lhsT, rhs, - result, # Meta-parameters TILES_IN_BLOCK_M=16, TILES_IN_BLOCK_N=2, @@ -264,13 +294,15 @@ def nki_matmul_fully_optimized_( rhs: an input tensor of shape [K,N], where K is a multiple of 128 * TILES_IN_BLOCK_K and N is a multiple of 512 * TILES_IN_BLOCK_N. It is the right-hand-side argument of the matrix multiplication. - result: the resulting output tensor of shape [M,N] TILES_IN_BLOCK_*: meta parameters to control blocking dimensions + Returns: + result: the resulting output tensor of shape [M,N] """ K, M = lhsT.shape K_, N = rhs.shape assert K == K_, "lhsT and rhs must have the same contraction dimension" + result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm) TILE_M = nl.tile_size.gemm_stationary_fmax # 128 TILE_K = nl.tile_size.pmax # 128 @@ -360,16 +392,19 @@ def nki_matmul_fully_optimized_( BLOCK_N * n + i_res_packed.x], value=result_packed[i_res_packed.p, i_res_packed.x]) + return result +# NKI_EXAMPLE_21_END + +# NKI_EXAMPLE_23_BEGIN if __name__ == "__main__": # Benchmarking with large matrices to show the differences more clearly lhsT = nt.tensor[[8192, 4096], nl.bfloat16] rhs = nt.tensor[[8192, 8192], nl.bfloat16] - output = nt.tensor[[4096, 8192], nl.bfloat16] def benchmark_nki(nki_func): bench_func = nki.benchmark(warmup=5, iters=10)(nki_func) - bench_func(lhsT, rhs, output) + bench_func(lhsT, rhs) latency_res = bench_func.benchmark_result.nc_latency p99 = latency_res.get_latency_percentile(99) print("Latency: {:.2f} ms (P99)".format(p99 / 1000.0)) @@ -385,3 +420,4 @@ def benchmark_nki(nki_func): print("Benchmarking nki_matmul_fully_optimized") benchmark_nki(nki_matmul_fully_optimized_) + # NKI_EXAMPLE_23_END diff --git a/src/tutorials/matrix_multiplication/matrix_multiplication_torch.py b/src/nki_samples/tutorials/matrix_multiplication/matrix_multiplication_torch.py similarity index 83% rename from src/tutorials/matrix_multiplication/matrix_multiplication_torch.py rename to src/nki_samples/tutorials/matrix_multiplication/matrix_multiplication_torch.py index ec0084c..de39ce8 100644 --- a/src/tutorials/matrix_multiplication/matrix_multiplication_torch.py +++ b/src/nki_samples/tutorials/matrix_multiplication/matrix_multiplication_torch.py @@ -7,23 +7,21 @@ import torch from torch_xla.core import xla_model as xm -from torch_neuronx import nki_jit from matrix_multiplication_nki_kernels import nki_matmul_basic_, nki_matmul_tiled_, nki_matmul_hoist_load_, nki_matmul_block_free_dimension_, nki_matmul_fully_optimized_ if __name__ == "__main__": + # NKI_EXAMPLE_17_BEGIN device = xm.xla_device() cpu = torch.device('cpu') # Test the small workload with basic kernel lhs_small = torch.rand((64, 128), dtype=torch.bfloat16, device=device) rhs_small = torch.rand((128, 512), dtype=torch.bfloat16, device=device) - output_small = torch.zeros((64, 512), dtype=torch.bfloat16, device=device) # Run NKI kernel - nki_matmul_basic_jit = nki_jit(nki_matmul_basic_) - nki_matmul_basic_jit(lhs_small.T, rhs_small, output_small) + output_small = nki_matmul_basic_(lhs_small.T, rhs_small) # Run torch reference output_small_torch = torch.matmul(lhs_small, rhs_small) @@ -34,18 +32,18 @@ print("NKI and Torch match") else: print("NKI and Torch differ") + # NKI_EXAMPLE_17_END + # NKI_EXAMPLE_22_BEGIN # Test the large workload with tiled kernels lhs = torch.rand((4096, 1024), dtype=torch.bfloat16, device=device) rhs = torch.rand((1024, 2048), dtype=torch.bfloat16, device=device) - output = torch.zeros((4096, 2048), dtype=torch.bfloat16, device=device) # Run torch reference output_torch = torch.matmul(lhs, rhs).to(device=cpu) def check_match(nki_func): - jit_func = nki_jit(nki_func) - jit_func(lhs.T, rhs, output) + output = nki_func(lhs.T, rhs) output_nki = output.to(device=cpu) if torch.allclose(output_torch, output_nki, atol=1e-4, rtol=1e-2): print("NKI and Torch match") @@ -63,3 +61,4 @@ def check_match(nki_func): print("Checking correctness of nki_matmul_fully_optimized") check_match(nki_matmul_fully_optimized_) + # NKI_EXAMPLE_22_END diff --git a/src/tutorials/rmsnorm/rmsnorm_jax.py b/src/nki_samples/tutorials/rmsnorm/rmsnorm_jax.py similarity index 84% rename from src/tutorials/rmsnorm/rmsnorm_jax.py rename to src/nki_samples/tutorials/rmsnorm/rmsnorm_jax.py index 5b412d8..f0efc20 100644 --- a/src/tutorials/rmsnorm/rmsnorm_jax.py +++ b/src/nki_samples/tutorials/rmsnorm/rmsnorm_jax.py @@ -7,9 +7,9 @@ import jax import jax.numpy as jnp -from jax_neuronx import nki_call from rmsnorm_nki_kernels import nki_rmsnorm_kernel +# NKI_EXAMPLE_44_BEGIN # Reference JAX implementation def jax_rms_norm(a_tensor, g_tensor): # Square the tensor (element-wise) @@ -26,11 +26,7 @@ def jax_rms_norm(a_tensor, g_tensor): a_tensor = jax.random.uniform(a_key, (250, 512)) g_tensor = jax.random.uniform(g_key, (512,)) -output_nki = nki_call( - nki_rmsnorm_kernel, - a_tensor, g_tensor, - out_shape=jax.ShapeDtypeStruct(a_tensor.shape, dtype=a_tensor.dtype), -) +output_nki = nki_rmsnorm_kernel(a_tensor, g_tensor) print(a_tensor) @@ -43,3 +39,4 @@ def jax_rms_norm(a_tensor, g_tensor): print("NKI and JAX match") else: print("NKI and JAX differ") + # NKI_EXAMPLE_44_END diff --git a/src/tutorials/rmsnorm/rmsnorm_nki_kernels.py b/src/nki_samples/tutorials/rmsnorm/rmsnorm_nki_kernels.py similarity index 90% rename from src/tutorials/rmsnorm/rmsnorm_nki_kernels.py rename to src/nki_samples/tutorials/rmsnorm/rmsnorm_nki_kernels.py index 140b682..402eecd 100644 --- a/src/tutorials/rmsnorm/rmsnorm_nki_kernels.py +++ b/src/nki_samples/tutorials/rmsnorm/rmsnorm_nki_kernels.py @@ -6,20 +6,23 @@ """ import numpy as np +# NKI_EXAMPLE_42_BEGIN import math import neuronxcc.nki as nki import neuronxcc.nki.language as nl -def nki_rmsnorm_kernel(a_tensor, g_tensor, out_tensor): +@nki.jit +def nki_rmsnorm_kernel(a_tensor, g_tensor): # Calculate out_tensor = a_tensor/RMS(a_tensor) * g_tensor # Where RMS(a_tensor) = sqrt((1/N) * sum(a_tensor * a_tensor)) # and N = a_tensor.shape[1] # Reduction (mean) is performed in the free (2nd) dimension + out_tensor = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype, + buffer=nl.shared_hbm) # Make sure shapes match assert a_tensor.shape[1] == g_tensor.shape[0] - assert a_tensor.shape == out_tensor.shape # Generate tensor indices to index input tensor ix = nl.arange(128)[:, None] @@ -68,14 +71,15 @@ def nki_rmsnorm_kernel(a_tensor, g_tensor, out_tensor): nl.store(out_tensor[i * 128 + ix, iy], value=out_tile, mask=(i * 128 + ix < num_rows)) + return out_tensor + # NKI_EXAMPLE_42_END + if __name__ == "__main__": a = np.random.rand(128, 512).astype(np.float32) g = np.random.rand(512).astype(np.float32) - output_nki = np.zeros(a.shape, dtype=a.dtype) - nki_rmsnorm_kernel_baremetal = nki.baremetal(nki_rmsnorm_kernel) - nki_rmsnorm_kernel_baremetal(a, g, output_nki) + output_nki = nki_rmsnorm_kernel(a, g) print(f"output_nki={output_nki}") # One-line numpy RMSNorm diff --git a/src/tutorials/rmsnorm/rmsnorm_torch.py b/src/nki_samples/tutorials/rmsnorm/rmsnorm_torch.py similarity index 82% rename from src/tutorials/rmsnorm/rmsnorm_torch.py rename to src/nki_samples/tutorials/rmsnorm/rmsnorm_torch.py index 71ced3e..c9bfc69 100644 --- a/src/tutorials/rmsnorm/rmsnorm_torch.py +++ b/src/nki_samples/tutorials/rmsnorm/rmsnorm_torch.py @@ -5,11 +5,11 @@ """ -from torch_neuronx.xla_impl.ops import nki_jit import torch import os from rmsnorm_nki_kernels import nki_rmsnorm_kernel +# NKI_EXAMPLE_43_BEGIN # Reference torch implementation def torch_rmsnorm_kernel(a_tensor, g_tensor): # Square the tensor (element-wise) @@ -25,13 +25,10 @@ def torch_rmsnorm_kernel(a_tensor, g_tensor): from torch_xla.core import xla_model as xm device = xm.xla_device() -nki_rmsnorm_kernel = nki_jit(nki_rmsnorm_kernel) - a_tensor = torch.rand((250, 512), dtype=torch.float32).to(device=device) g_tensor = torch.rand((512), dtype=torch.float32).to(device=device) -output_nki = torch.zeros((250, 512), dtype=torch.float32).to(device=device) -nki_rmsnorm_kernel(a_tensor, g_tensor, output_nki) +output_nki = nki_rmsnorm_kernel(a_tensor, g_tensor) print(f"output_nki={output_nki}") output_torch = torch_rmsnorm_kernel(a_tensor, g_tensor) @@ -41,3 +38,4 @@ def torch_rmsnorm_kernel(a_tensor, g_tensor): print("NKI and Torch match") else: print("NKI and Torch differ") +# NKI_EXAMPLE_43_END diff --git a/src/tutorials/sd_attention/sd_attention_nki_kernels.py b/src/nki_samples/tutorials/sd_attention/sd_attention_nki_kernels.py similarity index 81% rename from src/tutorials/sd_attention/sd_attention_nki_kernels.py rename to src/nki_samples/tutorials/sd_attention/sd_attention_nki_kernels.py index e5eec25..6d1f781 100644 --- a/src/tutorials/sd_attention/sd_attention_nki_kernels.py +++ b/src/nki_samples/tutorials/sd_attention/sd_attention_nki_kernels.py @@ -12,7 +12,9 @@ import argparse from scipy.special import softmax -def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_causal_mask=False, +# NKI_EXAMPLE_31_BEGIN +@nki.jit +def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask=False, mixed_percision=True): """ Fused self attention kernel for small head dimension Stable Diffusion workload, @@ -44,7 +46,7 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_cau # Assume all IO tensors have the same dtype kernel_dtype = q_ref.dtype pe_in_dt = nl.bfloat16 if mixed_percision else np.float32 - assert q_ref.dtype == k_ref.dtype == v_ref.dtype == out_ref.dtype + assert q_ref.dtype == k_ref.dtype == v_ref.dtype # Shape checking seqlen, d_head = q_ref.shape @@ -53,7 +55,7 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_cau assert tuple(k_ref.shape) == (seqlen, d_head), 'Input shape mismatch!' assert tuple(v_ref.shape) == (seqlen,d_head), \ f'Input shape mismatch! Expected: {(seqlen, d_head)} Actual: {tuple(v_ref.shape)}' - assert tuple(out_ref.shape) == (seqlen, d_head), 'Output shape mismatch!' + out_ref = nl.ndarray((seqlen, d_head), dtype=q_ref.dtype, buffer=nl.shared_hbm) # Softmax scaling factor, multiplied onto Q softmax_scale = 0.125 @@ -210,58 +212,61 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_cau out_ref[i_q_seq_tile * q_seq_tile_size + if_out, ip_out], value=attn_res_div) + return out_ref +# NKI_EXAMPLE_31_END + def parse_args(): - parser = argparse.ArgumentParser("Run Stable Diffusion Attention NKI kernel.") - parser.add_argument("--mode", - choices=["accuracy", "perf"], - default="accuracy", - help="""Do accuracy test or perf test. - Accuracy test uses cpu golden output as golden reference. - """) + parser = argparse.ArgumentParser("Run Stable Diffusion Attention NKI kernel.") + parser.add_argument("--mode", + choices=["accuracy", "perf"], + default="accuracy", + help="""Do accuracy test or perf test. + Accuracy test uses cpu golden output as golden reference. + """) + + args = parser.parse_args() + return args - args = parser.parse_args() - return args def cpu_golden_attn(q, k, v): - softmax_scale = 0.125 + softmax_scale = 0.125 - q_scaled = q * softmax_scale - raw_score = np.matmul(q_scaled, k.transpose()) - norm_score = softmax(raw_score, axis=-1) + q_scaled = q * softmax_scale + raw_score = np.matmul(q_scaled, k.transpose()) + norm_score = softmax(raw_score, axis=-1) - return np.matmul(norm_score, v) + return np.matmul(norm_score, v) if __name__ == "__main__": - args = parse_args() - - print(f"Running {args.mode} mode.") - - seqlen, d_head = 4096, 64 - - # Set up input tensors - dtype = np.float32 - q_tensor = np.random.rand(seqlen, d_head).astype(dtype) - k_tensor = np.random.rand(seqlen, d_head).astype(dtype) - v_tensor = np.random.rand(seqlen, d_head).astype(dtype) - output_nki = np.empty((seqlen, d_head), dtype=dtype) - output_golden = cpu_golden_attn(q_tensor, k_tensor, v_tensor) - - if args.mode == "accuracy": - nki.baremetal(fused_self_attn_for_SD_small_head_size)\ - (q_tensor, k_tensor, v_tensor, output_nki) - allclose = np.allclose(output_nki, output_golden, atol=1e-5, rtol=1e-3) - print(f">>>> SD attention matches CPU reference? {allclose}") - assert allclose, "Accuracy check fails!" - - else: - benchmark_func = nki.benchmark(fused_self_attn_for_SD_small_head_size, - save_neff_name='file.neff', - save_trace_name='profile.ntff') - benchmark_func(q_tensor, k_tensor, v_tensor, output_nki) - - metrics = benchmark_func.benchmark_result.nc_latency - print(">>>> SD attention benchmark results") - print("latency.p50 = " + str(metrics.get_latency_percentile(50))) - print("latency.p99 = " + str(metrics.get_latency_percentile(99))) \ No newline at end of file + args = parse_args() + + print(f"Running {args.mode} mode.") + + seqlen, d_head = 4096, 64 + + # Set up input tensors + dtype = np.float32 + q_tensor = np.random.rand(seqlen, d_head).astype(dtype) + k_tensor = np.random.rand(seqlen, d_head).astype(dtype) + v_tensor = np.random.rand(seqlen, d_head).astype(dtype) + output_nki = np.empty((seqlen, d_head), dtype=dtype) + output_golden = cpu_golden_attn(q_tensor, k_tensor, v_tensor) + + if args.mode == "accuracy": + output_nki = fused_self_attn_for_SD_small_head_size(q_tensor, k_tensor, v_tensor) + allclose = np.allclose(output_nki, output_golden, atol=1e-5, rtol=1e-3) + print(f">>>> SD attention matches CPU reference? {allclose}") + assert allclose, "Accuracy check fails!" + + else: + benchmark_func = nki.benchmark(fused_self_attn_for_SD_small_head_size, + save_neff_name='file.neff', + save_trace_name='profile.ntff') + benchmark_func(q_tensor, k_tensor, v_tensor) + + metrics = benchmark_func.benchmark_result.nc_latency + print(">>>> SD attention benchmark results") + print("latency.p50 = " + str(metrics.get_latency_percentile(50))) + print("latency.p99 = " + str(metrics.get_latency_percentile(99))) \ No newline at end of file diff --git a/src/tutorials/sd_attention/sd_attention_torch.py b/src/nki_samples/tutorials/sd_attention/sd_attention_torch.py similarity index 79% rename from src/tutorials/sd_attention/sd_attention_torch.py rename to src/nki_samples/tutorials/sd_attention/sd_attention_torch.py index f124607..639e5cf 100644 --- a/src/tutorials/sd_attention/sd_attention_torch.py +++ b/src/nki_samples/tutorials/sd_attention/sd_attention_torch.py @@ -5,8 +5,8 @@ """ +# NKI_EXAMPLE_32_BEGIN import torch -from torch_neuronx.xla_impl.ops import nki_jit from torch_xla.core import xla_model as xm from sd_attention_nki_kernels import fused_self_attn_for_SD_small_head_size @@ -28,10 +28,8 @@ def cpu_golden_attn(q, k, v): q_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device) k_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device) v_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device) - output_nki = torch.zeros((4096, 64), dtype=torch.float32).to(device=device) - nki_func = nki_jit(func=fused_self_attn_for_SD_small_head_size) - nki_func(q_tensor, k_tensor, v_tensor, output_nki) + output_nki = fused_self_attn_for_SD_small_head_size(q_tensor, k_tensor, v_tensor) output_torch = cpu_golden_attn(q_tensor, k_tensor, v_tensor) @@ -42,4 +40,5 @@ def cpu_golden_attn(q, k, v): else: print("NKI and Torch differ") - assert allclose \ No newline at end of file + assert allclose + # NKI_EXAMPLE_32_END diff --git a/src/nki_samples/tutorials/tensor_addition/tensor_addition_jax.py b/src/nki_samples/tutorials/tensor_addition/tensor_addition_jax.py new file mode 100644 index 0000000..e40f962 --- /dev/null +++ b/src/nki_samples/tutorials/tensor_addition/tensor_addition_jax.py @@ -0,0 +1,35 @@ +""" +Copyright (C) 2024, Amazon.com. All Rights Reserved + +JAX implementation for tensor addition NKI tutorial. + +""" +# NKI_EXAMPLE_30_BEGIN +import jax +import jax.numpy as jnp +# NKI_EXAMPLE_30_END + +from tensor_addition_nki_kernels import nki_tensor_add + + +# NKI_EXAMPLE_30_BEGIN +if __name__ == "__main__": + + seed_a, seed_b = jax.random.split(jax.random.PRNGKey(42)) + a = jax.random.uniform(seed_a, (256, 1024), dtype=jnp.bfloat16) + b = jax.random.uniform(seed_b, (256, 1024), dtype=jnp.bfloat16) + + output_nki = nki_tensor_add(a, b) + print(f"output_nki={output_nki}") + + output_jax = a + b + print(f"output_jax={output_jax}") + + allclose = jnp.allclose(output_jax, output_nki, atol=1e-4, rtol=1e-2) + if allclose: + print("NKI and JAX match") + else: + print("NKI and JAX differ") + + assert allclose + # NKI_EXAMPLE_30_END diff --git a/src/tutorials/tensor_addition/tensor_addition_nki_kernels.py b/src/nki_samples/tutorials/tensor_addition/tensor_addition_nki_kernels.py similarity index 76% rename from src/tutorials/tensor_addition/tensor_addition_nki_kernels.py rename to src/nki_samples/tutorials/tensor_addition/tensor_addition_nki_kernels.py index 2b49237..ea72488 100644 --- a/src/tutorials/tensor_addition/tensor_addition_nki_kernels.py +++ b/src/nki_samples/tutorials/tensor_addition/tensor_addition_nki_kernels.py @@ -5,20 +5,26 @@ """ import numpy as np +# NKI_EXAMPLE_27_BEGIN import neuronxcc.nki as nki import neuronxcc.nki.language as nl -def nki_tensor_add_kernel_(a_input, b_input, c_output): +@nki.jit +def nki_tensor_add_kernel_(a_input, b_input): """NKI kernel to compute element-wise addition of two input tensors - This kernel assumes strict input/output tile-sizes, of up-to [128,512] + This kernel assumes strict input/output sizes can be uniformly tiled to [128,512] Args: - a_input: a first input tensor, of shape [128,512] - b_input: a second input tensor, of shape [128,512] - c_output: an output tensor, of shape [128,512] + a_input: a first input tensor + b_input: a second input tensor + + Returns: + c_output: an output tensor """ + # Create output tensor shared between all SPMD instances as result tensor + c_output = nl.ndarray(a_input.shape, dtype=a_input.dtype, buffer=nl.shared_hbm) # Calculate tile offsets based on current 'program' offset_i_x = nl.program_id(0) * 128 @@ -39,7 +45,12 @@ def nki_tensor_add_kernel_(a_input, b_input, c_output): # store the addition results back to device memory (c_output) nl.store(c_output[ix, iy], value=c_tile) + # Transfer the ownership of `c_output` to the caller + return c_output + # NKI_EXAMPLE_27_END + +# NKI_EXAMPLE_28_BEGIN def nki_tensor_add(a_input, b_input): """NKI kernel caller to compute element-wise addition of two input tensors @@ -57,12 +68,9 @@ def nki_tensor_add(a_input, b_input): # In this case, we use a 2D grid where the size of each invocation is 128x512 grid_x = a_input.shape[0] // 128 grid_y = a_input.shape[1] // 512 - c_output = np.zeros(a_input.shape, dtype=a_input.dtype) - - nki_tensor_add_kernel_baremetal = nki.baremetal(nki_tensor_add_kernel_) - nki_tensor_add_kernel_baremetal[grid_x, grid_y](a_input, b_input, c_output) - return c_output + return nki_tensor_add_kernel_[grid_x, grid_y](a_input, b_input) + # NKI_EXAMPLE_28_END if __name__ == "__main__": diff --git a/src/nki_samples/tutorials/tensor_addition/tensor_addition_torch.py b/src/nki_samples/tutorials/tensor_addition/tensor_addition_torch.py new file mode 100644 index 0000000..83673e5 --- /dev/null +++ b/src/nki_samples/tutorials/tensor_addition/tensor_addition_torch.py @@ -0,0 +1,35 @@ +""" +Copyright (C) 2024, Amazon.com. All Rights Reserved + +PyTorch implementation for tensor addition NKI tutorial. + +""" +# NKI_EXAMPLE_29_BEGIN +import torch +from torch_xla.core import xla_model as xm +# NKI_EXAMPLE_29_END + +from tensor_addition_nki_kernels import nki_tensor_add + + +# NKI_EXAMPLE_29_BEGIN +if __name__ == "__main__": + device = xm.xla_device() + + a = torch.rand((256, 1024), dtype=torch.bfloat16).to(device=device) + b = torch.rand((256, 1024), dtype=torch.bfloat16).to(device=device) + + output_nki = nki_tensor_add(a, b) + print(f"output_nki={output_nki}") + + output_torch = a + b + print(f"output_torch={output_torch}") + + allclose = torch.allclose(output_torch, output_nki, atol=1e-4, rtol=1e-2) + if allclose: + print("NKI and Torch match") + else: + print("NKI and Torch differ") + + assert allclose + # NKI_EXAMPLE_29_END diff --git a/src/tutorials/transpose2d/transpose2d_jax.py b/src/nki_samples/tutorials/transpose2d/transpose2d_jax.py similarity index 65% rename from src/tutorials/transpose2d/transpose2d_jax.py rename to src/nki_samples/tutorials/transpose2d/transpose2d_jax.py index 024782c..f23ceef 100644 --- a/src/tutorials/transpose2d/transpose2d_jax.py +++ b/src/nki_samples/tutorials/transpose2d/transpose2d_jax.py @@ -5,25 +5,18 @@ """ +# NKI_EXAMPLE_36_BEGIN import jax import jax.numpy as jnp -from functools import partial -from jax_neuronx import nki_call +# NKI_EXAMPLE_36_END from transpose2d_nki_kernels import tensor_transpose2D_kernel_ - -def transpose2D(in_tensor, shape2D): - return nki_call( - partial(tensor_transpose2D_kernel_, shape2D=shape2D), - in_tensor, - out_shape=jax.ShapeDtypeStruct(in_tensor.shape, dtype=in_tensor.dtype) - ) - +# NKI_EXAMPLE_36_BEGIN if __name__ == "__main__": P, X, Y = 5, 37, 44 a = jax.random.uniform(jax.random.PRNGKey(42), (P, X * Y)) - a_t_nki = transpose2D(a, (X, Y)) + a_t_nki = tensor_transpose2D_kernel_(a, shape2D=(X, Y)) a_t_jax = jnp.transpose(a.reshape(P, X, Y), axes=(0, 2, 1)).reshape(P, X * Y) print(a, a_t_nki, a_t_jax) @@ -35,3 +28,4 @@ def transpose2D(in_tensor, shape2D): print("NKI and JAX differ") assert allclose +# NKI_EXAMPLE_36_END diff --git a/src/tutorials/transpose2d/transpose2d_nki_kernels.py b/src/nki_samples/tutorials/transpose2d/transpose2d_nki_kernels.py similarity index 90% rename from src/tutorials/transpose2d/transpose2d_nki_kernels.py rename to src/nki_samples/tutorials/transpose2d/transpose2d_nki_kernels.py index d993c7e..171e6ed 100644 --- a/src/tutorials/transpose2d/transpose2d_nki_kernels.py +++ b/src/nki_samples/tutorials/transpose2d/transpose2d_nki_kernels.py @@ -5,11 +5,13 @@ """ import numpy as np +# NKI_EXAMPLE_33_BEGIN import neuronxcc.nki as nki import neuronxcc.nki.language as nl -def tensor_transpose2D_kernel_(in_tensor, out_tensor, shape2D): +@nki.jit +def tensor_transpose2D_kernel_(in_tensor, shape2D): """ NKI kernel to reorder the elements on axis[1] of the input tensor. @@ -36,6 +38,8 @@ def tensor_transpose2D_kernel_(in_tensor, out_tensor, shape2D): shape2D: tuple representing the dimensions to be transposed: (#rows, #cols) out_tensor: an output (transposed) tensor """ + out_tensor = nl.ndarray(in_tensor.shape, dtype=in_tensor.dtype, + buffer=nl.shared_hbm) # Gather input shapes sz_p, _ = in_tensor.shape @@ -64,14 +68,15 @@ def tensor_transpose2D_kernel_(in_tensor, out_tensor, shape2D): # Finally, we store out_tile to external memory nl.store(out_tensor, value=out_tile) + return out_tensor + # NKI_EXAMPLE_33_END + if __name__ == "__main__": P, X, Y = 5, 3, 4 a = np.arange(P*X*Y, dtype=np.int8).reshape((P, X*Y)) - a_t_nki = np.zeros((P, Y*X), dtype=np.int8) - tensor_transpose2D_kernel_torch = nki.baremetal(tensor_transpose2D_kernel_) - tensor_transpose2D_kernel_torch(a, a_t_nki, (X, Y)) + a_t_nki = tensor_transpose2D_kernel_(a, (X, Y)) a_t_np = np.transpose(a.reshape(P, X, Y), (0, 2, 1)).reshape(P, X * Y) diff --git a/src/tutorials/transpose2d/transpose2d_torch.py b/src/nki_samples/tutorials/transpose2d/transpose2d_torch.py similarity index 82% rename from src/tutorials/transpose2d/transpose2d_torch.py rename to src/nki_samples/tutorials/transpose2d/transpose2d_torch.py index 71083d7..61fe367 100644 --- a/src/tutorials/transpose2d/transpose2d_torch.py +++ b/src/nki_samples/tutorials/transpose2d/transpose2d_torch.py @@ -4,13 +4,15 @@ PyTorch implementation for transpose2d NKI tutorial. """ +# NKI_EXAMPLE_34_BEGIN import torch from torch_xla.core import xla_model as xm -from torch_neuronx import nki_jit +# NKI_EXAMPLE_34_END from transpose2d_nki_kernels import tensor_transpose2D_kernel_ +# NKI_EXAMPLE_34_BEGIN if __name__ == "__main__": device = xm.xla_device() @@ -18,8 +20,7 @@ a = torch.arange(P*X*Y, dtype=torch.int8).reshape((P, X*Y)).to(device=device) a_t_nki = torch.zeros((P, Y*X), dtype=torch.int8).to(device=device) - tensor_transpose2D_kernel_torch = nki_jit(tensor_transpose2D_kernel_) - tensor_transpose2D_kernel_torch(a, a_t_nki, (X, Y)) + a_t_nki = tensor_transpose2D_kernel_(a, (X, Y)) a_t_torch = torch.transpose(a.reshape(P, X, Y), 1, 2).reshape(P, X * Y) @@ -32,3 +33,4 @@ print("NKI and PyTorch differ") assert allclose + # NKI_EXAMPLE_34_END diff --git a/src/reference/__init__.py b/src/reference/__init__.py deleted file mode 100644 index ad4a18a..0000000 --- a/src/reference/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) 2023, Amazon.com. All Rights Reserved - -""" -Package containing public kernels for Neuron Kernel Interface (NKI). - -Kernels here are the same to the ones available in the -NKI Github Sample Repo. - -TODO: Insert link to Github Repo when available -""" -from neuronxcc.nki.kernels.attention import fused_self_attn_for_SD_small_head_size, flash_attn_bwd, flash_fwd -from neuronxcc.nki.kernels.vision import resize_nearest_fixed_dma_kernel, select_and_scatter_kernel diff --git a/src/reference/attention.py b/src/reference/attention.py deleted file mode 100644 index 81704b5..0000000 --- a/src/reference/attention.py +++ /dev/null @@ -1,1031 +0,0 @@ -""" -Copyright (c) 2023, Amazon.com. All Rights Reserved - -kernels - Builtin high performance attention kernels - -""" -import numpy as np - -from neuronxcc.nki import trace -import neuronxcc.nki.isa as nisa -import neuronxcc.nki.language as nl - -from neuronxcc.nki.language import par_dim -from dataclasses import dataclass - -def div_ceil(n, d): - return (n + d - 1) // d - -@dataclass(frozen=True) -class FlashConfig: - """ - Config class for flash attention with default values - """ - seq_tile_size:int = 2048 - training:bool = True - should_transpose_v:bool = False - - __annotations__ = { - 'seq_tile_size': int, - 'training': bool, - 'should_transpose_v': bool - } - -@trace -def _flash_attention_core(q_local_tile, k, v, - q_h_per_k_h, - o_buffer, l_buffer, m_buffer, - batch_id, head_id, gqa_head_idx, q_tile_idx, - local_k_large_tile_idx, - kernel_dtype, acc_type, - flash_config: FlashConfig, - olm_buffer_idx=None, - global_k_large_tile_idx=None, - use_causal_mask=False, initialize=False, - B_P_SIZE=128, B_F_SIZE=512, B_D_SIZE=128, - dropout_p=0.0, dropout_p_tensor=None, seed_tensor=None - ): - """ - The flash attention core function to calcualte self attention between a tile of q and a block of K and V. - The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF already. The block size of K and V - is defined in the seq_tile_size of the flash_config. The results are stored in the following there buffers - o_buffer: (num_large_k_tile, B_P_SIZE, d) - l_buffer: (num_large_k_tile, B_P_SIZE, 1) - m_buffer: (num_large_k_tile, B_P_SIZE, 1) - """ - LARGE_TILE_SZ = flash_config.seq_tile_size - REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2) - num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE - seq_len = k.shape[-1] - seq_q_num_tiles = seq_len // B_P_SIZE - - # Indices used by the distributed attention - if global_k_large_tile_idx is None: - global_k_large_tile_idx = local_k_large_tile_idx - if olm_buffer_idx is None: - olm_buffer_idx = local_k_large_tile_idx - - i_q_p = nl.arange(B_P_SIZE)[:, None] - i_q_f = nl.arange(B_F_SIZE)[None, :] - i_d_p = nl.arange(B_D_SIZE)[:, None] - i_d_f = nl.arange(B_D_SIZE)[None, :] - i_f_128 = nl.arange(B_P_SIZE)[None, :] - i_f_k_tiles = nl.arange(num_k_tile_per_large_tile)[None, :] - - # mask are used to only apply computation to the lower half of the matrix, - # which reduce the arthimetic intensity by half - forward_mask = q_tile_idx * B_P_SIZE >= global_k_large_tile_idx * LARGE_TILE_SZ if use_causal_mask else None - # Negation mask is the negation of `forward_mask`, which is used for the - # instructions executed on the blocks in the upper triangular section - # of the matrix. - # These instructions should not be executed when causual mask is disabled. - # - # For example, the o_buffer still needs to be propagated from o[j-1] to o[j] in - # the upper triangular of the matrix. - negation_mask = q_tile_idx * B_P_SIZE < global_k_large_tile_idx * LARGE_TILE_SZ if use_causal_mask else None - - qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), buffer=nl.sbuf, dtype=acc_type) - max_local = nl.ndarray((par_dim(B_P_SIZE), num_k_tile_per_large_tile), dtype=acc_type) - for k_i in nl.affine_range(num_k_tile_per_large_tile): - qk_psum = nl.zeros((par_dim(B_P_SIZE), B_F_SIZE), - dtype=np.float32, buffer=nl.psum) # (128, 512) - multiplication_required_selection = global_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE <= q_tile_idx * B_P_SIZE if use_causal_mask else None - qk_psum[i_q_p, i_q_f] += nl.matmul(q_local_tile, k[i_d_p, k_i * B_F_SIZE + i_q_f], transpose_x=True, - mask=multiplication_required_selection) # (p(128), 512) - - if use_causal_mask: - left_diagonal_selection = q_tile_idx * B_P_SIZE >= global_k_large_tile_idx * LARGE_TILE_SZ + (k_i + 1) * B_F_SIZE - diagonal_and_right_selection = (q_tile_idx * B_P_SIZE < global_k_large_tile_idx * LARGE_TILE_SZ + (k_i + 1) * B_F_SIZE) & forward_mask - - q_pos = q_tile_idx * B_P_SIZE + i_q_p - k_pos = global_k_large_tile_idx * LARGE_TILE_SZ + k_i * B_F_SIZE + i_q_f - pred = q_pos >= k_pos - # For tiles on and on the right of the diagonal, need to do affine_select. - # Magic number -9984.0 to replace -inf similar to what Tensorizer uses - qk_res_buf[i_q_p, k_i * B_F_SIZE + i_q_f] = nisa.affine_select( - pred=pred, - on_true_tile=qk_psum[i_q_p, i_q_f], on_false_value=-9984.0, dtype=kernel_dtype, - mask=diagonal_and_right_selection) - - # For tiles on the left of the diagonal, direct copy, no select required. - qk_res_buf[i_q_p, k_i * B_F_SIZE + i_q_f] = \ - nl.copy(qk_psum[i_q_p, i_q_f], dtype=kernel_dtype, mask=left_diagonal_selection) - else: - # Simply send psum result back to sbuf - qk_res_buf[i_q_p, k_i * B_F_SIZE + i_q_f] = \ - nl.copy(qk_psum[i_q_p, i_q_f], dtype=kernel_dtype) - - # Calculate max of the current tile - max_local[i_q_p, k_i] = nisa.tensor_reduce(np.max, qk_res_buf[i_q_p, k_i * B_F_SIZE + i_q_f], axis=(1,), - dtype=acc_type, negate=False, mask=forward_mask) - - max_ = nisa.tensor_reduce(np.max, max_local[i_q_p, i_f_k_tiles], axis=(1, ), - dtype=acc_type, negate=False, mask=forward_mask) - if not initialize: - m_previous = nl.copy(m_buffer[olm_buffer_idx - 1, i_q_p, 0]) - m_buffer[olm_buffer_idx, i_q_p, 0] = nl.maximum(m_previous, max_, mask=forward_mask) # (128,1) - if use_causal_mask: - m_buffer[olm_buffer_idx, i_q_p, 0] = nl.copy(m_previous, mask=negation_mask) - - m_current = m_buffer[olm_buffer_idx, i_q_p, 0] - # Compute scaling factor - alpha = nisa.activation(np.exp, m_previous, bias=-1*m_current, scale=1.0, mask=forward_mask) - o_previous = nl.copy(o_buffer[olm_buffer_idx-1, i_q_p, i_d_f], mask=forward_mask) - o_previous_scaled = nl.multiply(o_previous, alpha, mask=forward_mask) - else: - m_buffer[0, i_q_p, 0] = nl.copy(max_) - m_current = max_ - - p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype) - i_r_f = nl.arange(REDUCTION_TILE)[None,: ] - p_partial_sum = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE), dtype=acc_type) - for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE): - # compute exp(qk-max) - p_local[i_q_p, k_r_i * REDUCTION_TILE + i_r_f] = \ - nisa.activation(np.exp, - qk_res_buf[i_q_p, k_r_i * REDUCTION_TILE + i_r_f], - bias=-1 * m_current, - scale=1.0, - dtype=kernel_dtype, - mask=forward_mask) - - # dropout - if dropout_p > 0.0: - for k_d_i in nl.sequential_range(REDUCTION_TILE // B_F_SIZE): - offset = k_d_i + k_r_i * (REDUCTION_TILE // B_F_SIZE) \ - + global_k_large_tile_idx * (LARGE_TILE_SZ // B_F_SIZE) \ - + q_tile_idx * (seq_len // B_F_SIZE) \ - + (head_id * q_h_per_k_h + gqa_head_idx) * (seq_len // B_F_SIZE) * seq_q_num_tiles \ - + batch_id * nl.num_programs(1) * (seq_len // B_F_SIZE) * seq_q_num_tiles - offset_seed = nl.add(seed_tensor[0, 0], offset, mask=forward_mask) - nl.random_seed(seed=offset_seed, mask=forward_mask) - softmax_dropout = nl.dropout(p_local[i_q_p, k_r_i * REDUCTION_TILE + k_d_i * B_F_SIZE + i_q_f], - rate=dropout_p_tensor[i_q_p, 0], - mask=forward_mask) - p_local[i_q_p, k_r_i * REDUCTION_TILE + k_d_i * B_F_SIZE + i_q_f] = \ - nl.multiply(softmax_dropout, 1 / (1 - dropout_p), mask=forward_mask) - - # Compute partial row-tile sum of exp(qk-max)) - p_partial_sum[i_q_p, k_r_i] = nl.sum(p_local[i_q_p, k_r_i * REDUCTION_TILE + i_r_f], axis=1, dtype=acc_type, mask=forward_mask) - - p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype) - for i_p_t in nl.affine_range(LARGE_TILE_SZ // 512): - p_local_t_tmp = nl.ndarray((par_dim(B_P_SIZE), 512), buffer=nl.psum, dtype=np.float32) - for i_p_t_local in nl.affine_range(512//128): - p_local_t_tmp[i_q_p, i_p_t_local*128 + i_f_128] = nisa.nc_transpose(p_local[i_q_p, i_p_t*512+i_p_t_local * B_P_SIZE + i_f_128]) - i_f_512 = nl.arange(512)[None, :] - p_local_transposed[i_q_p, i_p_t * 512 + i_f_512 ] = nl.copy(p_local_t_tmp[i_q_p, i_f_512], dtype=kernel_dtype) - - ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type, mask=forward_mask) - pv_psum = nl.zeros((par_dim(B_P_SIZE), B_D_SIZE), dtype=np.float32, buffer=nl.psum) - for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE): - pv_psum[i_q_p, i_d_f] += nl.matmul(p_local_transposed[i_q_p, k_i * B_P_SIZE + i_f_128], - v[k_i, i_q_p, i_d_f], - transpose_x=True, - mask=forward_mask) # (128, 128) (p(Br), d) - - if initialize: - o_buffer[olm_buffer_idx, i_q_p, i_d_f] = nl.copy(pv_psum[i_q_p, i_d_f]) - l_buffer[olm_buffer_idx, i_q_p, 0] = nl.add(nl.log(ps), max_) - else: - if use_causal_mask: - o_buffer[olm_buffer_idx, i_q_p, i_d_f] = nl.copy(o_buffer[olm_buffer_idx-1, i_q_p, i_d_f], mask=negation_mask) - o_buffer[olm_buffer_idx, i_q_p, i_d_f] = nl.add(o_previous_scaled, pv_psum, mask=forward_mask) - - l_prev = l_buffer[olm_buffer_idx-1, i_q_p, 0] - l_exp = nl.add(nl.exp(nl.subtract(l_prev, m_current, mask=forward_mask), mask=forward_mask), ps, mask=forward_mask) - l_buffer[olm_buffer_idx, i_q_p, 0] = nl.add(m_current, nl.log(l_exp, mask=forward_mask), mask=forward_mask) - if use_causal_mask: - l_buffer[olm_buffer_idx, i_q_p, 0] = nl.copy(l_buffer[olm_buffer_idx-1, i_q_p, 0], mask=negation_mask) - - -def flash_fwd(q, k, v, seed, o, lse=None, - softmax_scale=None, - use_causal_mask=True, - mixed_precision=True, - dropout_p=0.0, config=None): - """ - Flash Attention Forward kernel - - IO tensor layouts: - - q: shape (bs, n_heads, d, seq_q) - - k: shape (bs, nk_heads, d, seq_k) - - v: shape (bs, nv_heads, d, seq_v) if config.should_transpose_v else (bs, nv_heads, seq_v, d) - - seed: shape (1,) - - o: shape (bs, n_heads, seq_q, d) - - lse: shape (bs, nheads, nl.tile_size.pmax, seq // nl.tile_size.pmax) if training else None - - We use seq_q and seq_k just for clarity, this kernel requires seq_q == seq_k - - IO tensor dtypes: - - This kernel assumes all IO tensors have the same dtype - - If mixed_percision is True, then all Tensor Engine operation will be performed in - bfloat16 and accumulation will be performed in float32. Otherwise the intermediates - will be in the same type as the inputs. - - Compile-time Constants: - - softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)` - - mixed_precision: flag to set non-matmul ops in fp32 precision, defualt is set to `true`, if false, we use same precision as input types - - causal_mask: flag to set causal masking - - config: Instance of dataclass :class:`nki.kernels.attention.FlashConfig` with Performance config parameters for flash attention with default values - seq_tile_size: `default=2048`, size of the kv tile size for attention computation reduction - training: bool to indicate training vs inference `default=True` - - Performance Notes: - For better performance, the kernel is tiled to be of size `LARGE_TILE_SZ`, and Flash attention math techniques are applied in unit - of `LARGE_TILE_SZ`. Seqlen that is not divisible by `LARGE_TILE_SZ` is not supported at the moment. - - GQA support Notes: - the spmd kernel for launching kernel should be on kv_heads instead of nheads - - Example usage: - MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d] - usage: `flash_fwd[b, h](q, k, v, ...)` - GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d] - usage: `flash_fwd[b, kv_h](q, k, v, ...)` - """ - config = config or FlashConfig() - B_F_SIZE=512 - B_P_SIZE=128 - b , h, d, n = q.shape - B_D_SIZE = d - k_h = k.shape[1] - v_shape = v.shape - if config.should_transpose_v: - assert tuple(v_shape) == (b, k_h, d, n), f"V shape does not match layout requirements, expect: {(b, k_h, d, n)} but got {v_shape}" - assert tuple(k.shape) == (b, k_h, d, n), f" k and v shape does not match the layout defined in the function, but got {k.shape}" - else: - assert tuple(v_shape) == (b, k_h, n, d), f"V shape does not match layout requirements, expect: {(b, k_h, n, d)} but got {v_shape}" - assert tuple(k.shape) == (b,k_h, d, n), f" k and v shape does not match the layout defined in the function, but got {k.shape}" - assert d <= 128, f" we do not support head_dim > 128, got head dim {d}" - kernel_dtype = nl.bfloat16 if mixed_precision else q.dtype - acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype - - i_q_p = nl.arange(B_P_SIZE)[:,None] - i_0_f = nl.arange(1)[None, :] - n_tile_q = n//B_P_SIZE # since q will be loaded on PE - - batch_id = nl.program_id(axis=0) - head_id = nl.program_id(axis=1) - softmax_scale = softmax_scale or (1.0 / (d ** 0.5)) - - LARGE_TILE_SZ = config.seq_tile_size - # FIXME: Add masking for different seqlen values. - assert config.seq_tile_size >= 512, f" seq tile_size {config.seq_tile_size} cannot be less than 512" - assert n % LARGE_TILE_SZ == 0, f"seqlen is not divisible by {LARGE_TILE_SZ}" - num_large_k_tile = n // LARGE_TILE_SZ - - # inference flag, check if lse is none - inference = not(config.training) - if inference: - assert lse is None, "lse should be none for inference" - assert seed is None, f"seed should be None for inference, but got {seed}" - assert dropout_p==0.0, f"dropout should be 0.0 for inference but got {dropout_p}" - else: - assert lse is not None, "lse should not be none for training" - q_h_per_k_h = h // k_h - - if dropout_p > 0.0 and not inference: - seed_local = nl.load(seed[0]) - # TODO: Remove this once the dropout supports scale prob - dropout_p_tensor = nl.full((B_P_SIZE, 1), fill_value=dropout_p, dtype=np.float32) - else: - dropout_p_tensor = None - seed_local = None - - for i_q_h in nl.affine_range(q_h_per_k_h): - - # =============== Global Flash Attention accumulators ====================== # - o_buffer = nl.full((n_tile_q, num_large_k_tile, par_dim(B_P_SIZE), d), 0.0, dtype=acc_type, buffer=nl.sbuf) - l_buffer = nl.full((n_tile_q, num_large_k_tile, par_dim(B_P_SIZE), 1), 0.0, dtype=acc_type, buffer=nl.sbuf) - m_buffer = nl.full((n_tile_q, num_large_k_tile, par_dim(B_P_SIZE), 1), 0.0, dtype=acc_type) - # =============== Global Flash Attention accumulators END ================== # - - j = 0 - cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype) - cur_v_tile = nl.ndarray((LARGE_TILE_SZ//B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE), dtype=kernel_dtype) - load_tile_size = B_P_SIZE - for k_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size): - load_p = nl.arange(B_D_SIZE)[:, None] - load_f = nl.arange(load_tile_size)[None, :] - cur_k_tile[load_p, load_tile_size*k_i+load_f] = nl.load( - k[batch_id, head_id, load_p, load_tile_size*k_i+load_f] - ) - if config.should_transpose_v: - for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size): - load_p = nl.arange(B_D_SIZE)[:, None] - load_f = nl.arange(B_P_SIZE)[None, :] - - loaded = nl.load(v[batch_id, head_id, load_p, B_P_SIZE*v_i+load_f], dtype=kernel_dtype) - store_p = nl.arange(B_P_SIZE)[:, None] - store_f = nl.arange(B_D_SIZE)[None, :] - cur_v_tile[v_i, store_p, store_f] = nisa.nc_transpose(loaded) - else: - for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size): - load_p = nl.arange(B_P_SIZE)[:, None] - load_f = nl.arange(B_D_SIZE)[None, :] - - cur_v_tile[v_i, load_p, load_f] = nl.load(v[batch_id, head_id, B_P_SIZE*v_i+load_p, load_f], dtype=kernel_dtype) - - for i in nl.affine_range(n_tile_q): - i_f_128 = nl.arange(B_P_SIZE)[None, :] - i_f_d = nl.arange(B_D_SIZE)[None, :] - i_p_d = nl.arange(B_D_SIZE)[:,None] - q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE),dtype=kernel_dtype) - q_tile[i_p_d, i_f_128] = nl.load(q[batch_id, head_id * q_h_per_k_h + i_q_h, i_p_d, i*B_P_SIZE+i_f_128], dtype=kernel_dtype) \ - * softmax_scale # load (d, 128) tile in SBUF - # handle first tile and compute max and lse explicitly by passing initialize=True - _flash_attention_core(q_local_tile=q_tile, k=cur_k_tile, v=cur_v_tile, - q_h_per_k_h=q_h_per_k_h, - o_buffer=o_buffer[i], l_buffer=l_buffer[i], m_buffer=m_buffer[i], - batch_id=batch_id, head_id=head_id, - gqa_head_idx=i_q_h, q_tile_idx=i, local_k_large_tile_idx=0, - kernel_dtype=kernel_dtype, acc_type=acc_type, - flash_config=config, use_causal_mask=use_causal_mask, - initialize=True, - B_P_SIZE=B_P_SIZE, B_F_SIZE=B_F_SIZE, B_D_SIZE=B_D_SIZE, - dropout_p=dropout_p, dropout_p_tensor=dropout_p_tensor, seed_tensor=seed_local) - - for j in nl.sequential_range(1, num_large_k_tile): - cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), dtype=kernel_dtype) - cur_v_tile = nl.ndarray((LARGE_TILE_SZ//B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE), dtype=kernel_dtype) - load_tile_size = B_P_SIZE - for k_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size): - load_p = nl.arange(B_D_SIZE)[:, None] - load_f = nl.arange(load_tile_size)[None, :] - cur_k_tile[load_p, load_tile_size*k_i+load_f] = nl.load( - k[batch_id, head_id, load_p, j*LARGE_TILE_SZ+load_tile_size*k_i+load_f] - ) - if config.should_transpose_v: - for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size): - load_p = nl.arange(B_D_SIZE)[:, None] - load_f = nl.arange(B_P_SIZE)[None, :] - - loaded = nl.load(v[batch_id, head_id, load_p, j*LARGE_TILE_SZ+B_P_SIZE*v_i+load_f], dtype=kernel_dtype) - store_p = nl.arange(B_P_SIZE)[:, None] - store_f = nl.arange(B_D_SIZE)[None, :] - cur_v_tile[v_i, store_p, store_f] = nisa.nc_transpose(loaded) - else: - for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size): - load_p = nl.arange(B_P_SIZE)[:, None] - load_f = nl.arange(B_D_SIZE)[None, :] - - cur_v_tile[v_i, load_p, load_f] = nl.load(v[batch_id, head_id, j*LARGE_TILE_SZ+B_P_SIZE*v_i+load_p, load_f], dtype=kernel_dtype) - - for i in nl.affine_range(n_tile_q): - i_f_128 = nl.arange(B_P_SIZE)[None, :] - i_f_d = nl.arange(B_D_SIZE)[None, :] - i_p_d = nl.arange(B_D_SIZE)[:,None] - q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE),dtype=kernel_dtype) - q_tile[i_p_d, i_f_128] = nl.load(q[batch_id, head_id * q_h_per_k_h + i_q_h, i_p_d, i*B_P_SIZE+i_f_128], dtype=kernel_dtype) \ - * softmax_scale # load (d, 128) tile in SBUF - _flash_attention_core(q_local_tile=q_tile, k=cur_k_tile, v=cur_v_tile, - q_h_per_k_h=q_h_per_k_h, - o_buffer=o_buffer[i], l_buffer=l_buffer[i], m_buffer=m_buffer[i], - batch_id=batch_id, head_id=head_id, - gqa_head_idx=i_q_h, q_tile_idx=i, local_k_large_tile_idx=j, - kernel_dtype=kernel_dtype, acc_type=acc_type, - flash_config=config, use_causal_mask=use_causal_mask, - initialize=False, - B_P_SIZE=B_P_SIZE, B_F_SIZE=B_F_SIZE, B_D_SIZE=B_D_SIZE, - dropout_p=dropout_p, dropout_p_tensor=dropout_p_tensor, seed_tensor=seed_local) - - # -------- write output to buffer on HBM ------------ # - for i in nl.affine_range(n_tile_q): - out = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE), dtype=kernel_dtype) - out[i_q_p, i_f_d] = nl.multiply(o_buffer[i, num_large_k_tile - 1, i_q_p, i_f_d], - nl.exp(m_buffer[i, num_large_k_tile - 1, i_q_p, i_0_f] - l_buffer[i, num_large_k_tile - 1, i_q_p, i_0_f]), - dtype=kernel_dtype) - - nl.store(o[batch_id, head_id * q_h_per_k_h + i_q_h, i * B_P_SIZE + i_q_p, i_f_d], out[i_q_p, i_f_d]) - if not inference: - lse_local = nl.zeros((par_dim(B_P_SIZE), 1), dtype=acc_type) - lse_local[i_q_p, i_0_f] = nl.copy(l_buffer[i, num_large_k_tile - 1, i_q_p, i_0_f], dtype=acc_type) - nl.store(lse[batch_id, head_id * q_h_per_k_h + i_q_h, i_q_p, i + i_0_f], lse_local[i_q_p, i_0_f]) - - -def flash_attn_bwd( - q_ref, k_ref, v_ref, o_ref, - dy_ref, - lse_ref, - seed_ref, - out_dq_ref, out_dk_ref, out_dv_ref, - use_causal_mask=False, - mixed_precision=False, - dropout_p=0.0, - softmax_scale=None, -): - """ - Flash attention backward kernel. Compute the backward gradients. - - IO tensor layouts: - - q_ref: shape (bs, nheads, head_size, seq) - - k_ref: shape (bs, nheads, head_size, seq) - - v_ref: shape (bs, nheads, head_size, seq) - - o_ref: shape (bs, nheads, head_size, seq) - - dy_ref: shape (bs, nheads, head_size, seq) - - lse_ref: shape (bs, nheads, nl.tile_size.pmax, seq // nl.tile_size.pmax) - - seed_ref: shape (1,) - - out_dq_ref: shape (bs, nheads, head_size, seq) - - out_dk_ref: shape (bs, nheads, head_size, seq) - - out_dv_ref: shape (bs, nheads, head_size, seq) - - Detailed steps: - 1. D = rowsum(dO ◦ O) (pointwise multiply) - - 2. Recompute (softmax(Q^T@K)) - - 2.1 Q^T@K - 2.2 Scale the QK score - 2.3 Apply causal mask - 2.4 softmax - - 3. Compute the gradients of y = score @ V with respect to the loss - - 4. Compute the gradients of y = softmax(x) - - 5. Compute the gradients of Q^T@K - - 4.1 Compute dQ - 4.2 Compute dK - """ - - # Use q_ref dtype as the intermediate tensor dtype - # Assume all IO tensors have the same dtype - kernel_dtype = q_ref.dtype - mixed_dtype = np.dtype(np.float32) if mixed_precision else kernel_dtype - - assert q_ref.dtype == k_ref.dtype == v_ref.dtype == o_ref.dtype == dy_ref.dtype \ - == out_dq_ref.dtype == out_dk_ref.dtype == out_dv_ref.dtype - assert lse_ref.dtype == mixed_dtype - - # Shape checking - bs, nheads, d_head, seqlen = q_ref.shape - assert tuple(k_ref.shape) == (bs, nheads, d_head, seqlen), \ - f"Input K shape mismatch, got {k_ref.shape}" - assert tuple(v_ref.shape) == (bs, nheads, d_head, seqlen), \ - f"Input V shape mismatch, got {v_ref.shape}" - assert tuple(o_ref.shape) == (bs, nheads, d_head, seqlen), \ - f"Input o shape mismatch, got {o_ref.shape}" - assert tuple(dy_ref.shape) == (bs, nheads, d_head, seqlen), \ - f"Input dy shape mismatch, got {dy_ref.shape}" - assert tuple(lse_ref.shape) == (bs, nheads, nl.tile_size.pmax, seqlen // nl.tile_size.pmax), \ - f"Input lse shape mismatch, got {lse_ref.shape}" - if seed_ref is not None: - assert tuple(seed_ref.shape) == (1,), \ - f"Input seed shape mismatch, got {seed_ref.shape}" - - assert tuple(out_dq_ref.shape) == (bs, nheads, d_head, seqlen), \ - f"Output dQ shape mismatch, got {out_dq_ref.shape}" - assert tuple(out_dk_ref.shape) == (bs, nheads, d_head, seqlen), \ - f"Output dK shape mismatch, got {out_dk_ref.shape}" - assert tuple(out_dv_ref.shape) == (bs, nheads, d_head, seqlen), \ - f"Output dV shape mismatch, got {out_dv_ref.shape}" - - # FIXME: Add masking for different seqlen values. - assert seqlen % 128 == 0, \ - f"Input sequence length must be divisible by 128, got {seqlen}" - - # Softmax scaling factor, multiplied onto Q - softmax_scale = softmax_scale or 1.0 / float(d_head ** 0.5) - - # Different batch samples/attention heads have independent attention - batch_id = nl.program_id(axis=0) - head_id = nl.program_id(axis=1) - - assert nl.num_programs(1) == nheads, \ - f"The grid shape mismatch, got {nl.num_programs(1)} but should be {nheads}" - - q_seq_n_tiles, q_seq_tile_size = div_ceil(seqlen, 128), 128 - d_head_n_tiles, d_head_tile_size = div_ceil(d_head, 128), min(d_head, 128) - - if seqlen >= 512: - k_seq_n_tiles, k_seq_tile_size = seqlen // 512, 512 - else: - k_seq_n_tiles, k_seq_tile_size = seqlen // 128, 128 - - k_seq_n_tiles_backward, k_seq_tile_size_backward = seqlen // 128, 128 - k_seq_fwd_bwd_tile_multipler = k_seq_tile_size // k_seq_tile_size_backward - - ############################################################## - # Step 2.4 Prefetch exp bias for softmax - ############################################################## - softmax_exp_bias = nl.zeros((q_seq_n_tiles, par_dim(q_seq_tile_size), 1), dtype=mixed_dtype) - for i_q_seq_tile in nl.affine_range(q_seq_n_tiles): - ip_qk = nl.arange(q_seq_tile_size)[:, None] - lse_local = nl.load( - lse_ref[batch_id, head_id, ip_qk, i_q_seq_tile], - dtype=mixed_dtype) - softmax_exp_bias[i_q_seq_tile, ip_qk, 0] = lse_local * -1.0 - - ############################################################## - # Step 1 Compute rowsum(dO ◦ O) - ############################################################## - dy_o_sum = nl.ndarray((q_seq_n_tiles, par_dim(q_seq_tile_size), 1), dtype=mixed_dtype) - for i_q_seq_tile in nl.affine_range(q_seq_n_tiles): - ip_reduce = nl.arange(q_seq_tile_size)[:, None] - dy_o_partial = nl.zeros((par_dim(q_seq_tile_size), d_head_n_tiles), dtype=mixed_dtype) - for i_d_head_tile in nl.affine_range(d_head_n_tiles): - ip_load = nl.arange(d_head_tile_size)[:, None] - if_q = nl.arange(q_seq_tile_size)[None, :] - dy_local = nl.load_transpose2d( - dy_ref[batch_id, head_id, i_d_head_tile * d_head_tile_size + ip_load, i_q_seq_tile * q_seq_tile_size + if_q], - dtype=mixed_dtype) - o_local = nl.load_transpose2d( - o_ref[batch_id, head_id, i_d_head_tile * d_head_tile_size + ip_load, i_q_seq_tile * q_seq_tile_size + if_q], - dtype=mixed_dtype - ) - - dy_o_partial[ip_reduce, i_d_head_tile] = nisa.tensor_reduce( - np.add, data=dy_local*o_local, axis=(1,), dtype=mixed_dtype - ) - - dy_o_sum[i_q_seq_tile, ip_reduce, 0] = nisa.tensor_reduce( - np.add, data=dy_o_partial[ip_reduce, nl.arange(d_head_n_tiles)[None, :]], - axis=(1,), dtype=mixed_dtype - ) - - # Indices for prefetch - ip_qk = nl.arange(d_head_tile_size)[:, None] - if_q = nl.arange(q_seq_tile_size)[None, :] - if_k = nl.arange(k_seq_tile_size)[None, :] - - if dropout_p > 0.0: - seed_local = nl.load(seed_ref[0]) - # TODO: Remove this once the dropout supports scale prob - dropout_p_local = nl.full((q_seq_tile_size, 1), fill_value=dropout_p, dtype=np.float32) - else: - seed_local = None - dropout_p_local = None - - dq_local_reduced = nl.zeros((q_seq_n_tiles, d_head_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), - dtype=mixed_dtype) - - # affine_range give the compiler permission to vectorize instructions - # inside the loop which improves the performance. However, when using the - # the dropout we should use sequential_range to avoid setting - # seed vectorization. TODO: the compiler should avoid vectorizing seed setting - _range = nl.sequential_range if dropout_p > 0.0 else nl.affine_range - - for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): - # Prefetch V, K - v_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), dtype=kernel_dtype) - k_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), dtype=kernel_dtype) - transposed_k_local = nl.zeros((k_seq_fwd_bwd_tile_multipler, d_head_n_tiles, par_dim(k_seq_tile_size_backward), d_head_tile_size), dtype=kernel_dtype) - for i_d_head_tile in nl.affine_range(d_head_n_tiles): - k_local[i_d_head_tile, ip_qk, if_k] = nl.load( - k_ref[batch_id, head_id, i_d_head_tile * d_head_tile_size + ip_qk, i_k_seq_tile * k_seq_tile_size + if_k], - dtype=kernel_dtype) - v_local[i_d_head_tile, ip_qk, if_k] = nl.load( - v_ref[batch_id, head_id, i_d_head_tile * d_head_tile_size + ip_qk, i_k_seq_tile * k_seq_tile_size + if_k], - dtype=kernel_dtype) - ############################################################## - # Prefetch k transpose for the backward too - ############################################################## - if_k_backward = nl.arange(k_seq_tile_size_backward)[None, :] - ip_k_backward = nl.arange(k_seq_tile_size_backward)[:, None] - if_d_head = nl.arange(d_head_tile_size)[None, :] - for i_k_seq_tile_backward in nl.affine_range(k_seq_fwd_bwd_tile_multipler): - transposed_k_local[i_k_seq_tile_backward, i_d_head_tile, ip_k_backward, if_d_head] = \ - nisa.nc_transpose(k_local[i_d_head_tile, ip_qk, - i_k_seq_tile_backward * k_seq_tile_size_backward + if_k_backward]) - - dv_psum = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), - dtype=np.float32, buffer=nl.psum) - dk_psum = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), - dtype=np.float32, buffer=nl.psum) - for i_q_seq_tile in _range(q_seq_n_tiles): - # Prefetch dy, Q - dy_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=kernel_dtype) - q_local = nl.zeros((d_head_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=kernel_dtype) - for i_d_head_tile in nl.affine_range(d_head_n_tiles): - ip_qk = nl.arange(d_head_tile_size)[:, None] - if_q = nl.arange(q_seq_tile_size)[None, :] - - dy_local[i_d_head_tile, ip_qk, if_q] = nl.load( - dy_ref[batch_id, head_id, i_d_head_tile * d_head_tile_size + ip_qk, i_q_seq_tile * q_seq_tile_size + if_q], - dtype=kernel_dtype) - - q_local[i_d_head_tile, ip_qk, if_q] = nl.load( - q_ref[batch_id, head_id, i_d_head_tile * d_head_tile_size + ip_qk, i_q_seq_tile * q_seq_tile_size + if_q], - dtype=kernel_dtype) * softmax_scale - - _flash_attn_bwd_core( - q_local=q_local, k_local=k_local, transposed_k_local=transposed_k_local, - v_local=v_local, dy_local=dy_local, - dk_psum=dk_psum, dv_psum=dv_psum, dq_local_reduced=dq_local_reduced, - softmax_exp_bias=softmax_exp_bias, dy_o_sum=dy_o_sum, - local_i_q_seq_tile=i_q_seq_tile, local_i_k_seq_tile=i_k_seq_tile, - seqlen=seqlen, d_head=d_head, - use_causal_mask=use_causal_mask, - kernel_dtype=kernel_dtype, mixed_dtype=mixed_dtype, - softmax_scale=softmax_scale, - seed_local=seed_local, dropout_p=dropout_p, dropout_p_local=dropout_p_local, - ) - - # Write dK, dV - for i_d_head_tile in nl.affine_range(d_head_n_tiles): - ip_dkv = nl.arange(d_head_tile_size)[:, None] - if_dkv = nl.arange(k_seq_tile_size)[None, :] - - nl.store( - out_dv_ref[batch_id, head_id, - i_d_head_tile * d_head_tile_size + ip_dkv, - i_k_seq_tile * k_seq_tile_size + if_dkv], - value=dv_psum[i_d_head_tile, ip_dkv, if_dkv], - ) - - nl.store( - out_dk_ref[batch_id, head_id, - i_d_head_tile * d_head_tile_size + ip_dkv, - i_k_seq_tile * k_seq_tile_size + if_dkv], - value=dk_psum[i_d_head_tile, ip_dkv, if_dkv], - ) - - # Write dQ - for i_q_seq_tile in nl.affine_range(q_seq_n_tiles): - for i_d_head_tile in nl.affine_range(d_head_n_tiles): - ip_dq = nl.arange(d_head_tile_size)[:, None] - if_dq = nl.arange(q_seq_tile_size)[None, :] - - nl.store( - out_dq_ref[batch_id, head_id, - i_d_head_tile * d_head_tile_size + ip_dq, - i_q_seq_tile * q_seq_tile_size + if_dq], - value=dq_local_reduced[i_q_seq_tile, i_d_head_tile, ip_dq, if_dq], - ) - -@trace -def _flash_attn_bwd_core( - q_local, k_local, transposed_k_local, v_local, dy_local, - dk_psum, dv_psum, dq_local_reduced, - softmax_exp_bias, dy_o_sum, - local_i_q_seq_tile, local_i_k_seq_tile, - seqlen, d_head, - use_causal_mask, - kernel_dtype, mixed_dtype, - softmax_scale, - seed_local, dropout_p, dropout_p_local, - global_i_q_seq_tile = None, - global_i_k_seq_tile = None, -): - """ - The flash backward core funciton to calculate the gradients of Q, K and V - of the given tiles. The result will be accumulated into the dk, dv, dq psum - """ - q_seq_n_tiles, q_seq_tile_size = div_ceil(seqlen, 128), 128 - d_head_n_tiles, d_head_tile_size = div_ceil(d_head, 128), min(d_head, 128) - if seqlen >= 512: - k_seq_n_tiles, k_seq_tile_size = seqlen // 512, 512 - else: - k_seq_n_tiles, k_seq_tile_size = seqlen // 128, 128 - k_seq_n_tiles_backward, k_seq_tile_size_backward = seqlen // 128, 128 - k_seq_fwd_bwd_tile_multipler = k_seq_tile_size // k_seq_tile_size_backward - - if global_i_q_seq_tile is None: - global_i_q_seq_tile = local_i_q_seq_tile - global_i_k_seq_tile = local_i_k_seq_tile - - mask = global_i_q_seq_tile * q_seq_tile_size >= global_i_k_seq_tile * k_seq_tile_size if use_causal_mask else None - # PSUM buffer shape: [q_seq_tile_size P, k_seq_tile_size F] - qk_psum = nl.zeros((par_dim(q_seq_tile_size), k_seq_tile_size), - dtype=np.float32, buffer=nl.psum) - qk_res_buf = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), buffer=nl.sbuf, dtype=kernel_dtype) - - batch_id = nl.program_id(axis=0) - head_id = nl.program_id(axis=1) - # Tensor indices for accessing qk result in k_seq_tile_size - if_q = nl.arange(q_seq_tile_size)[None, :] - ip_qk = nl.arange(d_head_tile_size)[:, None] - - ip_q = nl.arange(q_seq_tile_size)[:, None] - if_k = nl.arange(k_seq_tile_size)[None, :] - - # Loop over contraction dim of QK matmul - for i_d_head_tile in nl.affine_range(d_head_n_tiles): - ############################################################## - # Step 2.1 Compute Q^T@K, with matmul(stationary=tensor_q, moving=tensor_k, contract=d_head) - ############################################################## - qk_psum[ip_q, if_k] += nisa.nc_matmul(q_local[i_d_head_tile, ip_qk, if_q], - k_local[i_d_head_tile, ip_qk, if_k], - mask=mask) - - ###################################### - # Step 2.2. Apply optional causal mask - ###################################### - if use_causal_mask: - # Magic number -9984.0 to replace -inf similar to what Tensorizer uses - qk_res_buf[ip_q, if_k] = nisa.affine_select( - pred=(global_i_q_seq_tile * q_seq_tile_size + ip_q >= global_i_k_seq_tile * k_seq_tile_size + if_k), - on_true_tile=qk_psum[ip_q, if_k], on_false_value=-9984.0, dtype=mixed_dtype, - mask=mask) - else: - # Simply send psum result back to sbuf - qk_res_buf[ip_q, if_k] = \ - nl.copy(qk_psum[ip_q, if_k], dtype=mixed_dtype) - - softmax_y = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=kernel_dtype, buffer=nl.sbuf) - softmax_y[ip_q, if_k] = nisa.activation(np.exp, - data=qk_res_buf[ip_q, if_k], - bias=softmax_exp_bias[local_i_q_seq_tile, ip_q, 0], - scale=1.0, - mask=mask) - ##################################################################### - # Dropout - ##################################################################### - if dropout_p > 0.0: - offset = global_i_k_seq_tile + global_i_q_seq_tile * k_seq_n_tiles \ - + head_id * k_seq_n_tiles * q_seq_n_tiles \ - + batch_id * nl.num_programs(1) * k_seq_n_tiles * q_seq_n_tiles - offset_seed = nl.add(seed_local[0, 0], offset, mask=mask) - nl.random_seed(seed=offset_seed, mask=mask) - softmax_y[ip_q, if_k] = nl.dropout(softmax_y[ip_q, if_k], rate=dropout_p_local[ip_q, 0], mask=mask) - softmax_y[ip_q, if_k] = nl.multiply(softmax_y[ip_q, if_k], 1 / (1 - dropout_p), mask=mask) - - ##################################################################### - # Step 3.1 Calculate the backward gradients dL/dV, where y=softmax@V - # in value projection with matmul(stationary=dy, moving=softmax) - ##################################################################### - for i_d_head_tile in nl.affine_range(d_head_n_tiles): - ip_dv = nl.arange(d_head_tile_size)[:, None] - if_dv = nl.arange(k_seq_tile_size)[None, :] - if_trans_dy = nl.arange(q_seq_tile_size)[None, :] - trans_dy = nisa.nc_transpose(dy_local[i_d_head_tile, ip_dv, if_trans_dy], - mask=mask) - dv_psum[i_d_head_tile, ip_dv, if_dv] += \ - nisa.nc_matmul(trans_dy, softmax_y[ip_q, if_k], mask=mask) - - ##################################################################### - # Step 3.2 Calculate the backward gradients dL/dsoftmax, where y=softmax@V - # in value projection with matmul(stationary=dy, moving=v) - ##################################################################### - softmax_dy_psum = nl.zeros((par_dim(q_seq_tile_size), k_seq_tile_size), - dtype=np.float32, buffer=nl.psum) - for i_d_head_tile in nl.affine_range(d_head_n_tiles): - ip_softmax_dy = nl.arange(d_head_tile_size)[:, None] - if_dy = nl.arange(q_seq_tile_size)[None, :] - softmax_dy_psum[ip_q, if_k] += \ - nisa.nc_matmul(dy_local[i_d_head_tile, ip_softmax_dy, if_dy], - v_local[i_d_head_tile, ip_softmax_dy, if_k], - mask=mask) - - softmax_dy = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=kernel_dtype, buffer=nl.sbuf) - softmax_dy[ip_q, if_k] = nl.copy(softmax_dy_psum[ip_q, if_k], dtype=kernel_dtype, - mask=mask) - - ##################################################################### - # Step 4 Calculate the softmax backward gradients dL/dx, where y=softmax(x) - # dL/dx = y * (dL/dy - rowsum(dO_O)), where y = softmax(x) - ##################################################################### - softmax_dx_local = nl.ndarray((par_dim(q_seq_tile_size), k_seq_tile_size), dtype=kernel_dtype, buffer=nl.sbuf) - softmax_dx_local[ip_q, if_k] = \ - nisa.tensor_scalar(data=softmax_dy[ip_q, if_k], - op0=np.subtract, - operand0=dy_o_sum[local_i_q_seq_tile, ip_q, 0], - op1=np.multiply, - operand1=softmax_y[ip_q, if_k], - mask=mask) - - ##################################################################### - # Step 5.1 Calculate dK, with matmul(stationary=Q, moving=softmax_dx) - ##################################################################### - for i_d_head_tile in nl.affine_range(d_head_n_tiles): - ip_trans_q = nl.arange(d_head_tile_size)[:, None] - if_trans_q = nl.arange(q_seq_tile_size)[None, :] - ip_dk = nl.arange(d_head_tile_size)[:, None] - trans_q_local = nisa.nc_transpose(q_local[i_d_head_tile, ip_trans_q, if_trans_q], - mask=mask) - dk_psum[i_d_head_tile, ip_dk, if_k] += \ - nisa.nc_matmul(trans_q_local, - softmax_dx_local[ip_q, if_k], - mask=mask) - - ##################################################################### - # Step 5.2 Calculate dQ - ##################################################################### - if_k = nl.arange(k_seq_tile_size_backward)[None, :] - ip_dq = nl.arange(d_head_tile_size)[:, None] - if_dq = nl.arange(q_seq_tile_size)[None, :] - if_d = nl.arange(d_head_tile_size)[None, :] - ip_transposed_k = nl.arange(k_seq_tile_size_backward)[:, None] - for i_d_head_tile in nl.affine_range(d_head_n_tiles): - dq_psum = nl.zeros((par_dim(d_head_tile_size), q_seq_tile_size), - dtype=np.float32, buffer=nl.psum) - for i_k_seq_tile_backward in nl.affine_range(k_seq_fwd_bwd_tile_multipler): - transposed_softmax_dx_local = \ - nisa.nc_transpose(softmax_dx_local[ip_q, i_k_seq_tile_backward * k_seq_tile_size_backward + if_k], - mask=mask) - dq_psum[ip_dq, if_dq] += nisa.nc_matmul( - transposed_k_local[i_k_seq_tile_backward, i_d_head_tile, ip_transposed_k, if_d], - transposed_softmax_dx_local, - mask=mask) - dq_local = nl.multiply(dq_psum[ip_dq, if_dq], softmax_scale, dtype=kernel_dtype, mask=mask) - dq_local_reduced[local_i_q_seq_tile, i_d_head_tile, ip_dq, if_dq] = nl.loop_reduce( - dq_local, op=np.add, loop_indices=(local_i_k_seq_tile,), - dtype=mixed_dtype, mask=mask) - -def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_causal_mask=False, - mixed_percision=True): - """ - Fused self attention kernel for small head size Stable Diffusion workload. - - Computes softmax(QK^T)V. Decoder model can optionally include a causal mask - application. Does not include QKV rojection, output projection, dropout, - residual connection, etc. - - This kernel is designed to be used for Stable Diffusion models where the - n_heads is smaller or equal to 128. Assertion is thrown if `n_heads` does - not satisfy the requirement. - - IO tensor layouts: - - q_ptr: shape (bs, n_heads, seq_q) - - k_ptr: shape (bs, seq_k, n_heads) - - v_ptr: shape (bs, seq_v, n_heads) - - out_ptr: shape (bs, seq_q, n_heads) - - We use seq_q and seq_k just for clarity, this kernel requires seq_q == seq_k - - IO tensor dtypes: - - This kernel assumes all IO tensors have the same dtype - - If mixed_percision is True, then all Tensor Engine operation will be performed in - bfloat16 and accumulation will be performed in float32. Otherwise the intermediates - will be in the same type as the inputs. - """ - # Use q_ref dtype as the intermediate tensor dtype - # Assume all IO tensors have the same dtype - kernel_dtype = q_ref.dtype - pe_in_dt = nl.bfloat16 if mixed_percision else np.float32 - assert q_ref.dtype == k_ref.dtype == v_ref.dtype == out_ref.dtype - - # Shape checking - bs, d_head, seqlen = q_ref.shape - assert d_head <= 128, "Cannot use this kernel for d_head > 128" - assert tuple(q_ref.shape) == (bs, d_head, seqlen), 'Input shape mismatch!' - assert tuple(k_ref.shape) == (bs, seqlen, d_head), 'Input shape mismatch!' - assert tuple(v_ref.shape) == (bs, seqlen, - d_head), f'Input shape mismatch! Expected: {(bs, seqlen, d_head)} Actual: {tuple(v_ref.shape)}' - assert tuple(out_ref.shape) == (bs, seqlen, d_head), 'Output shape mismatch!' - - # Softmax scaling factor, multiplied onto Q - softmax_scale = 0.125 - - # Different batch samples/attention heads have independent attention - batch_id = nl.program_id(axis=0) - # batch_id = 0 - - # TODO: make q_seq_tile_size user input - # The matmuls currently use a fixed tile size of (128, 128). This may not achieve the best - # performance for dense attention. However, since this kernel is in preparation - # for block-sparse attention, this tile size is acceptable because the block - # size of block-sparse attention cannot be too large. - q_seq_n_tiles, q_seq_tile_size = seqlen // 128, 128 - k_seq_n_tiles, k_seq_tile_size = seqlen // 128, 128 - # No tiling on d_head dimension since the number of d_head fits in SB - d_head_tile_size = d_head - v_seq_n_tiles, v_seq_tile_size = seqlen // 128, 128 - - ################################### - # Step 1. transpose(tensor_v) - ################################### - # Buffer for v matrix transposed - # Pre-fetch and keep it in SBUF throughout different softmax tiles - trans_v = nl.ndarray((par_dim(v_seq_tile_size), v_seq_n_tiles, d_head), dtype=pe_in_dt) - - for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): - ip_v = nl.arange(v_seq_tile_size)[:, None] - if_v = nl.arange(d_head_tile_size)[None, :] - trans_v[ip_v, i_k_seq_tile, if_v] = nl.load( - v_ref[batch_id, i_k_seq_tile * k_seq_tile_size + ip_v, if_v], - dtype=pe_in_dt) - - q_local = nl.ndarray((q_seq_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=pe_in_dt) - ip_q = nl.arange(d_head_tile_size)[:, None] - if_q = nl.arange(q_seq_tile_size)[None, :] - for i_q_seq_tile in nl.affine_range(q_seq_n_tiles): - q_local[i_q_seq_tile, ip_q, if_q] = nl.load( - q_ref[batch_id, ip_q, i_q_seq_tile * q_seq_tile_size + if_q], - dtype=pe_in_dt) * softmax_scale - - k_local = nl.ndarray((k_seq_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), dtype=pe_in_dt) - ip_k = nl.arange(d_head_tile_size)[:, None] - if_k = nl.arange(k_seq_tile_size)[None, :] - for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): - k_local[i_k_seq_tile, ip_k, if_k] = nl.load_transpose2d( - k_ref[batch_id, - i_k_seq_tile * k_seq_tile_size + nl.arange(k_seq_tile_size)[:, None], - nl.arange(d_head_tile_size)[None, :]], - dtype=pe_in_dt) - - for i_q_seq_tile in nl.affine_range(q_seq_n_tiles): # indent = 2 - # A SBUF buffer for an independent softmax tile - qk_res_buf = nl.ndarray((par_dim(q_seq_tile_size), seqlen), dtype=kernel_dtype) - - neg_max_res = nl.ndarray((par_dim(q_seq_tile_size), k_seq_n_tiles), dtype=kernel_dtype) - ip_max = nl.arange(q_seq_tile_size)[:, None] - if_max = nl.arange(k_seq_n_tiles)[None, :] - - # Loop over RHS free of matmul(stationary=tensor_q, moving=tensor_k, contract=d_head) - for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): # indent = 4 - - # Since the K^T tile is the RHS, the q_seq_len dimension will be P in the result - # PSUM buffer shape: [q_seq_tile_size P, k_seq_tile_size F] - qk_psum = nl.zeros((par_dim(q_seq_tile_size), k_seq_tile_size), - dtype=np.float32, buffer=nl.psum) - - # Tensor indices for accessing qk result in k_seq_tile_size - ip_qk = nl.arange(q_seq_tile_size)[:, None] - if_qk = nl.arange(k_seq_tile_size)[None, :] - - ############################################################## - # Step 2. matmul(stationary=tensor_q, moving=tensor_k, contract=d_head) - ############################################################## - qk_psum[ip_qk, if_qk] += nisa.nc_matmul(moving=k_local[i_k_seq_tile, ip_k, if_k], - stationary=q_local[i_q_seq_tile, ip_q, if_q]) - - ################################### - # Step 3. Apply optional causal mask - ################################### - if use_causal_mask: - # Magic number -9984.0 to replace -inf similar to what Tensorizer uses - qk_res_buf[ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nisa.affine_select( - pred=(i_q_seq_tile * q_seq_tile_size + ip_qk >= i_k_seq_tile * k_seq_tile_size + if_qk), - on_true_tile=qk_psum[ip_qk, if_qk], on_false_value=-9984.0, dtype=kernel_dtype) - else: - # Simply send psum result back to sbuf - qk_res_buf[ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nl.copy(qk_psum[ip_qk, if_qk], - dtype=kernel_dtype) - - ################################### - # Step 4. Softmax - ################################### - # TODO: use TensorScalarCacheReduce to avoid an extra copy - # We want to break this reduction in tiles because we want to overlap it with the previous matmul - neg_max_res[ip_max, i_k_seq_tile] = nisa.tensor_reduce( - np.max, data=qk_res_buf[ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk], - axis=(1,), dtype=kernel_dtype, negate=True) - - neg_max_res_final = nisa.tensor_reduce( - np.min, data=neg_max_res[ip_max, if_max], - axis=(1,), dtype=kernel_dtype, negate=False) - - ip_softmax = nl.arange(q_seq_tile_size)[:, None] - if_softmax = nl.arange(seqlen)[None, :] - ip_sum_res = nl.arange(q_seq_tile_size)[:, None] - if_sum_res = nl.arange(d_head_tile_size)[None, :] - - softmax_res = nl.ndarray((par_dim(q_seq_tile_size), seqlen), dtype=pe_in_dt) - sum_divisor = nl.ndarray((par_dim(q_seq_tile_size), d_head_tile_size), dtype=kernel_dtype) - - # Simply use a large tile of seq_len in size since this is a "blocking" instruction - # Assuming the compiler will merge exp and reduce_add into a single instruction on ACT - exp_res = nisa.activation(np.exp, - data=qk_res_buf[ip_softmax, if_softmax], - bias=neg_max_res_final, scale=1.0) - - sum_res = nisa.tensor_reduce(np.add, data=exp_res, axis=(1,), - dtype=kernel_dtype) - softmax_res[ip_softmax, if_softmax] = nl.copy(exp_res, dtype=pe_in_dt) - - sum_reciprocal_broadcast = (1.0 / sum_res).broadcast_to((q_seq_tile_size, d_head_tile_size)) - sum_divisor[ip_sum_res, if_sum_res] = nl.copy(sum_reciprocal_broadcast, dtype=kernel_dtype) - - # Buffer for transposed softmax results (FP32 in PSUM) - trans_softmax_res = nl.ndarray( - (par_dim(k_seq_tile_size), k_seq_n_tiles, q_seq_tile_size), - dtype=pe_in_dt) - - # Result psum buffer has the hidden dim as P - attn_res_psum = nl.zeros((par_dim(d_head_tile_size), q_seq_tile_size), - dtype=np.float32, buffer=nl.psum) - - ip_scores_t = nl.arange(k_seq_tile_size)[:, None] - if_scores_t = nl.arange(q_seq_tile_size)[None, :] - # Loop over matmul_1 contraction - for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): - ################################### - # Step 5. transpose(softmax_res) - ################################### - ip_scores = nl.arange(q_seq_tile_size)[:, None] - if_scores = nl.arange(k_seq_tile_size)[None, :] - - trans_softmax_res[ip_scores_t, i_k_seq_tile, if_scores_t] = nisa.nc_transpose( - softmax_res[ip_scores, i_k_seq_tile * k_seq_tile_size + if_scores]) - - ip_out = nl.arange(d_head_tile_size)[:, None] - if_out = nl.arange(q_seq_tile_size)[None, :] - for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): - ###################################################################### - # Step 6. matmul_1(stationary=trans_v, moving=trans_softmax_res, contract=seqlen_v=seqlen_k) - ###################################################################### - ip_v_t = nl.arange(k_seq_tile_size)[:, None] - if_v_t = nl.arange(d_head_tile_size)[None, :] - attn_res_psum[ip_out, if_out] += \ - nisa.nc_matmul(moving=trans_softmax_res[ip_scores_t, i_k_seq_tile, if_scores_t], - stationary=trans_v[ip_v_t, i_k_seq_tile, if_v_t]) - - attn_res_sbuf = nl.copy(attn_res_psum[ip_out, if_out], dtype=kernel_dtype) - - attn_res_div = attn_res_sbuf * nisa.nc_transpose(sum_divisor[ip_sum_res, if_sum_res]) - - nl.store( - out_ref[batch_id, i_q_seq_tile * q_seq_tile_size + if_out, ip_out], - value=attn_res_div) - \ No newline at end of file diff --git a/src/reference/tutorial.py b/src/reference/tutorial.py deleted file mode 100644 index 4f3ebef..0000000 --- a/src/reference/tutorial.py +++ /dev/null @@ -1,29 +0,0 @@ -""" -Copyright (c) 2023, Amazon.com. All Rights Reserved - -kernels - Builtin high performance NKI kernels used in tutorial - -""" - -import neuronxcc.nki.language as nl - -def add_kernel_nx8x128x512(a_ptr, b_ptr, c_ptr, n_elements): - ix = nl.arange(128)[:, None] - iy = nl.arange(512)[None, :] - - tile_size = 128 * 512 - block_size = 8 * tile_size - - j = nl.program_id(axis=0) - - for i in nl.affine_range(8): - offset = j * block_size + i * tile_size + 512 * ix + iy - mask = offset < n_elements - a_ptr = a_ptr.ptr + offset - b_ptr = b_ptr.ptr + offset - c_ptr = c_ptr.ptr + offset - - a = nl.load(a_ptr, mask=mask) - b = nl.load(b_ptr, mask=mask) - c = a + b - nl.store(c_ptr, value=c, mask=mask) \ No newline at end of file diff --git a/src/tutorials/tensor_addition/tensor_addition_jax.py b/src/tutorials/tensor_addition/tensor_addition_jax.py deleted file mode 100644 index 9655b84..0000000 --- a/src/tutorials/tensor_addition/tensor_addition_jax.py +++ /dev/null @@ -1,61 +0,0 @@ -""" -Copyright (C) 2024, Amazon.com. All Rights Reserved - -JAX implementation for tensor addition NKI tutorial. - -""" -import jax -import jax.numpy as jnp -from jax_neuronx import nki_call - -from tensor_addition_nki_kernels import nki_tensor_add_kernel_ - - -def nki_tensor_add(a_input, b_input): - """NKI kernel caller to compute element-wise addition of two input tensors - - This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs - - Args: - a_input: a first input tensor, of shape [N*128, M*512] - b_input: a second input tensor, of shape [N*128, M*512] - - Returns: - a tensor of shape [N*128, M*512], the result of a_input + b_input - """ - - # The SPMD launch grid denotes the number of kernel instances. - # In this case, we use a 2D grid where the size of each invocation is 128x512 - grid_x = a_input.shape[0] // 128 - grid_y = a_input.shape[1] // 512 - - out_shape = jax.ShapeDtypeStruct((a_input.shape[0], a_input.shape[1]), dtype=a_input.dtype) - - return nki_call( - nki_tensor_add_kernel_, - a_input, - b_input, - grid=(grid_x, grid_y), - out_shape=out_shape, - ) - - -if __name__ == "__main__": - - seed_a, seed_b = jax.random.split(jax.random.PRNGKey(42)) - a = jax.random.uniform(seed_a, (256, 1024), dtype=jnp.bfloat16) - b = jax.random.uniform(seed_b, (256, 1024), dtype=jnp.bfloat16) - - output_nki = nki_tensor_add(a, b) - print(f"output_nki={output_nki}") - - output_jax = a + b - print(f"output_jax={output_jax}") - - allclose = jnp.allclose(output_jax, output_nki, atol=1e-4, rtol=1e-2) - if allclose: - print("NKI and JAX match") - else: - print("NKI and JAX differ") - - assert allclose diff --git a/src/tutorials/tensor_addition/tensor_addition_torch.py b/src/tutorials/tensor_addition/tensor_addition_torch.py deleted file mode 100644 index 942e728..0000000 --- a/src/tutorials/tensor_addition/tensor_addition_torch.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -Copyright (C) 2024, Amazon.com. All Rights Reserved - -PyTorch implementation for tensor addition NKI tutorial. - -""" -import torch -from torch_xla.core import xla_model as xm -from torch_neuronx import nki_jit - -from tensor_addition_nki_kernels import nki_tensor_add_kernel_ - - -def nki_tensor_add(a_input, b_input): - """NKI kernel caller to compute element-wise addition of two input tensors - - This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs - - Args: - a_input: a first input tensor, of shape [N*128, M*512] - b_input: a second input tensor, of shape [N*128, M*512] - - Returns: - a tensor of shape [N*128, M*512], the result of a_input + b_input - """ - - # The SPMD launch grid denotes the number of kernel instances. - # In this case, we use a 2D grid where the size of each invocation is 128x512 - grid_x = a_input.shape[0] // 128 - grid_y = a_input.shape[1] // 512 - c_output = torch.zeros(a_input.shape, dtype=a_input.dtype).to(device=device) - - # Decorate the NKI kernel for PyTorch tracing - nki_tensor_add_kernel_torch = nki_jit(nki_tensor_add_kernel_) - nki_tensor_add_kernel_torch[grid_x, grid_y](a_input, b_input, c_output) - - return c_output - -if __name__ == "__main__": - device = xm.xla_device() - - a = torch.rand((256, 1024), dtype=torch.bfloat16).to(device=device) - b = torch.rand((256, 1024), dtype=torch.bfloat16).to(device=device) - - output_nki = nki_tensor_add(a, b) - print(f"output_nki={output_nki}") - - output_torch = a + b - print(f"output_torch={output_torch}") - - allclose = torch.allclose(output_torch, output_nki, atol=1e-4, rtol=1e-2) - if allclose: - print("NKI and Torch match") - else: - print("NKI and Torch differ") - - assert allclose diff --git a/test/integration/flash_attention/flash_attention_benchmark.py b/test/integration/flash_attention/flash_attention_benchmark.py index 5aa2e40..918a14f 100644 --- a/test/integration/flash_attention/flash_attention_benchmark.py +++ b/test/integration/flash_attention/flash_attention_benchmark.py @@ -14,6 +14,8 @@ from flash_attention import nki_flash_attn_func +parent_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(parent_dir) from perf_utils.LatencyCollector import benchmark if len(sys.argv) != 2: diff --git a/test/integration/fused_sd_attention_small_head/sd2_512_benchmark.py b/test/integration/fused_sd_attention_small_head/sd2_512_benchmark.py index e7fd205..5d63424 100644 --- a/test/integration/fused_sd_attention_small_head/sd2_512_benchmark.py +++ b/test/integration/fused_sd_attention_small_head/sd2_512_benchmark.py @@ -8,6 +8,8 @@ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler from diffusers.models.unet_2d_condition import UNet2DConditionOutput +parent_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(parent_dir) from perf_utils.LatencyCollector import benchmark import sys diff --git a/test/integration/resize_nearest_fixed_dma_kernel/sd2_inpainting_936_624_benchmark.py b/test/integration/resize_nearest_fixed_dma_kernel/sd2_inpainting_936_624_benchmark.py index 3ba0eab..4970f72 100644 --- a/test/integration/resize_nearest_fixed_dma_kernel/sd2_inpainting_936_624_benchmark.py +++ b/test/integration/resize_nearest_fixed_dma_kernel/sd2_inpainting_936_624_benchmark.py @@ -23,6 +23,8 @@ else: from diffusers.models.cross_attention import CrossAttention +parent_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(parent_dir) from perf_utils.LatencyCollector import benchmark import sys diff --git a/test/unit/README.md b/test/unit/README.md index e0835de..55dc937 100644 --- a/test/unit/README.md +++ b/test/unit/README.md @@ -1 +1,7 @@ -Tests under this folder are unit tests for the kernels in `neuronxcc.nki.kernels`, and they are part of the nki-samples Github Repo. Only public APIs can be used for tests in this folder. \ No newline at end of file +Tests under this folder are unit tests for the kernels in `src/nki_samples`. + +To execute the tests, we need to include `src/nki_samples` in the `PYTHONPATH`. + +For example, + +PYTHONPATH=$PYTHONPATH:/home/ubuntu/nki-samples/src/ pytest test_flash_attn_fwd.py \ No newline at end of file diff --git a/test/unit/__main__.py b/test/unit/__main__.py deleted file mode 100644 index 34fee3a..0000000 --- a/test/unit/__main__.py +++ /dev/null @@ -1,14 +0,0 @@ -import os -import sys - -# This file is basically a hack around the fact that pytest has a bug where it does not discover conftest.py correctly if you launch the test using --pyargs. -# https://github.com/pytest-dev/pytest/issues/1596 - - -# Todo: Using __file__ isn't the most robust. Figure out how to do this using importlib or similar. -test_root = os.path.dirname(__file__) - -if __name__ == "__main__": - import pytest - errcode = pytest.main([test_root] + sys.argv[1:]) - sys.exit(errcode) \ No newline at end of file diff --git a/test/unit/conftest.py b/test/unit/conftest.py new file mode 100644 index 0000000..cd663ae --- /dev/null +++ b/test/unit/conftest.py @@ -0,0 +1,28 @@ +import pytest + +def pytest_addoption(parser): + parser.addoption( + "--simulation-only", action="store_true", default=False, help="Run simulation only, it will run test with `simulation` marker in simulation mode" + ) + +def pytest_configure(config): + config.addinivalue_line( + "markers", "simulation: mark simulation test that can be executed without a NeuronDevice" + ) + +@pytest.fixture +def simulation_only(request): + return request.config.getoption("--simulation-only") + +def pytest_collection_modifyitems(session, config, items): + if config.getoption("--simulation-only"): + # Only run cases with `simulation marker` + result = [] + for item in items: + for marker in item.iter_markers(): + if marker.name == 'simulation': + result.append(item) + break + items.clear() + items.extend(result) + \ No newline at end of file diff --git a/test/unit/test_SD_attention_small_head.py b/test/unit/test_SD_attention_small_head.py index 5480fa4..1a54a4b 100644 --- a/test/unit/test_SD_attention_small_head.py +++ b/test/unit/test_SD_attention_small_head.py @@ -3,15 +3,14 @@ """ import os import pytest -from neuronxcc.nki.kernels.attention import fused_self_attn_for_SD_small_head_size -from neuronxcc.nki import benchmark, baremetal +from nki_samples.reference.attention import fused_self_attn_for_SD_small_head_size +from neuronxcc.nki import benchmark, baremetal, simulate_kernel import neuronxcc.nki.language as nl import numpy as np from scipy.special import softmax test_trace_file_path='local_trace.ntff' -numeric_func = baremetal(fused_self_attn_for_SD_small_head_size) -bench_func = benchmark(warmup=5, iters=10, save_trace_name=test_trace_file_path)(fused_self_attn_for_SD_small_head_size) +bench_func = benchmark(warmup=5, iters=20, save_trace_name=test_trace_file_path)(fused_self_attn_for_SD_small_head_size) def cpu_golden_attn(q, k, v): softmax_scale = 0.125 @@ -34,33 +33,37 @@ def test_attention_for_SD_perf(self, bs, seqlen, d, dtype, latency): q = np.random.random_sample((bs, d, seqlen)).astype(np.float32) k = np.random.random_sample((bs, seqlen, d)).astype(np.float32) v = np.random.random_sample((bs, seqlen, d)).astype(np.float32) - out = nl.static_cast(np.ndarray(shape=(bs, seqlen, d)), dtype) - + q_dev = nl.static_cast(q, dtype) k_dev = nl.static_cast(k, dtype) v_dev = nl.static_cast(v, dtype) - bench_func[bs](q_dev, k_dev, v_dev, out) - latency_res = bench_func.benchmark_result.nc_latency - p99 = latency_res.get_latency_percentile(99) - assert p99 <= latency + bench_func_ = bench_func[bs] + bench_func_(q_dev, k_dev, v_dev) + latency_res = bench_func_.benchmark_result.nc_latency + p50 = latency_res.get_latency_percentile(50) + assert p50 <= latency*1.05 # short running kernels are subjected to hardware fluctuation assert os.path.getsize(test_trace_file_path) > 0 + @pytest.mark.simulation @pytest.mark.parametrize("bs,seqlen,d,dtype", [ [1, 4096, 128, np.float32], [1, 4096, 128, nl.bfloat16] ]) - def test_attention_for_SD_numberic(self, bs, seqlen, d, dtype): + def test_attention_for_SD_numberic(self, simulation_only, bs, seqlen, d, dtype): q = np.random.random_sample((bs, d, seqlen)).astype(np.float32) k = np.random.random_sample((bs, seqlen, d)).astype(np.float32) v = np.random.random_sample((bs, seqlen, d)).astype(np.float32) - out = nl.static_cast(np.ndarray(shape=(bs, seqlen, d)), dtype) - + q_dev = nl.static_cast(q, dtype) k_dev = nl.static_cast(k, dtype) v_dev = nl.static_cast(v, dtype) - numeric_func[bs](q_dev, k_dev, v_dev, out) + numeric_func = baremetal(fused_self_attn_for_SD_small_head_size) + if simulation_only: + out = simulate_kernel(numeric_func[bs], q_dev, k_dev, v_dev) + else: + out = numeric_func[bs](q_dev, k_dev, v_dev) out = nl.static_cast(out, np.float32) golden_result = cpu_golden_attn(q, k, v) assert np.allclose(out, golden_result, atol=1e-2) diff --git a/test/unit/test_allocated_SD_attention_small_head.py b/test/unit/test_allocated_SD_attention_small_head.py new file mode 100644 index 0000000..712148f --- /dev/null +++ b/test/unit/test_allocated_SD_attention_small_head.py @@ -0,0 +1,72 @@ +""" +Copyright (c) 2023, Amazon.com. All Rights Reserved +""" +import os +import pytest +from nki_samples.reference.allocated_attention import allocated_fused_self_attn_for_SD_small_head_size +from neuronxcc.nki import benchmark, baremetal, simulate_kernel +import neuronxcc.nki as nki +import neuronxcc.nki.language as nl +import numpy as np +from scipy.special import softmax + +test_trace_file_path='local_trace.ntff' + +bench_func = benchmark(warmup=5, iters=20, save_trace_name=test_trace_file_path)(allocated_fused_self_attn_for_SD_small_head_size) + +def cpu_golden_attn(q, k, v): + softmax_scale = 0.125 + q_scaled = q * softmax_scale + raw_score = np.matmul(q_scaled.transpose(0, 2, 1), k) + + norm_score = softmax(raw_score, axis=-1) + + # Transpose the result so it has the same layout as ours + return np.matmul(norm_score, v).transpose(0, 2, 1) + +class TestAttention: + + @pytest.mark.parametrize("bs,seqlen,d,dtype,latency", [ + [1, 4096, 128, np.float32, 410], + [1, 4096, 128, nl.bfloat16, 350], + [1, 5120, 128, nl.bfloat16, 586] + ]) + def test_allocated_attention_for_SD_perf(self, bs, seqlen, d, dtype, latency): + q = np.random.random_sample((bs, d, seqlen)).astype(np.float32) + k = np.random.random_sample((bs, d, seqlen)).astype(np.float32) + v = np.random.random_sample((bs, seqlen, d)).astype(np.float32) + + q_dev = nl.static_cast(q, dtype) + k_dev = nl.static_cast(k, dtype) + v_dev = nl.static_cast(v, dtype) + + bench_func_ = bench_func[bs] + bench_func_(q_dev, k_dev, v_dev) + latency_res = bench_func_.benchmark_result.nc_latency + p50 = latency_res.get_latency_percentile(50) + assert p50 <= latency * 1.05 # short running kernels are subjected to hardware fluctuation + assert os.path.getsize(test_trace_file_path) > 0 + + @pytest.mark.simulation + @pytest.mark.parametrize("bs,seqlen,d,dtype", [ + [1, 4096, 128, np.float32], + [1, 4096, 128, nl.bfloat16], + [1, 5120, 128, nl.bfloat16] + ]) + def test_allocated_attention_for_SD_numberic(self, simulation_only, bs, seqlen, d, dtype): + q = np.random.random_sample((bs, d, seqlen)).astype(np.float32) + k = np.random.random_sample((bs, d, seqlen)).astype(np.float32) + v = np.random.random_sample((bs, seqlen, d)).astype(np.float32) + + q_dev = nl.static_cast(q, dtype) + k_dev = nl.static_cast(k, dtype) + v_dev = nl.static_cast(v, dtype) + + numeric_func = baremetal(allocated_fused_self_attn_for_SD_small_head_size) + if simulation_only: + out = simulate_kernel(numeric_func[bs], q_dev, k_dev, v_dev) + else: + out = numeric_func[bs](q_dev, k_dev, v_dev) + out = nl.static_cast(out, np.float32) + golden_result = cpu_golden_attn(q, k, v) + assert np.allclose(out, golden_result, atol=1e-2) diff --git a/test/unit/test_flash_attn_bwd.py b/test/unit/test_flash_attn_bwd.py index a55abbe..0f45f9f 100644 --- a/test/unit/test_flash_attn_bwd.py +++ b/test/unit/test_flash_attn_bwd.py @@ -2,12 +2,14 @@ Copyright (c) 2023, Amazon.com. All Rights Reserved """ import pytest -from neuronxcc.nki.kernels.attention import flash_attn_bwd -from neuronxcc.nki import benchmark, baremetal +from nki_samples.reference.attention import flash_attn_bwd +from neuronxcc.nki import benchmark, baremetal, simulate_kernel import neuronxcc.nki.language as nl import numpy as np -numeric_func = baremetal(flash_attn_bwd) +xfail = pytest.mark.arch_specific_xfail + + bench_func = benchmark(warmup=5, iters=10)(flash_attn_bwd) def softmax(x: np.ndarray, dim: int, zero_max_mode=False, @@ -85,6 +87,7 @@ def mixed_precision_matmul(a, b): class TestAttention: + @xfail # P167481231 @pytest.mark.parametrize("bs, nheads, seqlen, d, dtype, latency", [ [1, 4, 32*1024, 128, nl.bfloat16, 117000], ]) @@ -97,30 +100,24 @@ def test_flash_attn_bwd_perf(self, bs, nheads, seqlen, d, dtype, latency): lse = np.random.random_sample([bs, nheads, nl.tile_size.pmax, seqlen // nl.tile_size.pmax]).astype(np.float32) seed = None - out_dq = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype) - out_dk = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype) - out_dv = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype) - q = nl.static_cast(q, dtype) k = nl.static_cast(k, dtype) v = nl.static_cast(v, dtype) o_proj = nl.static_cast(o_proj, dtype) dy = nl.static_cast(dy, dtype) - out_dq = nl.static_cast(out_dq, dtype) - out_dk = nl.static_cast(out_dk, dtype) - out_dv = nl.static_cast(out_dv, dtype) - - bench_func[bs, nheads](q, k, v, o_proj, dy, lse, seed, - out_dq, out_dk, out_dv, - use_causal_mask=True, mixed_precision=True) - latency_res = bench_func.benchmark_result.nc_latency - p99 = latency_res.get_latency_percentile(99) + + bench_func_ = bench_func[bs, nheads] + bench_func_(q, k, v, o_proj, dy, lse, seed, + use_causal_mask=True, mixed_precision=True) + latency_res = bench_func_.benchmark_result.nc_latency + p99 = latency_res.get_latency_percentile(50) assert p99 <= latency + @pytest.mark.simulation @pytest.mark.parametrize("bs, nheads, seqlen, d, dtype", [ [1, 4, 4096, 128, np.float32], ]) - def test_flash_attn_bwd_numerical(self, bs, nheads, seqlen, d, dtype): + def test_flash_attn_bwd_numerical(self, simulation_only, bs, nheads, seqlen, d, dtype): q = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2 k = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2 v = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2 @@ -130,10 +127,7 @@ def test_flash_attn_bwd_numerical(self, bs, nheads, seqlen, d, dtype): v = nl.static_cast(v, dtype) dy = nl.static_cast(dy, dtype) seed = None - out_dq = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype) - out_dk = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype) - out_dv = np.zeros(shape=[bs, nheads, d, seqlen], dtype=dtype) - + dq_golden, dk_golden, dv_golden, cached_negative_max, cached_sum_reciprocal, o_proj = \ cpu_attention_backward(q, k, v, dy, use_causal_mask=True) cached_negative_max = cached_negative_max.reshape(bs, nheads, seqlen // nl.tile_size.pmax, @@ -142,9 +136,15 @@ def test_flash_attn_bwd_numerical(self, bs, nheads, seqlen, d, dtype): nl.tile_size.pmax).transpose(0, 1, 3, 2) lse = -1.0 * (cached_negative_max + np.log(cached_sum_reciprocal)) - numeric_func[bs, nheads](q, k, v, o_proj, dy, lse, seed, - out_dq, out_dk, out_dv, - use_causal_mask=True, mixed_precision=True) + numeric_func = baremetal(flash_attn_bwd) + if simulation_only: + out_dq, out_dk, out_dv = simulate_kernel(numeric_func[bs, nheads], q, k, v, o_proj, dy, lse, seed, + use_causal_mask=True, + mixed_precision=True) + else: + out_dq, out_dk, out_dv = numeric_func[bs, nheads](q, k, v, o_proj, dy, lse, seed, + use_causal_mask=True, + mixed_precision=True) assert np.allclose(out_dq, dq_golden, atol=1e-2) assert np.allclose(out_dk, dk_golden, atol=1e-2) diff --git a/test/unit/test_flash_attn_fwd.py b/test/unit/test_flash_attn_fwd.py index 4d91164..e52354d 100644 --- a/test/unit/test_flash_attn_fwd.py +++ b/test/unit/test_flash_attn_fwd.py @@ -2,12 +2,11 @@ Copyright (c) 2023, Amazon.com. All Rights Reserved """ import pytest -from neuronxcc.nki.kernels.attention import flash_fwd, FlashConfig -from neuronxcc.nki import benchmark, baremetal +from nki_samples.reference.attention import flash_fwd, FlashConfig +from neuronxcc.nki import benchmark, baremetal, simulate_kernel import neuronxcc.nki.language as nl import numpy as np - -numeric_func = baremetal(flash_fwd) + bench_func = benchmark(warmup=5, iters=10)(flash_fwd) def softmax(x: np.ndarray, dim: int, zero_max_mode=False, @@ -63,75 +62,93 @@ def mixed_precision_matmul(a, b): class TestAttention: - @pytest.mark.parametrize("bs, nheads, seqlen, d, dtype, use_causal_mask,\ + @pytest.mark.parametrize("bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask,\ mixed_precision, training, tile_size, kv_heads, should_transpose_v, latency", [ - [1, 6, 32*1024, 96, nl.bfloat16, True, True, True, 2048, 3, False, 87000000000], - [1, 1, 32*1024, 96, nl.bfloat16, True, True, False, 2048, None, False, 15100000000], + [1, 6, 32*1024, 32*1024, 96, nl.bfloat16, True, True, True, 2048, 3, False, 87000000000], + [1, 1, 32*1024, 32*1024, 96, nl.bfloat16, True, True, False, 2048, None, False, 15100000000], + # Non-square + [1, 3, 32*1024, 16*1024, 96, nl.bfloat16, True, True, False, 2048, None, False, 7550000000], + [1, 3, 16*1024, 32*1024, 96, nl.bfloat16, True, True, False, 2048, None, False, 7550000000], ]) - def test_flash_attn_fwd_perf(self, bs, nheads, seqlen, d, dtype, use_causal_mask, + def test_flash_attn_fwd_perf(self, bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask, mixed_precision, training, tile_size, kv_heads, should_transpose_v,latency): - q = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2 - k = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2 + q = (np.random.random_sample([bs, nheads, d, seqlen_q]) - 0.5) * 2 + k = (np.random.random_sample([bs, nheads, d, seqlen_k]) - 0.5) * 2 if should_transpose_v: - v = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2 + v = (np.random.random_sample([bs, nheads, d, seqlen_k]) - 0.5) * 2 else: - v = (np.random.random_sample([bs, nheads, seqlen, d]) - 0.5) * 2 - o_proj = np.zeros(shape=[bs, nheads, seqlen, d], dtype=dtype) - out_lse = np.zeros(shape=[bs, nheads, int(nl.tile_size.pmax), seqlen // nl.tile_size.pmax], + v = (np.random.random_sample([bs, nheads, seqlen_k, d]) - 0.5) * 2 + o_proj = np.zeros(shape=[bs, nheads, seqlen_q, d], dtype=dtype) + out_lse = np.zeros(shape=[bs, nheads, int(nl.tile_size.pmax), seqlen_q // nl.tile_size.pmax], dtype=nl.float32 if mixed_precision else dtype) if training else None seed = None q = nl.static_cast(q, dtype) k = nl.static_cast(k, dtype) v = nl.static_cast(v, dtype) - o_proj = nl.static_cast(o_proj, dtype) config = FlashConfig(**{'seq_tile_size':tile_size, 'training':training, 'should_transpose_v':should_transpose_v}) heads = nheads if kv_heads is None else kv_heads - bench_func[bs, heads](q, k, v, seed, o_proj, out_lse, - use_causal_mask=use_causal_mask, mixed_precision=mixed_precision, config=config) - latency_res = bench_func.benchmark_result.nc_latency - p99 = latency_res.get_latency_percentile(99) + bench_func_ = bench_func[bs, heads] + bench_func_(q, k, v, seed, use_causal_mask=use_causal_mask, + mixed_precision=mixed_precision, config=config) + latency_res = bench_func_.benchmark_result.nc_latency + p99 = latency_res.get_latency_percentile(50) assert p99 <= latency - - @pytest.mark.parametrize("bs, nheads, seqlen, d, dtype, use_causal_mask,\ + + @pytest.mark.simulation + @pytest.mark.parametrize("bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask,\ training, tile_size, kv_heads, should_transpose_v", [ - [1, 6, 4096, 128, np.float32, True, True, 2048, 3, False], - [1, 1, 4096, 128, np.float32, True, False, 2048, None, False], + [1, 6, 4096, 4096, 128, np.float32, True, True, 2048, 3, False], + [1, 1, 4096, 4096, 128, np.float32, True, False, 2048, None, False], + [1, 1, 8192, 4096, 128, np.float32, True, False, 2048, None, False], + [1, 1, 4096, 8192, 128, np.float32, True, False, 2048, None, False], ]) - def test_flash_attn_fwd_numerical(self, bs, nheads, seqlen, d, dtype, use_causal_mask, + def test_flash_attn_fwd_numerical(self, simulation_only, bs, nheads, seqlen_q, seqlen_k, d, dtype, use_causal_mask, training, tile_size, kv_heads, should_transpose_v): - q = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2 - k = (np.random.random_sample([bs, kv_heads or nheads, d, seqlen]) - 0.5) * 2 + q = (np.random.random_sample([bs, nheads, d, seqlen_q]) - 0.5) * 2 + k = (np.random.random_sample([bs, kv_heads or nheads, d, seqlen_k]) - 0.5) * 2 if should_transpose_v: - v = (np.random.random_sample([bs, nheads, d, seqlen]) - 0.5) * 2 + v = (np.random.random_sample([bs, nheads, d, seqlen_k]) - 0.5) * 2 cpu_permute = (0, 1, 2, 3) else: - v = (np.random.random_sample([bs, kv_heads or nheads, seqlen, d]) - 0.5) * 2 + v = (np.random.random_sample([bs, kv_heads or nheads, seqlen_k, d]) - 0.5) * 2 cpu_permute = (0, 1, 3, 2) - o_proj = np.zeros(shape=[bs, nheads, seqlen, d], dtype=dtype) + q = nl.static_cast(q, dtype) k = nl.static_cast(k, dtype) v = nl.static_cast(v, dtype) seed = None - out_lse = np.zeros(shape=[bs, nheads, int(nl.tile_size.pmax), seqlen // nl.tile_size.pmax], - dtype=np.float32) if training else None o_proj_golden, cached_negative_max, cached_sum_reciprocal = \ cpu_attention_forward(q, k, v.transpose(cpu_permute), use_causal_mask=use_causal_mask,mixed_precision=True) o_proj_golden = o_proj_golden.transpose(0,1,3,2) # (b,h, d, seq) - cached_negative_max = cached_negative_max.reshape(bs, nheads, seqlen // nl.tile_size.pmax, + cached_negative_max = cached_negative_max.reshape(bs, nheads, seqlen_q // nl.tile_size.pmax, nl.tile_size.pmax).transpose(0, 1, 3, 2) - cached_sum_reciprocal = cached_sum_reciprocal.reshape(bs, nheads, seqlen // nl.tile_size.pmax, + cached_sum_reciprocal = cached_sum_reciprocal.reshape(bs, nheads, seqlen_q // nl.tile_size.pmax, nl.tile_size.pmax).transpose(0, 1, 3, 2) lse_golden = -1.0 * (cached_negative_max + np.log(cached_sum_reciprocal)) if training else None config = FlashConfig(**{'seq_tile_size':tile_size, 'training':training, 'should_transpose_v':should_transpose_v}) heads = nheads if kv_heads is None else kv_heads - numeric_func[bs, heads](q, k, v, seed, o_proj, out_lse, seed, - use_causal_mask=use_causal_mask, mixed_precision=True, config=config) - assert np.allclose(o_proj, o_proj_golden, atol=1e-2) + numeric_func = baremetal(flash_fwd) + if simulation_only: + results = simulate_kernel(numeric_func[bs, heads], q, k, v, seed, + use_causal_mask=use_causal_mask, + mixed_precision=True, + config=config) + else: + results = numeric_func[bs, heads](q, k, v, seed, + use_causal_mask=use_causal_mask, + mixed_precision=True, + config=config) + if training: + o_proj, out_lse = results + assert np.allclose(o_proj, o_proj_golden, atol=1e-2) assert np.allclose(out_lse, lse_golden, atol=1e-2) + else: + o_proj = results + assert np.allclose(o_proj, o_proj_golden, atol=1e-2) diff --git a/test/unit/test_neuron_profile.py b/test/unit/test_neuron_profile.py new file mode 100644 index 0000000..e607705 --- /dev/null +++ b/test/unit/test_neuron_profile.py @@ -0,0 +1,86 @@ +from neuronxcc.nki import benchmark +from neuronxcc.nki import profile +import neuronxcc.nki.language as nl +import numpy as np +import pytest +import os +import shutil +import tempfile + + +WORKING_DIRECTORY = tempfile.mkdtemp() +SAVE_NEFF_NAME = "cus_file123.neff" +SAVE_TRACE_NAME = "profile-custom.ntff" +NUM_EXECS = 20 +PROFILE_NTH = 10 +JSON_REPORTS = "json_reports" + +@profile(working_directory=WORKING_DIRECTORY, save_neff_name=SAVE_NEFF_NAME, overwrite=False , save_trace_name=SAVE_TRACE_NAME, num_execs=NUM_EXECS, profile_nth=PROFILE_NTH) +def nki_tensor_tensor_add(a_tensor, b_tensor): + c_output = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype, buffer=nl.shared_hbm) + + a = nl.load(a_tensor) + b = nl.load(b_tensor) + + c_tile = a + b + + nl.store(c_output, value=c_tile) + + return c_output + +class TestNeuronProfile: + def _get_ntff_path(self, trace_val): + """ + Prepares ntff file name based on execution trace number + """ + if trace_val == 1: + return os.path.join(WORKING_DIRECTORY, f"{os.path.splitext(os.path.basename(SAVE_TRACE_NAME))[0]}.ntff") + else: + return os.path.join(WORKING_DIRECTORY, f"{os.path.splitext(os.path.basename(SAVE_TRACE_NAME))[0]}_exec_{trace_val}.ntff") + + @pytest.fixture + def traces(self): + ret = [] + if NUM_EXECS < PROFILE_NTH: + ret.append(self._get_ntff_path(PROFILE_NTH)) + else: + curr = PROFILE_NTH + while curr <= NUM_EXECS: + ret.append(self._get_ntff_path(curr)) + curr += PROFILE_NTH + return ret + + @pytest.fixture + def num_reports(self): + if NUM_EXECS < PROFILE_NTH: + return 1 + else: + return NUM_EXECS // PROFILE_NTH + + def test_output_artifacts_created(self, traces, num_reports): + # delete artifact directory, only testing non-overwrite functionality + if os.path.exists(WORKING_DIRECTORY): + shutil.rmtree(WORKING_DIRECTORY) + + # creates dummy input to invoke profile kernel + a = np.zeros([128, 1024]).astype(np.float16) + b = np.random.random_sample([128, 1024]).astype(np.float16) + + output_nki = nki_tensor_tensor_add(a, b) + + # now asserting artifacts are correctly created + assert os.path.exists(os.path.join(WORKING_DIRECTORY, SAVE_NEFF_NAME)) # neff + + for trace in traces: + assert os.path.exists(trace) # trace + + # json reports + report_dir = os.path.join(WORKING_DIRECTORY, JSON_REPORTS) + + assert os.path.exists(report_dir) # actually exists + assert len(os.listdir(report_dir)) == num_reports # report all iterations queried + + # post condition cleanup + if os.path.exists(WORKING_DIRECTORY): + shutil.rmtree(WORKING_DIRECTORY) + diff --git a/test/unit/test_resize_nearest.py b/test/unit/test_resize_nearest.py index a77968b..72e7aef 100644 --- a/test/unit/test_resize_nearest.py +++ b/test/unit/test_resize_nearest.py @@ -3,14 +3,14 @@ """ import pytest -from neuronxcc.nki.kernels.vision import resize_nearest_fixed_dma_kernel -from neuronxcc.nki import benchmark, baremetal +from nki_samples.reference.vision import resize_nearest_fixed_dma_kernel +from neuronxcc.nki import benchmark, baremetal, simulate_kernel import neuronxcc.nki.language as nl import numpy as np -numeric_func = baremetal(resize_nearest_fixed_dma_kernel) bench_func = benchmark(warmup=5, iters=10)(resize_nearest_fixed_dma_kernel) + def cpu_golden_result(data_tensor, output_shape): in_b, in_h, in_w, in_c = data_tensor.shape out_b, out_h, out_w, out_c = output_shape @@ -36,33 +36,37 @@ def cpu_golden_result(data_tensor, output_shape): class TestResizeNearest: @pytest.mark.parametrize("in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype, latency", [ - [10, 30, 20, 1280, 10, 59, 38, 1280, np.float32, 1722], + [10, 30, 20, 1280, 10, 59, 38, 1280, np.float32, 1740], [1, 30, 20, 1280, 1, 59, 38, 1280, nl.float16, 659], [1, 30, 20, 1280, 1, 59, 38, 1280, nl.bfloat16, 659], ]) def test_resize_nearest_for_perf(self, in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype, latency): input_tensor = np.random.random_sample((in_b, in_h, in_w, in_c)).astype(np.float32) - output_tensor = nl.static_cast(np.ndarray(shape=(out_b, out_h, out_w, out_c)), dtype) - + input_dev = nl.static_cast(input_tensor, dtype) - bench_func[in_b](input_dev, output_tensor) - latency_res = bench_func.benchmark_result.nc_latency - p99 = latency_res.get_latency_percentile(99) + bench_func_ = bench_func[in_b] + bench_func_(input_dev, (out_b, out_h, out_w, out_c)) + latency_res = bench_func_.benchmark_result.nc_latency + p99 = latency_res.get_latency_percentile(50) assert p99 <= latency + @pytest.mark.simulation @pytest.mark.parametrize("in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype", [ [10, 30, 20, 1280, 10, 59, 38, 1280, np.float32], [1, 30, 20, 1280, 1, 59, 38, 1280, nl.float16], [1, 30, 20, 1280, 1, 59, 38, 1280, nl.bfloat16], ]) - def test_resize_nearest_for_numberic(self, in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype): + def test_resize_nearest_for_numberic(self, simulation_only, in_b, in_h, in_w, in_c, out_b, out_h, out_w, out_c, dtype): input_tensor = np.random.random_sample((in_b, in_h, in_w, in_c)).astype(np.float32) - output_tensor = nl.static_cast(np.ndarray(shape=(out_b, out_h, out_w, out_c)), dtype) - + input_dev = nl.static_cast(input_tensor, dtype) - numeric_func[in_b](input_dev, output_tensor) + numeric_func = baremetal(resize_nearest_fixed_dma_kernel) + if simulation_only: + output_tensor = simulate_kernel(numeric_func[in_b], input_dev, (out_b, out_h, out_w, out_c)) + else: + output_tensor = numeric_func[in_b](input_dev, (out_b, out_h, out_w, out_c)) output_tensor = nl.static_cast(output_tensor, np.float32) golden_result = cpu_golden_result(input_tensor, output_tensor.shape) assert np.allclose(output_tensor, golden_result, atol=1e-2) diff --git a/test/unit/test_rmsnorm_qkv.py b/test/unit/test_rmsnorm_qkv.py new file mode 100644 index 0000000..28838d1 --- /dev/null +++ b/test/unit/test_rmsnorm_qkv.py @@ -0,0 +1,69 @@ +""" +Copyright (c) 2024, Amazon.com. All Rights Reserved +""" +import pytest +from nki_samples.reference.allocated_fused_linear import allocated_fused_rms_norm_qkv +from neuronxcc.nki import benchmark, baremetal, simulate_kernel +import neuronxcc.nki.language as nl +import numpy as np + +bench_func = benchmark(warmup=5, iters=10)(allocated_fused_rms_norm_qkv) + +np.random.seed(0) + + +def rms_norm(hidden, gamma, eps=1e-6): + rms = np.sqrt(np.mean(np.square(hidden), axis=-1, keepdims=True)) + norm = hidden * np.reciprocal(rms + eps) + if gamma is not None: + norm *= gamma + return norm + +def cpu_golden_result(hidden, gamma, qkv_weights, dtype, do_norm=True): + if do_norm: + hidden = rms_norm(hidden, gamma) + qkv_out = (hidden @ qkv_weights).astype(dtype) + return qkv_out + +class TestRMSNormQKV: + @pytest.mark.parametrize("batch, seqlen, dim, d_head, dtype, latency", [ + [1, 128, 512, 512, np.float16, 25], + [1, 512, 1024, 384, nl.bfloat16, 40], + [1, 128, 1024, 512, nl.bfloat16, 28], + # [1, 1024, 8192, 512, nl.bfloat16, 301 * 1.02], # FIXME: performance is flaky + ]) + def test_allocated_rmsnorm_qkv_perf(self, batch, seqlen, dim, d_head, dtype, latency): + hidden = np.random.random_sample((batch, seqlen, dim)).astype(np.float32) + weights = np.random.random_sample((dim, d_head)).astype(np.float32) + + hidden = nl.static_cast(hidden, dtype) + weights = nl.static_cast(weights, dtype) + + bench_func(hidden, weights) + latency_res = bench_func.benchmark_result.nc_latency + p99 = latency_res.get_latency_percentile(50) + assert p99 <= latency + + @pytest.mark.simulation + @pytest.mark.parametrize("batch, seqlen, dim, d_head, dtype", [ + [1, 128, 512, 512, np.float16], + [1, 512, 1024, 384, nl.bfloat16], + [1, 128, 1024, 512, nl.bfloat16], + [1, 1024, 8192, 512, nl.bfloat16] + ]) + def test_allocated_rmsnorm_qkv_numeric(self, simulation_only, batch, seqlen, dim, d_head, dtype): + hidden = np.random.random_sample((batch, seqlen, dim)) + weights = np.random.random_sample((dim, d_head)) + + hidden_dev = nl.static_cast(hidden, dtype) + weights_dev = nl.static_cast(weights, dtype) + + numeric_func = baremetal(allocated_fused_rms_norm_qkv) + if simulation_only: + out = simulate_kernel(numeric_func, hidden_dev, weights_dev) + else: + out = numeric_func(hidden_dev, weights_dev) + out = nl.static_cast(out, np.float32) + golden_res = nl.static_cast(cpu_golden_result(hidden, None, weights, dtype, do_norm=True), np.float32) + assert np.allclose(out, golden_res, atol=1e-2, rtol=1e-2) + diff --git a/test/unit/test_select_and_scatter.py b/test/unit/test_select_and_scatter.py index 70f7a7c..08e787f 100644 --- a/test/unit/test_select_and_scatter.py +++ b/test/unit/test_select_and_scatter.py @@ -1,11 +1,10 @@ import pytest -from neuronxcc.nki.kernels.vision import select_and_scatter_kernel -from neuronxcc.nki import benchmark, baremetal +from nki_samples.reference.vision import select_and_scatter_kernel +from neuronxcc.nki import benchmark, baremetal, simulate_kernel import neuronxcc.nki.language as nl import numpy as np -numeric_func = baremetal(select_and_scatter_kernel) bench_func = benchmark(warmup=5, iters=10)(select_and_scatter_kernel) np.random.seed(0) @@ -39,7 +38,6 @@ def cpu_golden_result(operand_tensor, source_tensor, window_dimensions=(3, 3), w out_h = h * stride_h + local_h - padding[0] out_w = w * stride_w + local_w - padding[1] output_tensor[n, c, out_h, out_w] += source_tensor[n, c, h, w] - return output_tensor class TestSelectAndScatter: @@ -47,31 +45,33 @@ class TestSelectAndScatter: [8, 64, 112, 112, 56, 56, np.float32, 4500], ]) def test_select_and_scatter_for_perf(self, n, c, operand_h, operand_w, source_h, source_w, dtype, latency): - operand_tensor = np.random.random_sample((n, c, operand_h, operand_w)).astype(np.float32) - source_tensor = np.random.random_sample((n, c, source_h, source_w)).astype(np.float32) - output_tensor = nl.static_cast(np.ndarray(shape=(n, c, operand_h, operand_w)), dtype) - - operand_dev = nl.static_cast(operand_tensor, dtype) - source_dev = nl.static_cast(source_tensor, dtype) + operand_dev = nl.static_cast(np.random.random_sample((n, c, operand_h, operand_w)), dtype) + source_dev = nl.static_cast(np.random.random_sample((n, c, source_h, source_w)), dtype) - bench_func(operand_dev, source_dev, output_tensor) + bench_func(operand_dev, source_dev) latency_res = bench_func.benchmark_result.nc_latency - p99 = latency_res.get_latency_percentile(99) + p99 = latency_res.get_latency_percentile(50) assert p99 <= latency + @pytest.mark.simulation @pytest.mark.parametrize("n, c, operand_h, operand_w, source_h, source_w, dtype", [ [8, 64, 112, 112, 56, 56, np.float32], - pytest.param(8, 64, 112, 112, 56, 56, nl.bfloat16, marks=pytest.mark.xfail), + [8, 64, 112, 112, 56, 56, nl.bfloat16], ]) - def test_select_and_scatter_for_numeric(self, n, c, operand_h, operand_w, source_h, source_w, dtype): - operand_tensor = np.random.random_sample((n, c, operand_h, operand_w)).astype(np.float32) - source_tensor = np.random.random_sample((n, c, source_h, source_w)).astype(np.float32) - output_tensor = nl.static_cast(np.ndarray(shape=(n, c, operand_h, operand_w)), dtype) - - operand_dev = nl.static_cast(operand_tensor, dtype) - source_dev = nl.static_cast(source_tensor, dtype) + def test_select_and_scatter_for_numeric(self,simulation_only, n, c, operand_h, operand_w, source_h, source_w, dtype): + operand_dev = nl.static_cast(np.random.random_sample((n, c, operand_h, operand_w)), dtype) + source_dev = nl.static_cast(np.random.random_sample((n, c, source_h, source_w)), dtype) + + sw = nl.static_cast(np.ndarray(shape=(n, c, source_h, source_w, 3, 3)), dtype) + operand_tensor = nl.static_cast(operand_dev, np.float32) + source_tensor = nl.static_cast(source_dev, np.float32) - numeric_func(operand_dev, source_dev, output_tensor) + numeric_func = baremetal(select_and_scatter_kernel) + if simulation_only: + output_dev = simulate_kernel(numeric_func, operand_dev, source_dev) + else: + output_dev = numeric_func(operand_dev, source_dev) golden_result = cpu_golden_result(operand_tensor, source_tensor) - output_tensor = nl.static_cast(output_tensor, np.float32) - assert np.allclose(output_tensor, golden_result) \ No newline at end of file + nki_result = nl.static_cast(output_dev, np.float32) + + assert np.allclose(nki_result, golden_result, rtol=1e-2, atol=1e-2)