Skip to content

[arith][mlir] Fixed a bug in CeilDiv with neg values #90855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -652,16 +652,18 @@ OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
}
if (!aGtZero && bGtZero) {
// A is negative, b is positive, return - ( -a / b).
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
APInt div = posA.sdiv_ov(b, overflowOrDiv0);
return zero.ssub_ov(div, overflowOrDiv0);
}
// A is positive, b is negative, return - (a / -b).
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
APInt div = a.sdiv_ov(posB, overflowOrDiv0);
return zero.ssub_ov(div, overflowOrDiv0);
// If either divisor or dividend is negative, then take their absolute
// value and then do a normal signedCeil Division, but add 1 to bring
// the quotient down. In essense, Ceil Division with one of the values
// negative works like a floorDivision with negated quotient.
// Mathematically, -1 * (abs(a)-1/abs(b) + 1) + 1, which after factoring
// out -1 yields -1 * [abs(a)-1/abs(b) + 1 - 1]. This is implemented
// below.
APInt posA = aGtZero ? a : zero.ssub_ov(a, overflowOrDiv0);
APInt posB = bGtZero ? b : zero.ssub_ov(b, overflowOrDiv0);
APInt div = signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
APInt res = div.ssub_ov(APInt::getOneBitSet(bits, 0), overflowOrDiv0);
return zero.ssub_ov(res, overflowOrDiv0);
});

return overflowOrDiv0 ? Attribute() : result;
Expand Down Expand Up @@ -2260,12 +2262,12 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {

// Constant-fold constant operands over non-splat constant condition.
// select %cst_vec, %cst0, %cst1 => %cst2
if (auto cond =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
if (auto lhs =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
if (auto rhs =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
if (auto cond = llvm::dyn_cast_if_present<DenseElementsAttr>(
adaptor.getCondition())) {
if (auto lhs = llvm::dyn_cast_if_present<DenseElementsAttr>(
adaptor.getTrueValue())) {
if (auto rhs = llvm::dyn_cast_if_present<DenseElementsAttr>(
adaptor.getFalseValue())) {
SmallVector<Attribute> results;
results.reserve(static_cast<size_t>(cond.getNumElements()));
auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
Expand Down Expand Up @@ -2520,7 +2522,7 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
case AtomicRMWKind::minimumf:
return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
case AtomicRMWKind::maxnumf:
case AtomicRMWKind::maxnumf:
return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
case AtomicRMWKind::minnumf:
return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
Expand Down
11 changes: 11 additions & 0 deletions mlir/test/Transforms/constant-fold.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,17 @@ func.func @simple_arith.ceildivsi() -> (i32, i32, i32, i32, i32) {

// -----

// CHECK-LABEL: simple_arith.ceildivsi_i8
func.func @simple_arith.ceildivsi_i8() -> (i8) {
%0 = arith.constant 7 : i8
%1 = arith.constant -128 : i8
// CHECK-NEXT: arith.constant -18
%2 = arith.ceildivsi %1, %0 : i8
return %2 : i8
}

// -----

// CHECK-LABEL: func @simple_arith.ceildivui
func.func @simple_arith.ceildivui() -> (i32, i32, i32, i32, i32) {
// CHECK-DAG: [[C0:%.+]] = arith.constant 0
Expand Down
Loading