Skip to content

Commit 6dbc6df

Browse files
authored
Reland "[mlir][arith] Canonicalization patterns for arith.select (#67809)" (#68941)
This cherry-picks the changes in llvm-project/5bf701a6687a46fd898621f5077959ff202d716b and extends the pattern to handle vector types. To reuse `getBoolAttribute` method, it moves the static method above the include of generated file.
1 parent b1115f8 commit 6dbc6df

File tree

3 files changed

+161
-10
lines changed

3 files changed

+161
-10
lines changed

mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,55 @@ def CmpIExtUI :
233233
CPred<"$0.getValue() == arith::CmpIPredicate::eq || "
234234
"$0.getValue() == arith::CmpIPredicate::ne">> $pred)]>;
235235

236+
//===----------------------------------------------------------------------===//
237+
// SelectOp
238+
//===----------------------------------------------------------------------===//
239+
240+
def GetScalarOrVectorTrueAttribute :
241+
NativeCodeCall<"cast<TypedAttr>(getBoolAttribute($0.getType(), true))">;
242+
243+
// select(not(pred), a, b) => select(pred, b, a)
244+
def SelectNotCond :
245+
Pat<(SelectOp (Arith_XOrIOp $pred, (ConstantLikeMatcher APIntAttr:$ones)), $a, $b),
246+
(SelectOp $pred, $b, $a),
247+
[(IsScalarOrSplatNegativeOne $ones)]>;
248+
249+
// select(pred, select(pred, a, b), c) => select(pred, a, c)
250+
def RedundantSelectTrue :
251+
Pat<(SelectOp $pred, (SelectOp $pred, $a, $b), $c),
252+
(SelectOp $pred, $a, $c)>;
253+
254+
// select(pred, a, select(pred, b, c)) => select(pred, a, c)
255+
def RedundantSelectFalse :
256+
Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
257+
(SelectOp $pred, $a, $c)>;
258+
259+
// select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y)
260+
def SelectAndCond :
261+
Pat<(SelectOp $predA, (SelectOp $predB, $x, $y), $y),
262+
(SelectOp (Arith_AndIOp $predA, $predB), $x, $y)>;
263+
264+
// select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, y)
265+
def SelectAndNotCond :
266+
Pat<(SelectOp $predA, (SelectOp $predB, $y, $x), $y),
267+
(SelectOp (Arith_AndIOp $predA,
268+
(Arith_XOrIOp $predB,
269+
(Arith_ConstantOp (GetScalarOrVectorTrueAttribute $predB)))),
270+
$x, $y)>;
271+
272+
// select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y)
273+
def SelectOrCond :
274+
Pat<(SelectOp $predA, $x, (SelectOp $predB, $x, $y)),
275+
(SelectOp (Arith_OrIOp $predA, $predB), $x, $y)>;
276+
277+
// select(predA, x, select(predB, y, x)) => select(or(predA, not(predB)), x, y)
278+
def SelectOrNotCond :
279+
Pat<(SelectOp $predA, $x, (SelectOp $predB, $y, $x)),
280+
(SelectOp (Arith_OrIOp $predA,
281+
(Arith_XOrIOp $predB,
282+
(Arith_ConstantOp (GetScalarOrVectorTrueAttribute $predB)))),
283+
$x, $y)>;
284+
236285
//===----------------------------------------------------------------------===//
237286
// IndexCastOp
238287
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,14 @@ static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
113113
return failure();
114114
}
115115

116+
static Attribute getBoolAttribute(Type type, bool value) {
117+
auto boolAttr = BoolAttr::get(type.getContext(), value);
118+
ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
119+
if (!shapedType)
120+
return boolAttr;
121+
return DenseElementsAttr::get(shapedType, boolAttr);
122+
}
123+
116124
//===----------------------------------------------------------------------===//
117125
// TableGen'd canonicalization patterns
118126
//===----------------------------------------------------------------------===//
@@ -1696,14 +1704,6 @@ static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
16961704
llvm_unreachable("unknown cmpi predicate kind");
16971705
}
16981706

