Skip to content

Commit c7ceaad

Browse files
jreifferslravenclaw
authored andcommitted
[mlir] Fold ceil/floordiv with negative RHS. (llvm#97031)
Currently, we only fold if the RHS is a positive constant. There doesn't seem to be a good reason to do that. The comment claims that division by negative values is undefined, but I suspect that was just copied over from the `mod` simplifier.
1 parent c1bd8e6 commit c7ceaad

File tree

2 files changed

+31
-10
lines changed

2 files changed

+31
-10
lines changed

mlir/lib/IR/AffineExpr.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -855,8 +855,7 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
855855
auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
856856
auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
857857

858-
// mlir floordiv by zero or negative numbers is undefined and preserved as is.
859-
if (!rhsConst || rhsConst.getValue() < 1)
858+
if (!rhsConst || rhsConst.getValue() == 0)
860859
return nullptr;
861860

862861
if (lhsConst) {
@@ -875,12 +874,12 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
875874
if (rhsConst == 1)
876875
return lhs;
877876

878-
// Simplify (expr * const) floordiv divConst when expr is known to be a
879-
// multiple of divConst.
877+
// Simplify `(expr * lrhs) floordiv rhsConst` when `lrhs` is known to be a
878+
// multiple of `rhsConst`.
880879
auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
881880
if (lBin && lBin.getKind() == AffineExprKind::Mul) {
882881
if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
883-
// rhsConst is known to be a positive constant.
882+
// `rhsConst` is known to be a nonzero constant.
884883
if (lrhs.getValue() % rhsConst.getValue() == 0)
885884
return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
886885
}
@@ -891,7 +890,7 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
891890
if (lBin && lBin.getKind() == AffineExprKind::Add) {
892891
int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
893892
int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
894-
// rhsConst is known to be a positive constant.
893+
// rhsConst is known to be a nonzero constant.
895894
if (llhsDiv % rhsConst.getValue() == 0 ||
896895
lrhsDiv % rhsConst.getValue() == 0)
897896
return lBin.getLHS().floorDiv(rhsConst.getValue()) +
@@ -918,7 +917,7 @@ static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
918917
auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
919918
auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
920919

921-
if (!rhsConst || rhsConst.getValue() < 1)
920+
if (!rhsConst || rhsConst.getValue() == 0)
922921
return nullptr;
923922

924923
if (lhsConst) {
@@ -937,12 +936,12 @@ static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
937936
if (rhsConst.getValue() == 1)
938937
return lhs;
939938

940-
// Simplify (expr * const) ceildiv divConst when const is known to be a
941-
// multiple of divConst.
939+
// Simplify `(expr * lrhs) ceildiv rhsConst` when `lrhs` is known to be a
940+
// multiple of `rhsConst`.
942941
auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
943942
if (lBin && lBin.getKind() == AffineExprKind::Mul) {
944943
if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
945-
// rhsConst is known to be a positive constant.
944+
// `rhsConst` is known to be a nonzero constant.
946945
if (lrhs.getValue() % rhsConst.getValue() == 0)
947946
return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
948947
}

mlir/unittests/IR/AffineExprTest.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,25 @@ TEST(AffineExprTest, constantFolding) {
7676
getAffineBinaryOpExpr(AffineExprKind::FloorDiv, cmin, cn1);
7777
ASSERT_EQ(cminfloordivcn1.getKind(), AffineExprKind::FloorDiv);
7878
}
79+
80+
TEST(AffineExprTest, divisionSimplification) {
81+
MLIRContext ctx;
82+
OpBuilder b(&ctx);
83+
auto cn6 = b.getAffineConstantExpr(-6);
84+
auto c6 = b.getAffineConstantExpr(6);
85+
auto d0 = b.getAffineDimExpr(0);
86+
auto d1 = b.getAffineDimExpr(1);
87+
88+
ASSERT_EQ(c6.floorDiv(-1), cn6);
89+
ASSERT_EQ((d0 * 6).floorDiv(2), d0 * 3);
90+
ASSERT_EQ((d0 * 6).floorDiv(4).getKind(), AffineExprKind::FloorDiv);
91+
ASSERT_EQ((d0 * 6).floorDiv(-2), d0 * -3);
92+
ASSERT_EQ((d0 * 6 + d1).floorDiv(2), d0 * 3 + d1.floorDiv(2));
93+
ASSERT_EQ((d0 * 6 + d1).floorDiv(-2), d0 * -3 + d1.floorDiv(-2));
94+
ASSERT_EQ((d0 * 6 + d1).floorDiv(4).getKind(), AffineExprKind::FloorDiv);
95+
96+
ASSERT_EQ(c6.ceilDiv(-1), cn6);
97+
ASSERT_EQ((d0 * 6).ceilDiv(2), d0 * 3);
98+
ASSERT_EQ((d0 * 6).ceilDiv(4).getKind(), AffineExprKind::CeilDiv);
99+
ASSERT_EQ((d0 * 6).ceilDiv(-2), d0 * -3);
100+
}

0 commit comments

Comments
 (0)