Skip to content

Commit 52ff01d

Browse files
committed
fix floordivi expand error logic
1 parent 40faadb commit 52ff01d

File tree

3 files changed

+81
-52
lines changed

3 files changed

+81
-52
lines changed

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,10 @@ OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
692692
bool overflowOrDiv = false;
693693
auto result = constFoldBinaryOp<IntegerAttr>(
694694
adaptor.getOperands(), [&](APInt a, const APInt &b) {
695+
if (b.isZero()) {
696+
overflowOrDiv = true;
697+
return a;
698+
}
695699
return a.sfloordiv_ov(b, overflowOrDiv);
696700
});
697701

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,53 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
110110
}
111111
};
112112

113+
/// Expands FloorDivSIOp (x, y) into
114+
/// z = x / y
115+
/// if (z * y != x && (x < 0) != (y < 0)) {
116+
/// return z - 1;
117+
/// } else {
118+
/// return z;
119+
/// }
120+
struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
121+
using OpRewritePattern::OpRewritePattern;
122+
LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
123+
PatternRewriter &rewriter) const final {
124+
Location loc = op.getLoc();
125+
Type type = op.getType();
126+
Value a = op.getLhs();
127+
Value b = op.getRhs();
128+
129+
Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
130+
Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
131+
Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
132+
loc, arith::CmpIPredicate::ne, a, product);
133+
Value zero = createConst(loc, type, 0, rewriter);
134+
135+
Value aNeg =
136+
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
137+
Value bNeg =
138+
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
139+
140+
Value signOpposite = rewriter.create<arith::CmpIOp>(
141+
loc, arith::CmpIPredicate::ne, aNeg, bNeg);
142+
Value cond =
143+
rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signOpposite);
144+
145+
Value minusOne = createConst(loc, type, -1, rewriter);
146+
Value quotientMinusOne =
147+
rewriter.create<arith::SubIOp>(loc, quotient, minusOne);
148+
149+
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
150+
quotient);
151+
return success();
152+
}
153+
};
154+
113155
/// Expands FloorDivSIOp (n, m) into
114156
/// 1) x = (m<0) ? 1 : -1
115157
/// 2) return (n*m<0) ? - ((-n+x) / m) -1 : n / m
116-
struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
158+
struct AggressiveFloorDivSIOpConverter
159+
: public OpRewritePattern<arith::FloorDivSIOp> {
117160
using OpRewritePattern::OpRewritePattern;
118161
LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
119162
PatternRewriter &rewriter) const final {

mlir/test/Dialect/Arith/expand-ops.mlir

Lines changed: 33 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -66,23 +66,17 @@ func.func @ceildivi_index(%arg0: index, %arg1: index) -> (index) {
6666
func.func @floordivi(%arg0: i32, %arg1: i32) -> (i32) {
6767
%res = arith.floordivsi %arg0, %arg1 : i32
6868
return %res : i32
69-
// CHECK: [[ONE:%.+]] = arith.constant 1 : i32
70-
// CHECK: [[ZERO:%.+]] = arith.constant 0 : i32
71-
// CHECK: [[MIN1:%.+]] = arith.constant -1 : i32
72-
// CHECK: [[CMP1:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32
73-
// CHECK: [[X:%.+]] = arith.select [[CMP1]], [[ONE]], [[MIN1]] : i32
74-
// CHECK: [[TRUE1:%.+]] = arith.subi [[X]], [[ARG0]] : i32
75-
// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : i32
76-
// CHECK: [[TRUE3:%.+]] = arith.subi [[MIN1]], [[TRUE2]] : i32
77-
// CHECK: [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : i32
78-
// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : i32
79-
// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : i32
80-
// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32
81-
// CHECK: [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32
82-
// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MPOS]] : i1
83-
// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MNEG]] : i1
84-
// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
85-
// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE]] : i32
69+
// CHECK: %[[QUOTIENT:.*]] = arith.divsi %arg0, %arg1 : i32
70+
// CHECK: %[[PRODUCT:.*]] = arith.muli %[[QUOTIENT]], %arg1 : i32
71+
// CHECK: %[[NOT_EQ_PRODUCT:.*]] = arith.cmpi ne, %arg0, %[[PRODUCT]] : i32
72+
// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32
73+
// CHECK: %[[NEG_DIVISOR:.*]] = arith.cmpi slt, %arg0, %[[ZERO]] : i32
74+
// CHECK: %[[NEG_DIVIDEND:.*]] = arith.cmpi slt, %arg1, %[[ZERO]] : i32
75+
// CHECK: %[[OPPOSITE_SIGN:.*]] = arith.cmpi ne, %[[NEG_DIVISOR]], %[[NEG_DIVIDEND]] : i1
76+
// CHECK: %[[CONDITION:.*]] = arith.andi %[[NOT_EQ_PRODUCT]], %[[OPPOSITE_SIGN]] : i1
77+
// CHECK-DAG: %[[NEG_ONE:.*]] = arith.constant -1 : i32
78+
// CHECK: %[[MINUS_ONE:.*]] = arith.subi %[[QUOTIENT]], %[[NEG_ONE]] : i32
79+
// CHECK: %[[RES:.*]] = arith.select %[[CONDITION]], %[[MINUS_ONE]], %[[QUOTIENT]] : i32
8680
}
8781

8882
// -----
@@ -93,23 +87,17 @@ func.func @floordivi(%arg0: i32, %arg1: i32) -> (i32) {
9387
func.func @floordivi_index(%arg0: index, %arg1: index) -> (index) {
9488
%res = arith.floordivsi %arg0, %arg1 : index
9589
return %res : index
96-
// CHECK: [[ONE:%.+]] = arith.constant 1 : index
97-
// CHECK: [[ZERO:%.+]] = arith.constant 0 : index
98-
// CHECK: [[MIN1:%.+]] = arith.constant -1 : index
99-
// CHECK: [[CMP1:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
100-
// CHECK: [[X:%.+]] = arith.select [[CMP1]], [[ONE]], [[MIN1]] : index
101-
// CHECK: [[TRUE1:%.+]] = arith.subi [[X]], [[ARG0]] : index
102-
// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index
103-
// CHECK: [[TRUE3:%.+]] = arith.subi [[MIN1]], [[TRUE2]] : index
104-
// CHECK: [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : index
105-
// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index
106-
// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index
107-
// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
108-
// CHECK: [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
109-
// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MPOS]] : i1
110-
// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MNEG]] : i1
111-
// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
112-
// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE]] : index
90+
// CHECK: %[[QUOTIENT:.*]] = arith.divsi %arg0, %arg1 : index
91+
// CHECK: %[[PRODUCT:.*]] = arith.muli %[[QUOTIENT]], %arg1 : index
92+
// CHECK: %[[NOT_EQ_PRODUCT:.*]] = arith.cmpi ne, %arg0, %[[PRODUCT]] : index
93+
// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index
94+
// CHECK: %[[NEG_DIVISOR:.*]] = arith.cmpi slt, %arg0, %[[ZERO]] : index
95+
// CHECK: %[[NEG_DIVIDEND:.*]] = arith.cmpi slt, %arg1, %[[ZERO]] : index
96+
// CHECK: %[[OPPOSITE_SIGN:.*]] = arith.cmpi ne, %[[NEG_DIVISOR]], %[[NEG_DIVIDEND]] : i1
97+
// CHECK: %[[CONDITION:.*]] = arith.andi %[[NOT_EQ_PRODUCT]], %[[OPPOSITE_SIGN]] : i1
98+
// CHECK: %[[NEG_ONE:.*]] = arith.constant -1 : index
99+
// CHECK-DAG: %[[MINUS_ONE:.*]] = arith.subi %[[QUOTIENT]], %[[NEG_ONE]] : index
100+
// CHECK: %[[RES:.*]] = arith.select %[[CONDITION]], %[[MINUS_ONE]], %[[QUOTIENT]] : index
113101
}
114102

115103
// -----
@@ -121,23 +109,17 @@ func.func @floordivi_index(%arg0: index, %arg1: index) -> (index) {
121109
func.func @floordivi_vec(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> (vector<4xi32>) {
122110
%res = arith.floordivsi %arg0, %arg1 : vector<4xi32>
123111
return %res : vector<4xi32>
124-
// CHECK: %[[VAL_2:.*]] = arith.constant dense<1> : vector<4xi32>
125-
// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : vector<4xi32>
126-
// CHECK: %[[VAL_4:.*]] = arith.constant dense<-1> : vector<4xi32>
127-
// CHECK: %[[VAL_5:.*]] = arith.cmpi slt, %[[VAL_1]], %[[VAL_3]] : vector<4xi32>
128-
// CHECK: %[[VAL_6:.*]] = arith.select %[[VAL_5]], %[[VAL_2]], %[[VAL_4]] : vector<4xi1>, vector<4xi32>
129-
// CHECK: %[[VAL_7:.*]] = arith.subi %[[VAL_6]], %[[VAL_0]] : vector<4xi32>
130-
// CHECK: %[[VAL_8:.*]] = arith.divsi %[[VAL_7]], %[[VAL_1]] : vector<4xi32>
131-
// CHECK: %[[VAL_9:.*]] = arith.subi %[[VAL_4]], %[[VAL_8]] : vector<4xi32>
132-
// CHECK: %[[VAL_10:.*]] = arith.divsi %[[VAL_0]], %[[VAL_1]] : vector<4xi32>
133-
// CHECK: %[[VAL_11:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_3]] : vector<4xi32>
134-
// CHECK: %[[VAL_12:.*]] = arith.cmpi sgt, %[[VAL_0]], %[[VAL_3]] : vector<4xi32>
135-
// CHECK: %[[VAL_13:.*]] = arith.cmpi slt, %[[VAL_1]], %[[VAL_3]] : vector<4xi32>
136-
// CHECK: %[[VAL_14:.*]] = arith.cmpi sgt, %[[VAL_1]], %[[VAL_3]] : vector<4xi32>
137-
// CHECK: %[[VAL_15:.*]] = arith.andi %[[VAL_11]], %[[VAL_14]] : vector<4xi1>
138-
// CHECK: %[[VAL_16:.*]] = arith.andi %[[VAL_12]], %[[VAL_13]] : vector<4xi1>
139-
// CHECK: %[[VAL_17:.*]] = arith.ori %[[VAL_15]], %[[VAL_16]] : vector<4xi1>
140-
// CHECK: %[[VAL_18:.*]] = arith.select %[[VAL_17]], %[[VAL_9]], %[[VAL_10]] : vector<4xi1>, vector<4xi32>
112+
// CHECK: %[[QUOTIENT:.*]] = arith.divsi %arg0, %arg1 : vector<4xi32>
113+
// CHECK: %[[PRODUCT:.*]] = arith.muli %[[QUOTIENT]], %arg1 : vector<4xi32>
114+
// CHECK: %[[NOT_EQ_PRODUCT:.*]] = arith.cmpi ne, %arg0, %[[PRODUCT]] : vector<4xi32>
115+
// CHECK-DAG: %[[ZERO:.*]] = arith.constant dense<0> : vector<4xi32>
116+
// CHECK: %[[NEG_DIVISOR:.*]] = arith.cmpi slt, %arg0, %[[ZERO]] : vector<4xi32>
117+
// CHECK: %[[NEG_DIVIDEND:.*]] = arith.cmpi slt, %arg1, %[[ZERO]] : vector<4xi32>
118+
// CHECK: %[[OPPOSITE_SIGN:.*]] = arith.cmpi ne, %[[NEG_DIVISOR]], %[[NEG_DIVIDEND]] : vector<4xi1>
119+
// CHECK: %[[CONDITION:.*]] = arith.andi %[[NOT_EQ_PRODUCT]], %[[OPPOSITE_SIGN]] : vector<4xi1>
120+
// CHECK-DAG: %[[NEG_ONE:.*]] = arith.constant dense<-1> : vector<4xi32>
121+
// CHECK: %[[MINUS_ONE:.*]] = arith.subi %[[QUOTIENT]], %[[NEG_ONE]] : vector<4xi32>
122+
// CHECK: %[[RES:.*]] = arith.select %[[CONDITION]], %[[MINUS_ONE]], %[[QUOTIENT]] : vector<4xi1>, vector<4xi32>
141123
}
142124

143125
// -----

0 commit comments

Comments
 (0)