Skip to content

Commit f42c7e4

Browse files
authored
[Linalg] Add conversion between bf16 and f16 (#3963)
To fix issue #3962 : 'arith.extf' op operand type 'bf16' and result type 'f16' are cast incompatible
1 parent 5e1d68e commit f42c7e4

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

lib/Conversion/Utils/Utils.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,10 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
335335

336336
if (auto dtypeFloat = dyn_cast<mlir::FloatType>(dtype)) {
337337
if (auto scalarFloat = dyn_cast<mlir::FloatType>(scalarType)) {
338+
if (scalarFloat.getWidth() == 16 && dtypeFloat.getWidth() == 16) {
339+
auto scalarF32 = b.create<arith::ExtFOp>(loc, b.getF32Type(), scalar);
340+
return b.create<arith::TruncFOp>(loc, dtype, scalarF32);
341+
}
338342
if (scalarFloat.getWidth() > dtypeFloat.getWidth())
339343
return b.create<arith::TruncFOp>(loc, dtype, scalar);
340344
// Only scalarFloat width < dtypeFloat width can reach here.

test/Conversion/TorchToLinalg/elementwise.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,19 @@ func.func @elementwise_sinh(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3
102102
%0 = torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
103103
return %0 : !torch.vtensor<[3],f32>
104104
}
105+
106+
// -----
107+
108+
// CHECK-LABEL: func.func @elementwise_todtype_bf162f16(
109+
// CHECK: linalg.generic
110+
// CHECK: arith.extf
111+
// CHECK-SAME: bf16 to f32
112+
// CHECK: arith.truncf
113+
// CHECK-SAME: f32 to f16
114+
func.func @elementwise_todtype_bf162f16(%arg0: !torch.vtensor<[1,?,32,128],bf16>) -> !torch.vtensor<[1,?,32,128],f16> {
115+
%int5 = torch.constant.int 5
116+
%false = torch.constant.bool false
117+
%none = torch.constant.none
118+
%0 = torch.aten.to.dtype %arg0, %int5, %false, %false, %none : !torch.vtensor<[1,?,32,128],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,?,32,128],f16>
119+
return %0 : !torch.vtensor<[1,?,32,128],f16>
120+
}

0 commit comments

Comments
 (0)