Skip to content

Commit

Permalink
nki-samples updates for NeuronSDK release 2.21 (#38)
Browse files Browse the repository at this point in the history
* NeuronSDK 2.21: NKI-Samples Update (#6)

Update nki-samples for NeuronSDK 2.21 Beta Release.
1. Use new nki.jit decorator for kernels.
2. Added samples for the new direct allocation feature.
3. Misc tests and documentation improvements.

Co-authored-by: aws-qieqingy <[email protected]>

* Sync latest kernels and unit tests 

* Add latest kernels and unit tests for NeuronSDK release 2.21

* Move code into src/nki_samples to make unit test executable

* Add Github workflow to run simulation tests on main branch and incoming PRs
  • Loading branch information
aws-qieqingy authored Dec 23, 2024
1 parent d53cdd0 commit 05c0769
Show file tree
Hide file tree
Showing 46 changed files with 2,442 additions and 1,567 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/run_simulation_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: Run Python Simulation Tests

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

permissions:
contents: read

jobs:
build:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@v3
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com
python -m pip install wget awscli
python -m pip install pytest
python -m pip install neuronx-cc==2.*
- name: Test with pytest
run: |
PYTHONPATH=$PYTHONPATH:src/ pytest test/unit/ --simulation-only
10 changes: 10 additions & 0 deletions src/nki_samples/reference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2023, Amazon.com. All Rights Reserved

"""
Package containing public kernels for Neuron Kernel Interface (NKI).
Kernels here are also available in the `neuronxcc.nki.kernels` namespace, and they
are synced with this repository on every Neuron SDK release.
https://github.com/aws-neuron/nki-samples
"""
283 changes: 283 additions & 0 deletions src/nki_samples/reference/allocated_attention.py

Large diffs are not rendered by default.

114 changes: 114 additions & 0 deletions src/nki_samples/reference/allocated_fused_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""
Copyright (c) 2024, Amazon.com. All Rights Reserved
kernels - Fused normalization with linear layers
"""

import neuronxcc.nki.language as nl
import neuronxcc.nki.isa as nisa
import neuronxcc.nki.compiler as ncc
import math
import numpy as np
from neuronxcc import nki
from neuronxcc.nki.language import par_dim

@nki.jit
def allocated_fused_rms_norm_qkv(hidden, weights, norm_dtype=nl.float32, eps=1e-6):
"""
Allocated kernel that computes RMSNorm(hidden) @ wQKV. This kernel is designed to only handle fp16/bf16 tensor types.
Internally, normalizations are cast to fp32 to avoid NaN errors.
Args:
hidden (_type_): Input tensor of the attention block in BSH layout
weights (_type_): Fused QKV linear weights, assumed to be eltwise-multiplied with RMS norm weight vector (gamma)
out_tensor (_type_): Output tensor
norm_dtype (_type_, optional): Data type for RMS norm, should be f32 to avoid NaN. Defaults to nl.float32.
eps (_type_, optional): RMS norm epsilon term. Defaults to 1e-6.
"""
# Hidden should be in BSH layout.
batch, batchless_shape = hidden.shape[0], hidden.shape[1:]
seqlen, dim = batchless_shape
_dim, head_dim = weights.shape

assert dim <= 8192 and dim & 128 == 0, "Unsupported hidden dimension"
assert _dim == dim, "Reduction dimension must match"
assert head_dim <= 512, "Head dimension must be 512 or less"

out_tensor = nl.ndarray((batch, seqlen, head_dim), dtype=hidden.dtype, buffer=nl.shared_hbm)

pmax, fmax = nl.tile_size.pmax, nl.tile_size.psum_fmax # 128, 512
ix, iy = nl.mgrid[0:pmax, 0:dim]
i_lhs = nl.mgrid[0:pmax, 0:pmax]
i_rhs = nl.mgrid[0:pmax, 0:fmax]
i_res = nl.mgrid[0:pmax, 0:fmax]
M = math.ceil(dim / pmax)
NUM_TRANSP_TILES = math.ceil(dim / fmax)
NUM_TILES = math.ceil(seqlen / pmax)
TILES_INT = math.ceil(NUM_TILES / 2)
scale = 1 / dim

iden_x, iden_y = nl.mgrid[0:pmax, 0:128]

identity_a = nl.shared_constant(np.identity(n=128, dtype=np.int8), dtype=hidden.dtype)
identity_tensor = nl.ndarray((par_dim(pmax), 128), dtype=weights.dtype, buffer=ncc.sbuf.mod_alloc(base_addr=0))
identity_tensor[iden_x, iden_y] = nl.load(identity_a, dtype=weights.dtype)
bias_placeholder = nl.ndarray((par_dim(pmax), 1), dtype=np.float32, buffer=ncc.sbuf.mod_alloc(base_addr=128*2))
bias_placeholder[...] = 0

for b in nl.affine_range(batch):
weights_buffer = nl.ndarray((M, par_dim(pmax), fmax), dtype=weights.dtype,
buffer=ncc.sbuf.mod_alloc(base_addr=260+(3*dim+fmax)*2+(dim+1)*4, num_free_tiles=(M,)))
# Preload the entire weights tensor. everything fits in SBUF for LLaMA 3.1 70B
for m in nl.affine_range(M):
weights_buffer[m, i_rhs.p, i_rhs.x] = nl.load(weights[m*pmax+i_rhs.p, i_rhs.x],
mask=(m*pmax+i_rhs.p<dim) & (i_rhs.x<head_dim))
for i in nl.affine_range(TILES_INT):
# Double buffer the input tensor
in_bufs = nl.ndarray((2, par_dim(pmax), dim), dtype=hidden.dtype, buffer=ncc.sbuf.mod_alloc(base_addr=260, num_free_tiles=(2,)))
for i_interleave_grp in nl.affine_range(2):
in_bufs[i_interleave_grp] = nl.load(hidden[b, (2*i+i_interleave_grp)*pmax+ix, iy], mask=(2*i+i_interleave_grp)*pmax+ix < seqlen)
act = nl.ndarray((par_dim(pmax), dim), dtype=norm_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=260+(2*dim)*2))

# Write the RMS and RMS Reciprocal tensors back out here, in-place
square_sum = nl.ndarray((par_dim(pmax), 1), dtype=norm_dtype, buffer=ncc.sbuf.mod_alloc(base_addr=260+(2*dim)*2+(dim)*4))

# Write the output of RMS and RMS^T (in-place) out to here
out_tile = nl.ndarray((par_dim(pmax), dim), dtype=weights.dtype,
buffer=ncc.sbuf.mod_alloc(base_addr=260+(2*dim)*2+(dim+1)*4))

# Store the final output tiles to here before sending back to DRAM
output_sbuf = nl.ndarray((par_dim(pmax), fmax), dtype=weights.dtype,
buffer=ncc.sbuf.mod_alloc(base_addr=260+(3*dim)*2+(dim+1)*4))

act[...] = nisa.activation_reduce(op=nl.square, data=in_bufs[i_interleave_grp], reduce_op=np.add, reduce_res=square_sum[...], bias=bias_placeholder[...])
square_sum[...] = nisa.tensor_scalar(square_sum[...], np.multiply, scale, op1=np.add, operand1=eps)
square_sum[...] = nisa.activation(op=nl.rsqrt, data=square_sum[...], bias=bias_placeholder[...])

# all PE array ops must output to FP32 on trn1 but must match input dtype in trn2
if nisa.get_nc_version() == nisa.nc_version.gen3:
transpose_res_psum = nl.ndarray((NUM_TRANSP_TILES, par_dim(pmax), 4*pmax), dtype=weights.dtype,
buffer=ncc.psum.mod_alloc(base_bank=0, num_bank_tiles=(1,))) # FIXME: perf is better when all tiles are on bank 0?
else:
transpose_res_psum = nl.ndarray((NUM_TRANSP_TILES, par_dim(pmax), 4*pmax), dtype=np.float32,
buffer=ncc.psum.mod_alloc(base_bank=0, num_bank_tiles=(1,))) # FIXME: perf is better when all tiles are on bank 0?

for m in nl.affine_range(NUM_TRANSP_TILES):
# Perform (hidden .* RMS Reciprocal)^T in tiles of fmax (512)
out_tile[i_rhs.p, m*fmax+i_rhs.x] = nl.multiply(in_bufs[i_interleave_grp, i_rhs.p, m*fmax + i_rhs.x], square_sum[...], dtype=weights.dtype)
for j in nl.affine_range(4):
transpose_res_psum[m, i_lhs.p, j*pmax+i_lhs.x] = nisa.nc_matmul(out_tile[i_lhs.p, (m*4+j) * pmax + i_lhs.x], identity_tensor[...],
is_transpose=True)
out_tile[i_rhs.p, m * 4*pmax + i_rhs.x] = nl.copy(transpose_res_psum[m], dtype=hidden.dtype)

# perform (RMSNorm(hidden)^T)^T @ wQKV
res_psum = nl.ndarray((1, par_dim(pmax), fmax), dtype=nl.float32,
buffer=ncc.psum.mod_alloc(base_bank=7, num_bank_tiles=(1,)))
for m in nl.affine_range(M):
res_psum[0] += nisa.nc_matmul(out_tile[i_lhs.p, m*pmax+i_lhs.x], weights_buffer[m, i_rhs.p, i_rhs.x])

output_sbuf[...] = nl.copy(res_psum[0], dtype=out_tensor.dtype)
nl.store(out_tensor[b, (2*i+i_interleave_grp)*pmax+i_res.p, i_res.x],
value=output_sbuf,
mask=((2*i+i_interleave_grp)*pmax+i_res.p<seqlen) & (i_res.x<head_dim))
return out_tensor
Loading

0 comments on commit 05c0769

Please sign in to comment.