1699-
static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1700-
auto boolAttr = BoolAttr::get(ctx, value);
1701-
ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
1702-
if (!shapedType)
1703-
return boolAttr;
1704-
return DenseElementsAttr::get(shapedType, boolAttr);
1705-
}
1706-
17071707
static std::optional<int64_t> getIntegerWidth(Type t) {
17081708
if (auto intType = llvm::dyn_cast<IntegerType>(t)) {
17091709
return intType.getWidth();
@@ -1718,7 +1718,7 @@ OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
17181718
// cmpi(pred, x, x)
17191719
if (getLhs() == getRhs()) {
17201720
auto val = applyCmpPredicateToEqualOperands(getPredicate());
1721-
return getBoolAttribute(getType(), getContext(), val);
1721+
return getBoolAttribute(getType(), val);
17221722
}
17231723

17241724
if (matchPattern(adaptor.getRhs(), m_Zero())) {
@@ -2212,7 +2212,9 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
22122212

22132213
void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
22142214
MLIRContext *context) {
2215-
results.add<SelectI1Simplify, SelectToExtUI>(context);
2215+
results.add<RedundantSelectFalse, RedundantSelectTrue, SelectI1Simplify,
2216+
SelectAndCond, SelectAndNotCond, SelectOrCond, SelectOrNotCond,
2217+
SelectNotCond, SelectToExtUI>(context);
22162218
}
22172219

22182220
OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,106 @@ func.func @selToArith(%arg0: i1, %arg1 : i1, %arg2 : i1) -> i1 {
128128
return %res : i1
129129
}
130130

131+
// CHECK-LABEL: @redundantSelectTrue
132+
// CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg1, %arg3
133+
// CHECK-NEXT: return %[[res]]
134+
func.func @redundantSelectTrue(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
135+
%0 = arith.select %arg0, %arg1, %arg2 : i32
136+
%res = arith.select %arg0, %0, %arg3 : i32
137+
return %res : i32
138+
}
139+
140+
// CHECK-LABEL: @redundantSelectFalse
141+
// CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg3, %arg2
142+
// CHECK-NEXT: return %[[res]]
143+
func.func @redundantSelectFalse(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
144+
%0 = arith.select %arg0, %arg1, %arg2 : i32
145+
%res = arith.select %arg0, %arg3, %0 : i32
146+
return %res : i32
147+
}
148+
149+
// CHECK-LABEL: @selNotCond
150+
// CHECK-NEXT: %[[res1:.+]] = arith.select %arg0, %arg2, %arg1
151+
// CHECK-NEXT: %[[res2:.+]] = arith.select %arg0, %arg4, %arg3
152+
// CHECK-NEXT: return %[[res1]], %[[res2]]
153+
func.func @selNotCond(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : i32) -> (i32, i32) {
154+
%one = arith.constant 1 : i1
155+
%cond1 = arith.xori %arg0, %one : i1
156+
%cond2 = arith.xori %one, %arg0 : i1
157+
158+
%res1 = arith.select %cond1, %arg1, %arg2 : i32
159+
%res2 = arith.select %cond2, %arg3, %arg4 : i32
160+
return %res1, %res2 : i32, i32
161+
}
162+
163+
// CHECK-LABEL: @selAndCond
164+
// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %arg0
165+
// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg2, %arg3
166+
// CHECK-NEXT: return %[[res]]
167+
func.func @selAndCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
168+
%sel = arith.select %arg0, %arg2, %arg3 : i32
169+
%res = arith.select %arg1, %sel, %arg3 : i32
170+
return %res : i32
171+
}
172+
173+
// CHECK-LABEL: @selAndNotCond
174+
// CHECK-NEXT: %[[one:.+]] = arith.constant true
175+
// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
176+
// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %[[not]]
177+
// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg3, %arg2
178+
// CHECK-NEXT: return %[[res]]
179+
func.func @selAndNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
180+
%sel = arith.select %arg0, %arg2, %arg3 : i32
181+
%res = arith.select %arg1, %sel, %arg2 : i32
182+
return %res : i32
183+
}
184+
185+
// CHECK-LABEL: @selAndNotCondVec
186+
// CHECK-NEXT: %[[one:.+]] = arith.constant dense<true> : vector<4xi1>
187+
// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
188+
// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %[[not]]
189+
// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg3, %arg2
190+
// CHECK-NEXT: return %[[res]]
191+
func.func @selAndNotCondVec(%arg0: vector<4xi1>, %arg1: vector<4xi1>, %arg2 : vector<4xi32>, %arg3 : vector<4xi32>) -> vector<4xi32> {
192+
%sel = arith.select %arg0, %arg2, %arg3 : vector<4xi1>, vector<4xi32>
193+
%res = arith.select %arg1, %sel, %arg2 : vector<4xi1>, vector<4xi32>
194+
return %res : vector<4xi32>
195+
}
196+
197+
// CHECK-LABEL: @selOrCond
198+
// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %arg0
199+
// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg2, %arg3
200+
// CHECK-NEXT: return %[[res]]
201+
func.func @selOrCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
202+
%sel = arith.select %arg0, %arg2, %arg3 : i32
203+
%res = arith.select %arg1, %arg2, %sel : i32
204+
return %res : i32
205+
}
206+
207+
// CHECK-LABEL: @selOrNotCond
208+
// CHECK-NEXT: %[[one:.+]] = arith.constant true
209+
// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
210+
// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %[[not]]
211+
// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg3, %arg2
212+
// CHECK-NEXT: return %[[res]]
213+
func.func @selOrNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
214+
%sel = arith.select %arg0, %arg2, %arg3 : i32
215+
%res = arith.select %arg1, %arg3, %sel : i32
216+
return %res : i32
217+
}
218+
219+
// CHECK-LABEL: @selOrNotCondVec
220+
// CHECK-NEXT: %[[one:.+]] = arith.constant dense<true> : vector<4xi1>
221+
// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
222+
// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %[[not]]
223+
// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg3, %arg2
224+
// CHECK-NEXT: return %[[res]]
225+
func.func @selOrNotCondVec(%arg0: vector<4xi1>, %arg1: vector<4xi1>, %arg2 : vector<4xi32>, %arg3 : vector<4xi32>) -> vector<4xi32> {
226+
%sel = arith.select %arg0, %arg2, %arg3 : vector<4xi1>, vector<4xi32>
227+
%res = arith.select %arg1, %arg3, %sel : vector<4xi1>, vector<4xi32>
228+
return %res : vector<4xi32>
229+
}
230+
131231
// Test case: Folding of comparisons with equal operands.
132232
// CHECK-LABEL: @cmpi_equal_operands
133233
// CHECK-DAG: %[[T:.*]] = arith.constant true

0 commit comments

Comments
 (0)