Skip to content

Commit cc15bcf

Browse files
committed
[mlir][arith] Canonicalization patterns for arith.select
This adds the following canonicalization patterns: - Inverting condition: - `select(not(pred), a, b) => select(pred, b, a)` - Merging consecutive selects with the same predicate - `select(pred, select(pred, a, b), c) => select(pred, a, c)` - `select(pred, a, select(pred, b, c)) => select(pred, a, c)` - Merging consecutive selects with a common value value: - `select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y)` - `select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, y)` - `select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y)` - `select(predA, x, select(predB, y, x)) => select(or(predA, not(predB)), x, y)`
1 parent 7a46baa commit cc15bcf

File tree

3 files changed

+123
-1
lines changed

3 files changed

+123
-1
lines changed

mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,50 @@ def CmpIExtUI :
233233
CPred<"$0.getValue() == arith::CmpIPredicate::eq || "
234234
"$0.getValue() == arith::CmpIPredicate::ne">> $pred)]>;
235235

236+
//===----------------------------------------------------------------------===//
237+
// SelectOp
238+
//===----------------------------------------------------------------------===//
239+
240+
// select(not(pred), a, b) => select(pred, b, a)
241+
def SelectNotCond :
242+
Pat<(SelectOp (Arith_XOrIOp $pred, (ConstantLikeMatcher APIntAttr:$ones)), $a, $b),
243+
(SelectOp $pred, $b, $a),
244+
[(IsScalarOrSplatNegativeOne $ones)]>;
245+
246+
// select(pred, select(pred, a, b), c) => select(pred, a, c)
247+
def RedundantSelectTrue :
248+
Pat<(SelectOp $pred, (SelectOp $pred, $a, $b), $c),
249+
(SelectOp $pred, $a, $c)>;
250+
251+
// select(pred, a, select(pred, b, c)) => select(pred, a, c)
252+
def RedundantSelectFalse :
253+
Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
254+
(SelectOp $pred, $a, $c)>;
255+
256+
// select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y)
257+
def SelectAndCond :
258+
Pat<(SelectOp $predA, (SelectOp $predB, $x, $y), $y),
259+
(SelectOp (Arith_AndIOp $predA, $predB), $x, $y)>;
260+
261+
// select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, y)
262+
def SelectAndNotCond :
263+
Pat<(SelectOp $predA, (SelectOp $predB, $y, $x), $y),
264+
(SelectOp (Arith_AndIOp $predA,
265+
(Arith_XOrIOp $predB, (Arith_ConstantOp ConstantAttr<I1Attr, "1">))),
266+
$x, $y)>;
267+
268+
// select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y)
269+
def SelectOrCond :
270+
Pat<(SelectOp $predA, $x, (SelectOp $predB, $x, $y)),
271+
(SelectOp (Arith_OrIOp $predA, $predB), $x, $y)>;
272+
273+
// select(predA, x, select(predB, y, x)) => select(or(predA, not(predB)), x, y)
274+
def SelectOrNotCond :
275+
Pat<(SelectOp $predA, $x, (SelectOp $predB, $y, $x)),
276+
(SelectOp (Arith_OrIOp $predA,
277+
(Arith_XOrIOp $predB, (Arith_ConstantOp ConstantAttr<I1Attr, "1">))),
278+
$x, $y)>;
279+
236280
//===----------------------------------------------------------------------===//
237281
// IndexCastOp
238282
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2212,7 +2212,9 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
22122212

22132213
void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
22142214
MLIRContext *context) {
2215-
results.add<SelectI1Simplify, SelectToExtUI>(context);
2215+
results.add<RedundantSelectFalse, RedundantSelectTrue, SelectI1Simplify,
2216+
SelectAndCond, SelectAndNotCond, SelectOrCond, SelectOrNotCond,
2217+
SelectNotCond, SelectToExtUI>(context);
22162218
}
22172219

22182220
OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,82 @@ func.func @selToArith(%arg0: i1, %arg1 : i1, %arg2 : i1) -> i1 {
128128
return %res : i1
129129
}
130130

131+
// CHECK-LABEL: @redundantSelectTrue
132+
// CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg1, %arg3
133+
// CHECK-NEXT: return %[[res]]
134+
func.func @redundantSelectTrue(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
135+
%0 = arith.select %arg0, %arg1, %arg2 : i32
136+
%res = arith.select %arg0, %0, %arg3 : i32
137+
return %res : i32
138+
}
139+
140+
// CHECK-LABEL: @redundantSelectFalse
141+
// CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg3, %arg2
142+
// CHECK-NEXT: return %[[res]]
143+
func.func @redundantSelectFalse(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
144+
%0 = arith.select %arg0, %arg1, %arg2 : i32
145+
%res = arith.select %arg0, %arg3, %0 : i32
146+
return %res : i32
147+
}
148+
149+
// CHECK-LABEL: @selNotCond
150+
// CHECK-NEXT: %[[res1:.+]] = arith.select %arg0, %arg2, %arg1
151+
// CHECK-NEXT: %[[res2:.+]] = arith.select %arg0, %arg4, %arg3
152+
// CHECK-NEXT: return %[[res1]], %[[res2]]
153+
func.func @selNotCond(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : i32) -> (i32, i32) {
154+
%one = arith.constant 1 : i1
155+
%cond1 = arith.xori %arg0, %one : i1
156+
%cond2 = arith.xori %one, %arg0 : i1
157+
158+
%res1 = arith.select %cond1, %arg1, %arg2 : i32
159+
%res2 = arith.select %cond2, %arg3, %arg4 : i32
160+
return %res1, %res2 : i32, i32
161+
}
162+
163+
// CHECK-LABEL: @selAndCond
164+
// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %arg0
165+
// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg2, %arg3
166+
// CHECK-NEXT: return %[[res]]
167+
func.func @selAndCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
168+
%sel = arith.select %arg0, %arg2, %arg3 : i32
169+
%res = arith.select %arg1, %sel, %arg3 : i32
170+
return %res : i32
171+
}
172+
173+
// CHECK-LABEL: @selAndNotCond
174+
// CHECK-NEXT: %[[one:.+]] = arith.constant true
175+
// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
176+
// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %[[not]]
177+
// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg3, %arg2
178+
// CHECK-NEXT: return %[[res]]
179+
func.func @selAndNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
180+
%sel = arith.select %arg0, %arg2, %arg3 : i32
181+
%res = arith.select %arg1, %sel, %arg2 : i32
182+
return %res : i32
183+
}
184+
185+
// CHECK-LABEL: @selOrCond
186+
// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %arg0
187+
// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg2, %arg3
188+
// CHECK-NEXT: return %[[res]]
189+
func.func @selOrCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
190+
%sel = arith.select %arg0, %arg2, %arg3 : i32
191+
%res = arith.select %arg1, %arg2, %sel : i32
192+
return %res : i32
193+
}
194+
195+
// CHECK-LABEL: @selOrNotCond
196+
// CHECK-NEXT: %[[one:.+]] = arith.constant true
197+
// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
198+
// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %[[not]]
199+
// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg3, %arg2
200+
// CHECK-NEXT: return %[[res]]
201+
func.func @selOrNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
202+
%sel = arith.select %arg0, %arg2, %arg3 : i32
203+
%res = arith.select %arg1, %arg3, %sel : i32
204+
return %res : i32
205+
}
206+
131207
// Test case: Folding of comparisons with equal operands.
132208
// CHECK-LABEL: @cmpi_equal_operands
133209
// CHECK-DAG: %[[T:.*]] = arith.constant true

0 commit comments

Comments
 (0)