Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NKI matmul - result and store semantics #36

Closed
praveen-velliengiri opened this issue Nov 21, 2024 · 5 comments
Closed

NKI matmul - result and store semantics #36

praveen-velliengiri opened this issue Nov 21, 2024 · 5 comments
Assignees
Labels
documentation Improvements or additions to documentation

Comments

@praveen-velliengiri
Copy link

praveen-velliengiri commented Nov 21, 2024

Hi, I'm following the tutorial matmul to write matmul kernels in nki.
Below is my nki matmul kernel

def nki_matmul_base(lhsT, rhs, result):
  #lhsT - (128, 64) rhs - (128, 512) this implements 64 x 128 x 512 matmul

  i_lhsT_p, i_lhsT_f = nl.mgrid[0:128, 0:64]

  i_rhs_p,  i_rhs_f  = nl.mgrid[0:128, 0:512]

  i_out_p,  i_out_f  = nl.mgrid[0:64, 0:512] #this is sbuf result

  lhs_tile = nl.load(lhsT[i_lhsT_p, i_lhsT_f])
  rhs_tile = nl.load(rhs[i_rhs_p, i_rhs_f])

  result_psum = nl.matmul(lhs_tile, rhs_tile, transpose_x=True) #(i_lhsT_f x i_rhs_f)

  nl.store(result[i_out_p, i_out_f], value = result_psum)

In this code, I have two basic questions:

  1. I was expecting the shape of result_psum to be i_rhs_f x i_lhsT_f (512 x 64) instead of i_lhsT_f x i_rhs_f (64 x 512) I think this is consistent with the image presented in this tutorial page. I just wanted to know whether I'm misinterpreting the dimensions of the result. The image gives me a idea that psum_buffer should be i_rhs_f x i_lhsT_f (512 x 64) and while copying psum_buffer to SBUF we should arrange the layout to be i_lhsT_f x i_rhs_f (64 x 512). I believe I'm misinterpreting the layouts between the API and image, it would be very much helpful if you could explain me to get a correct understanding.

  2. In the last line, nl.store takes a buffer in psum instead of SBUF and it still works correctly. I'm not sure what happened here. I verified the results between NKI and pytorch code.

The full repo is here : mm.py

thank you

@aws-qieqingy
Copy link
Contributor

Hi @preejackie! Thanks for filing the issue. For question 1, in the code snippet, since we have transpose_x=True, lhs_tile is in fact lhs_T that has shape [p(K), M]=(128, 64). As part (b) of the figure you linked, M=64 becomes the partition dimension of the output. We will look into updating the figure for better clarity.

For question 2, compiler legalize the operation and inserts a copy from psum to sbuf before storing to the output tensor. This behaviour may change in future releases, I recommend explicitly copy psum into sbuf before calling nl.store.

@aws-qieqingy aws-qieqingy transferred this issue from aws-neuron/aws-neuron-sdk Nov 21, 2024
@aws-qieqingy aws-qieqingy added the documentation Improvements or additions to documentation label Nov 21, 2024
@praveen-velliengiri
Copy link
Author

@aws-qieqingy thanks a lot for your answer, it is very helpful to understand. If you don't mind can I ask you an another question. I'm trying to write a tiled version of the matrix multiplication from the tutorial.

import neuronxcc.nki.language as nl

'''
we can swap m,n but i think that would require moving tile
in and out of PE array many times.
for m in range(0, M, 128):
  for n in range(0, N, 512):
    accum = (128, 512)
    for k in range(0, k, 128):
      lhs_tile = lhs[k:k+128, m:m+128]
      rhs_tile = rhs[k:k+128, n:n+512]
'''

#ToDo: implement without transpose, let the API do transpose
#ToDo: implement SPMD version
def nki_matmul_tiled(lhsT, rhs, result, tile_count):
  k, m  = lhsT.shape
  k_, n  = rhs.shape

  assert k == k_, "contraction dim should be same"
  
  TILE_k = nl.tile_size.pmax #128
  TILE_m = nl.tile_size.gemm_stationary_fmax
  TILE_n = nl.tile_size.gemm_moving_fmax

  print(f"m tiles: {m // TILE_m}, n tiles: {n // TILE_n}, k tiles: {k // TILE_k}")

  for m in nl.affine_range(m // TILE_m):
    for n in nl.affine_range(n // TILE_n):
      res_psum = nl.zeros((TILE_m, TILE_n), dtype=nl.float32, buffer=nl.psum) #change to sbuf and see

      for k in nl.affine_range(k // TILE_k):
        print("tile")
        #k * TILE_k : k * TILE_k + TILE_K
        stationary_tile = nl.load(lhsT[k*TILE_k : (k+1)*TILE_k, m*TILE_m : (m+1) * TILE_m])
        moving_tile     = nl.load(rhs[k*TILE_k : (k+1)*TILE_k, n*TILE_n:(n+1) * TILE_n])

        res_psum += nl.matmul(stationary_tile, moving_tile, transpose_x = True)
      
      res_sbuf = nl.copy(res_psum, dtype=result.dtype)
      nl.store(result[m*TILE_m : (m+1) * TILE_m, n*TILE_n:(n+1)*TILE_n], value=res_sbuf)
  

import torch
from torch_neuronx import nki_jit
from torch_xla.core import xla_model as xm


if __name__ == "__main__":
  device = xm.xla_device()
  cpu    = torch.device('cpu')

  lhs_small = torch.rand((512, 128), dtype=torch.bfloat16, device =  device)
  rhs_small = torch.rand((128, 512), dtype=torch.bfloat16, device = device)
  output_small = torch.rand((512, 512), dtype=torch.bfloat16, device = device)
  count = [0]
  nki_matmul_tiled_jit = nki_jit(nki_matmul_tiled)
  nki_matmul_tiled_jit(lhs_small.T, rhs_small, output_small, count)

  output_small_torch = torch.matmul(lhs_small, rhs_small)
  print("tile counter: ", count)
  if torch.allclose(output_small, output_small_torch, atol = 1e-4, rtol = 1e-2):
    print("match")
  else:
    print("differ")

I understand that associative reductions are not loop-carried dependency and can be executed in parallel. Based on the free dimension of lhs_small matrix, there should be 4 tile matrix multiplications should be performed, I believe that the above code do exactly that. However, I wanted to check programmatically how many tiles are present in the above code, so I add print(tile) in the inner most loop, but when I execute it there is always one output instead of 4 (one for each tile). It would be very much helpful, to know how such tiled matmul are handled by the neuron device.
We are planning to write a transpiler which generates NKI automatically for a tensor program super optimizer, hence knowing this detail will help us to understand NKI better.
thanks

@aws-zhehongb
Copy link

the print statement is happen in "trace" time of nki, not the actual "execution" time. Because nki only trace loop body once, you will only see one print. For now you need to compute the number of tiles analytically based on tensor shape and tile sizes. We may introduce some reflection/inspection api in future for your research.

for super optimizer you could also check https://github.com/awslabs/nki-autotune/blob/main/test/matmul.py

@praveen-velliengiri
Copy link
Author

@aws-zhehongb thanks a lot for your explanation, it makes sense now.

Copy link

This issue is now closed. Comments on closed issues are hard for our team to see.
If you need more assistance, please open a new issue that references this one.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

5 participants