-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OV single matmul case fail #382
Comments
The root cause of this failure is similar to #360 ( The MLIR generated by OV for the model from the reproducer (#382) looks like this: MLIR with linalg.broadcastmodule @fragment_name attributes {"#dlti.sys_spec" = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 32 : i32>>>} {
func.func @entry(%arg0: memref<512x512xf16>, %arg1: memref<512x512xf16>, %arg2: memref<1x512xf16>, %arg3: memref<512x512xf16>) {
%0 = bufferization.to_tensor %arg0 restrict : memref<512x512xf16>
%1 = bufferization.to_tensor %arg1 restrict : memref<512x512xf16>
%2 = bufferization.to_tensor %arg2 restrict : memref<1x512xf16>
%3 = tensor.empty() : tensor<512x512xf16>
%cst = arith.constant 0.000000e+00 : f16
%4 = linalg.fill ins(%cst : f16) outs(%3 : tensor<512x512xf16>) -> tensor<512x512xf16>
%5 = linalg.matmul_transpose_b ins(%0, %1 : tensor<512x512xf16>, tensor<512x512xf16>) outs(%4 : tensor<512x512xf16>) -> tensor<512x512xf16>
%collapsed = tensor.collapse_shape %2 [[0, 1]] : tensor<1x512xf16> into tensor<512xf16>
%6 = tensor.empty() : tensor<512x512xf16>
%broadcasted = linalg.broadcast ins(%collapsed : tensor<512xf16>) outs(%6 : tensor<512x512xf16>) dimensions = [0]
%7 = tensor.empty() : tensor<512x512xf16>
%8 = linalg.add ins(%5, %broadcasted : tensor<512x512xf16>, tensor<512x512xf16>) outs(%7 : tensor<512x512xf16>) -> tensor<512x512xf16>
bufferization.materialize_in_destination %8 in restrict writable %arg3 : (tensor<512x512xf16>, memref<512x512xf16>) -> ()
return
}
} The main difference from the modules we've been processing before, is that it takes bias as a 1D tensor (%arg2 : tensor<1x512>) that has to be broadcasted to a proper shape (512x512). MLIR module after 'one-shot-bufferize' passfunc.func @entry(%arg0: memref<512x512xf16>, %arg1: memref<512x512xf16>, %arg2: memref<1x512xf16>, %arg3: memref<512x512xf16>, %arg4: memref<i8>) {
%cst = arith.constant 0.000000e+00 : f16
%alloc = memref.alloc() {alignment = 64 : i64} : memref<512x512xf16>
%collapse_shape = memref.collapse_shape %arg2 [[0, 1]] : memref<1x512xf16> into memref<512xf16>
%c0 = arith.constant 0 : index
%c0_0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%c512_1 = arith.constant 512 : index
%c32 = arith.constant 32 : index
%c32_2 = arith.constant 32 : index
scf.parallel (%arg5, %arg6) = (%c0, %c0_0) to (%c512, %c512_1) step (%c32, %c32_2) {
%subview = memref.subview %arg0[%arg5, 0] [32, 512] [1, 1] : memref<512x512xf16> to memref<32x512xf16, strided<[512, 1], offset: ?>>
%subview_3 = memref.subview %arg1[%arg6, 0] [32, 512] [1, 1] : memref<512x512xf16> to memref<32x512xf16, strided<[512, 1], offset: ?>>
%subview_4 = memref.subview %alloc[%arg5, %arg6] [32, 32] [1, 1] : memref<512x512xf16> to memref<32x32xf16, strided<[512, 1], offset: ?>>
linalg.fill ins(%cst : f16) outs(%subview_4 : memref<32x32xf16, strided<[512, 1], offset: ?>>)
linalg.matmul_transpose_b ins(%subview, %subview_3 : memref<32x512xf16, strided<[512, 1], offset: ?>>, memref<32x512xf16, strided<[512, 1], offset: ?>>) outs(%subview_4 : memref<32x32xf16, strided<[512, 1], offset: ?>>)
%subview_5 = memref.subview %collapse_shape[%arg6] [32] [1] : memref<512xf16> to memref<32xf16, strided<[1], offset: ?>>
%alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf16> // allocating buffer for the result of 'linalg.broadcast'
linalg.broadcast ins(%subview_5 : memref<32xf16, strided<[1], offset: ?>>) outs(%alloc_6 : memref<32x32xf16>) dimensions = [0]
%subview_7 = memref.subview %arg3[%arg5, %arg6] [32, 32] [1, 1] : memref<512x512xf16> to memref<32x32xf16, strided<[512, 1], offset: ?>>
linalg.add ins(%subview_4, %alloc_6 : memref<32x32xf16, strided<[512, 1], offset: ?>>, memref<32x32xf16>) outs(%subview_7 : memref<32x32xf16, strided<[512, 1], offset: ?>>)
memref.dealloc %alloc_6 : memref<32x32xf16>
scf.reduce
}
memref.dealloc %alloc : memref<512x512xf16>
return
} We can potentially lower Lowering 'linalg.broadcast' to xegpu%subview_5 = memref.subview %collapse_shape[%arg6] [32] [1] : memref<512xf16> to memref<32xf16, strided<[1], offset: ?>>
// %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf16> // allocating buffer for the result of 'linalg.broadcast'
// linalg.broadcast ins(%subview_5 : memref<32xf16, strided<[1], offset: ?>>) outs(%alloc_6 : memref<32x32xf16>) dimensions = [0]
%expand_shape = memref.expand_shape %subview_5 [[0, 1]] output_shape [32, 1] : memref<32xf16, strided<[1], offset: ?>> into memref<32x1xf16, strided<[1, 1], offset: ?>>
%51 = xegpu.create_nd_tdesc %expand_shape[%c0, %c0] : memref<32x1xf16, strided<[1, 1], offset: ?>> -> !xegpu.tensor_desc<32x1xf16, #xegpu.block_tdesc_attr<memory_scope = global, array_length = 1 : i64, boundary_check = true>>
%52 = xegpu.update_nd_offset %51, [%c0, %c0] : !xegpu.tensor_desc<32x1xf16, #xegpu.block_tdesc_attr<memory_scope = global, array_length = 1 : i64, boundary_check = true>>
%53 = xegpu.load_nd %52 : !xegpu.tensor_desc<32x1xf16, #xegpu.block_tdesc_attr<memory_scope = global, array_length = 1 : i64, boundary_check = true>> -> vector<32x1xf16>
%54 = vector.broadcast %53 : vector<32x1xf16> to vector<32x32xf16> However the chain of In order to fix that problem we have to get rid of Getting rid of collapse_shape + extend_shape// ORIGINAL:
func.func @entry(%arg0: memref<512x512xf16>, %arg1: memref<512x512xf16>, %arg2: memref<1x512xf16>, %arg3: memref<512x512xf16>, %arg4: memref<i8>) {
...
// have to get rid of this collapse_shape
%collapse_shape = memref.collapse_shape %arg2 [[0, 1]] : memref<1x512xf16> into memref<512xf16>
scf.parallel (%arg5, %arg6) = (%c0, %c0) to (%c512, %c512) step (%c32, %c32) {
...
%subview_2 = memref.subview %collapse_shape[%arg6] [32] [1] : memref<512xf16> to memref<32xf16, strided<[1], offset: ?>>
// have to get rid of this expand_shape
%expand_shape = memref.expand_shape %subview_2 [[0, 1]] output_shape [32, 1] : memref<32xf16, strided<[1], offset: ?>> into memref<32x1xf16, strided<[1, 1], offset: ?>>
%51 = xegpu.create_nd_tdesc %expand_shape[%c0, %c0] : memref<32x1xf16, strided<[1, 1], offset: ?>> -> !xegpu.tensor_desc<32x1xf16, #xegpu.block_tdesc_attr<memory_scope = global, array_length = 1 : i64, boundary_check = true>>
%52 = xegpu.update_nd_offset %51, [%c0, %c0] : !xegpu.tensor_desc<32x1xf16, #xegpu.block_tdesc_attr<memory_scope = global, array_length = 1 : i64, boundary_check = true>>
%53 = xegpu.load_nd %52 : !xegpu.tensor_desc<32x1xf16, #xegpu.block_tdesc_attr<memory_scope = global, array_length = 1 : i64, boundary_check = true>> -> vector<32x1xf16>
%54 = vector.broadcast %53 : vector<32x1xf16> to vector<32x32xf16>
...
}
return
}
// EXPECTED:
func.func @entry(%arg0: memref<512x512xf16>, %arg1: memref<512x512xf16>, %arg2: memref<1x512xf16>, %arg3: memref<512x512xf16>, %arg4: memref<i8>) {
...
// removed collapse_shape
// %collapse_shape = memref.collapse_shape %arg2 [[0, 1]] : memref<1x512xf16> into memref<512xf16>
scf.parallel (%arg5, %arg6) = (%c0, %c0) to (%c512, %c512) step (%c32, %c32) {
...
// broadcast %arg2 directly, no need to call 'expand_shape'
%subview_2 = memref.subview %arg2[%c0, %arg6] [32, 32] [1, 1] : memref<1x512xf16> to memref<1x32xf16, strided<[1, 1], offset: ?>>
%51 = xegpu.create_nd_tdesc %subview_2[%c0, %c0] : memref<1x32xf16, strided<[1, 1], offset: ?>> -> !xegpu.tensor_desc<32x1xf16, #xegpu.block_tdesc_attr<memory_scope = global, array_length = 1 : i64, boundary_check = true>>
%52 = xegpu.update_nd_offset %51, [%c0, %c0] : !xegpu.tensor_desc<32x1xf16, #xegpu.block_tdesc_attr<memory_scope = global, array_length = 1 : i64, boundary_check = true>>
%53 = xegpu.load_nd %52 : !xegpu.tensor_desc<32x1xf16, #xegpu.block_tdesc_attr<memory_scope = global, array_length = 1 : i64, boundary_check = true>> -> vector<32x1xf16>
%54 = vector.broadcast %53 : vector<32x1xf16> to vector<32x32xf16>
...
}
return
} |
Generate case:
Reproduce:
Error log:
The text was updated successfully, but these errors were encountered: