Skip to content

Commit 93ffe17

Browse files
authored
[mlir][tosa] Only match rfft2d of floats in linalg conversion (#93432)
This prevents an assertion being triggered by the cast to FloatType. Fixes #92064
1 parent c091dd4 commit 93ffe17

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2324,7 +2324,10 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
23242324
auto loc = rfft2d.getLoc();
23252325
auto input = rfft2d.getInput();
23262326
auto elementType =
2327-
cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
2327+
dyn_cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
2328+
if (!elementType)
2329+
return rewriter.notifyMatchFailure(rfft2d,
2330+
"only supports float element types");
23282331

23292332
// Compute the output type and set of dynamic sizes
23302333
llvm::SmallVector<Value> dynamicSizes;

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,12 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
2727
%2 = tosa.reshape %0 {new_shape = array<i64: 10, 10>} : (tensor<*xf32>) -> tensor<10x10xf32>
2828
return %2 : tensor<10x10xf32>
2929
}
30+
31+
// -----
32+
33+
// CHECK-LABEL: @rfft2d_with_non_float_type
34+
func.func @rfft2d_with_non_float_type(%arg0 : tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>) {
35+
// expected-error@+1 {{failed to legalize operation 'tosa.rfft2d'}}
36+
%real, %imag = tosa.rfft2d %arg0 : (tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>)
37+
return %real, %imag : tensor<1x1x1xi32>, tensor<1x1x1xi32>
38+
}

0 commit comments

Comments
 (0)