Skip to content

Commit 2c9ddfc

Browse files
[mlir][Tosa] fix fp16/bf16 support for AvgPool2d (#68718)
Currently, the AvgPool2d operation in the TOSA MLIR dialect does not accept half-precision Fp16 and Bf16 tensors, conversely to what stated in the [TOSA specification](https://www.mlplatform.org/tosa/tosa_spec.html#_avg_pool2d). This issue was previously raised: #63424 here on Github and it is due to a bug in the AvgPool2d verifier. This patch fixes the AvgPool2d verifier to accept fp16 & bf16 datatype for input/output tensors and accumulator, and it adds related LIT test cases in Tosa/ops.mlir.
1 parent 74c5e47 commit 2c9ddfc

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -247,18 +247,20 @@ LogicalResult tosa::AvgPool2dOp::verify() {
247247
if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
248248
return emitOpError("accumulator type for integer tensor is not i32");
249249

250-
if ((inputETy.isBF16() || inputETy.isF16()) &&
251-
!(accType.isF16() || accType.isF32()))
252-
return emitOpError("accumulator type for f16/bf16 tensor is not f16/f32");
250+
if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
251+
return emitOpError("accumulator type for f16 tensor is not f16/f32");
252+
253+
if (inputETy.isBF16() && !accType.isF32())
254+
return emitOpError("accumulator type for bf16 tensor is not f32");
253255

254256
if (inputETy.isF32() && !accType.isF32())
255257
return emitOpError("accumulator type for f32 tensor is not f32");
256258

257-
if (inputETy.isF32() && resultETy.isF32())
258-
return success();
259-
if (inputETy.isInteger(8) && resultETy.isInteger(8))
260-
return success();
261-
if (inputETy.isInteger(16) && resultETy.isInteger(16))
259+
if ((inputETy.isF32() && resultETy.isF32()) ||
260+
(inputETy.isF16() && resultETy.isF16()) ||
261+
(inputETy.isBF16() && resultETy.isBF16()) ||
262+
(inputETy.isInteger(8) && resultETy.isInteger(8)) ||
263+
(inputETy.isInteger(16) && resultETy.isInteger(16)))
262264
return success();
263265

264266
return emitOpError("input/output element types are incompatible.");

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,20 @@ func.func @test_avg_pool2d_f32(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32
1616
return %0 : tensor<1x7x7x9xf32>
1717
}
1818

19+
// -----
20+
// CHECK-LABEL: avg_pool2d_f16
21+
func.func @test_avg_pool2d_f16(%arg0: tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16> {
22+
%0 = tosa.avg_pool2d %arg0 {acc_type = f16, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16>
23+
return %0 : tensor<1x7x7x9xf16>
24+
}
25+
26+
// -----
27+
// CHECK-LABEL: avg_pool2d_f16_accumf32
28+
func.func @test_avg_pool2d_f16_accumf32(%arg0: tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16> {
29+
%0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16>
30+
return %0 : tensor<1x7x7x9xf16>
31+
}
32+
1933
// -----
2034
// CHECK-LABEL: avg_pool2d_i8
2135
func.func @test_avg_pool2d_i8(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> {

0 commit comments

Comments
 (0)