Skip to content

Commit d57479c

Browse files
authored
[mlir][tosa] Update SelectOp's input names to match TOSA specification (#127833)
Updated: - pred to input1 - on_true to input2 - on_false to input3 Signed-off-by: Jerry Ge <[email protected]>
1 parent a6f48ed commit d57479c

File tree

3 files changed

+14
-14
lines changed

3 files changed

+14
-14
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1190,9 +1190,9 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
11901190
}];
11911191

11921192
let arguments = (ins
1193-
Tosa_I1Tensor:$pred,
1194-
Tosa_Tensor:$on_true,
1195-
Tosa_Tensor:$on_false
1193+
Tosa_I1Tensor:$input1,
1194+
Tosa_Tensor:$input2,
1195+
Tosa_Tensor:$input3
11961196
);
11971197

11981198
let results = (outs
@@ -1202,7 +1202,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
12021202
let hasFolder = 1;
12031203

12041204
let assemblyFormat = [{
1205-
operands attr-dict `:` `(` type($pred) `,` type($on_true) `,` type($on_false)
1205+
operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
12061206
`)` `->` type($output)
12071207
}];
12081208
}

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,12 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
6565
}
6666

6767
LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
68-
auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
68+
auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
6969
if (!notOp)
7070
return failure();
7171
rewriter.modifyOpInPlace(op, [&]() {
7272
op.getOperation()->setOperands(
73-
{notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
73+
{notOp.getInput1(), op.getInput3(), op.getInput2()});
7474
});
7575
return success();
7676
}
@@ -1131,18 +1131,18 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
11311131
}
11321132

11331133
OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1134-
if (getOnTrue() == getOnFalse())
1135-
return getOnTrue();
1134+
if (getInput2() == getInput3())
1135+
return getInput2();
11361136

11371137
auto predicate =
1138-
llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
1138+
llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
11391139
if (!predicate)
11401140
return {};
11411141

11421142
if (!predicate.isSplat())
11431143
return {};
1144-
return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
1145-
: getOnFalse();
1144+
return predicate.getSplatValue<APInt>().getBoolValue() ? getInput2()
1145+
: getInput3();
11461146
}
11471147

11481148
OpFoldResult TileOp::fold(FoldAdaptor adaptor) {

mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> {
169169
LogicalResult matchAndRewrite(tosa::SelectOp tosaOp,
170170
PatternRewriter &rewriter) const override {
171171

172-
Value input1 = tosaOp.getPred();
173-
Value input2 = tosaOp.getOnTrue();
174-
Value input3 = tosaOp.getOnFalse();
172+
Value input1 = tosaOp.getInput1();
173+
Value input2 = tosaOp.getInput2();
174+
Value input3 = tosaOp.getInput3();
175175
Value output = tosaOp.getResult();
176176

177177
auto outputType = dyn_cast<RankedTensorType>(output.getType());

0 commit comments

Comments
 (0)