Skip to content

Commit c926291

Browse files
authored
[MLIR][TOSA] Fix f16/bf16 support for MaxPool2D (llvm#69332)
Currently, the MaxPool2D operation in the TOSA MLIR dialect does not accept half-precision Fp16 and Bf16 tensors, converse to what is stated in the [TOSA Specification](https://www.mlplatform.org/tosa/tosa_spec.html#_max_pool2d). This patch fixes the verifier to accept the two datatypes for input/output tensors, and adds related LIT test cases in Tosa/ops.mlir
1 parent 0dca566 commit c926291

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
691691

692692
// Determine what the initial value needs to be for the max pool op.
693693
TypedAttr initialAttr;
694-
if (resultETy.isF32())
694+
if (resultETy.isF32() || resultETy.isBF16() || resultETy.isF16())
695695
initialAttr = rewriter.getFloatAttr(
696696
resultETy, APFloat::getLargest(
697697
cast<FloatType>(resultETy).getFloatSemantics(), true));

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,26 @@ func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -
9797
}
9898

9999
// -----
100-
// CHECK-LABEL: max_pool2d
101-
func.func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
100+
// CHECK-LABEL: max_pool2d_f32
101+
func.func @test_max_pool2d_f32(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
102102
%0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
103103
return %0 : tensor<1x32x32x8xf32>
104104
}
105105

106+
// -----
107+
// CHECK-LABEL: max_pool2d_bf16
108+
func.func @test_max_pool2d_bf16(%arg0: tensor<1x32x32x8xbf16>) -> tensor<1x32x32x8xbf16> {
109+
%0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xbf16>) -> tensor<1x32x32x8xbf16>
110+
return %0 : tensor<1x32x32x8xbf16>
111+
}
112+
113+
// -----
114+
// CHECK-LABEL: max_pool2d_f16
115+
func.func @test_max_pool2d_f16(%arg0: tensor<1x32x32x8xf16>) -> tensor<1x32x32x8xf16> {
116+
%0 = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf16>) -> tensor<1x32x32x8xf16>
117+
return %0 : tensor<1x32x32x8xf16>
118+
}
119+
106120
// -----
107121
// CHECK-LABEL: rfft2d
108122
func.func @test_rfft2d(%arg0: tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {

0 commit comments

Comments
 (0)