Skip to content

Commit 7c11d05

Browse files
authored
[mlir][TosaToLinalg] Exit after notifyMatchFailure (#132012)
This PR adds `return nullptr` when the shift value of `tosa.mul` is not constant to prevent a crash. Fixes #131766.
1 parent 541d6c3 commit 7c11d05

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,14 @@ static Value createLinalgBodyCalculationForElementwiseOp(
136136

137137
// tosa::MulOp
138138
if (isa<tosa::MulOp>(op)) {
139-
auto shift_val = cast<tosa::MulOp>(op).getShift();
140-
DenseElementsAttr shift_elem;
141-
if (!shift_val.getImpl() ||
142-
!matchPattern(shift_val, m_Constant(&shift_elem))) {
139+
auto shiftVal = cast<tosa::MulOp>(op).getShift();
140+
DenseElementsAttr shiftElem;
141+
if (!matchPattern(shiftVal, m_Constant(&shiftElem))) {
143142
(void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
143+
return nullptr;
144144
}
145145

146-
int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
146+
int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
147147

148148
if (isa<FloatType>(elementTy)) {
149149
if (shift != 0) {

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,11 @@ func.func @unranked_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>)
7373
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<*xf32>
7474
return %0 : tensor<*xf32>
7575
}
76+
77+
// -----
78+
79+
func.func @mul_no_const_shift(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<1xi8>) -> tensor<2x3xi32> {
80+
// expected-error@+1 {{failed to legalize operation 'tosa.mul'}}
81+
%0 = tosa.mul %arg0, %arg1, %arg2 : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
82+
return %0 : tensor<2x3xi32>
83+
}

0 commit comments

Comments
 (0)