Skip to content

Commit b1da82a

Browse files
authored
[mlir][arith] Fix overflow bug in arith::CeilDivSIOp::fold (#90947)
The folder for arith::CeilDivSIOp should only be applied when it can be guaranteed that no overflow would happen. The current implementation works fine when both dividends are positive and the only arithmetic operation is the division itself. However, in cases where either the dividend or divisor is negative (or both), the division is split into multiple arith operations, e.g.: `- ( -a / b)`. That's additional 2 operations on top of the actual division that can overflow - the folder should check all 3 ops for overflow. The current logic doesn't do that - it effectively only checks the last operation (i.e. the division). It breaks when using e.g. MININT values (e.g. -128 for 8-bit integers) - negating such values overflows. This PR makes sure that no folding happens if any of the intermediate arithmetic operations overflows. Fixes #89382
1 parent 559accf commit b1da82a

File tree

2 files changed

+63
-9
lines changed

2 files changed

+63
-9
lines changed

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,8 @@ OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
683683
return getLhs();
684684

685685
// Don't fold if it would overflow or if it requires a division by zero.
686+
// TODO: This hook won't fold operations where a = MININT, because
687+
// negating MININT overflows. This can be improved.
686688
bool overflowOrDiv0 = false;
687689
auto result = constFoldBinaryOp<IntegerAttr>(
688690
adaptor.getOperands(), [&](APInt a, const APInt &b) {
@@ -701,22 +703,36 @@ OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
701703
// Both positive, return ceil(a, b).
702704
return signedCeilNonnegInputs(a, b, overflowOrDiv0);
703705
}
706+
707+
// No folding happens if any of the intermediate arithmetic operations
708+
// overflows.
709+
bool overflowNegA = false;
710+
bool overflowNegB = false;
711+
bool overflowDiv = false;
712+
bool overflowNegRes = false;
704713
if (!aGtZero && !bGtZero) {
705714
// Both negative, return ceil(-a, -b).
706-
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
707-
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
708-
return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
715+
APInt posA = zero.ssub_ov(a, overflowNegA);
716+
APInt posB = zero.ssub_ov(b, overflowNegB);
717+
APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
718+
overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
719+
return res;
709720
}
710721
if (!aGtZero && bGtZero) {
711722
// A is negative, b is positive, return - ( -a / b).
712-
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
713-
APInt div = posA.sdiv_ov(b, overflowOrDiv0);
714-
return zero.ssub_ov(div, overflowOrDiv0);
723+
APInt posA = zero.ssub_ov(a, overflowNegA);
724+
APInt div = posA.sdiv_ov(b, overflowDiv);
725+
APInt res = zero.ssub_ov(div, overflowNegRes);
726+
overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
727+
return res;
715728
}
716729
// A is positive, b is negative, return - (a / -b).
717-
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
718-
APInt div = a.sdiv_ov(posB, overflowOrDiv0);
719-
return zero.ssub_ov(div, overflowOrDiv0);
730+
APInt posB = zero.ssub_ov(b, overflowNegB);
731+
APInt div = a.sdiv_ov(posB, overflowDiv);
732+
APInt res = zero.ssub_ov(div, overflowNegRes);
733+
734+
overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
735+
return res;
720736
});
721737

722738
return overflowOrDiv0 ? Attribute() : result;

mlir/test/Transforms/constant-fold.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,44 @@ func.func @simple_arith.ceildivsi() -> (i32, i32, i32, i32, i32) {
478478

479479
// -----
480480

481+
// CHECK-LABEL: func @simple_arith.ceildivsi_overflow
482+
func.func @simple_arith.ceildivsi_overflow() -> (i8, i16, i32) {
483+
// The negative values below are MININTs for the corresponding bit-width. The
484+
// folder will try to negate them (so that the division operates on two
485+
// positive numbers), but that would cause overflow (negating MININT
486+
// overflows). Hence folding should not happen and the original ceildivsi is
487+
// preserved.
488+
489+
// TODO: The folder should be able to fold the following by avoiding
490+
// intermediate operations that overflow.
491+
492+
// CHECK-DAG: %[[C_1:.*]] = arith.constant 7 : i8
493+
// CHECK-DAG: %[[MIN_I8:.*]] = arith.constant -128 : i8
494+
// CHECK-DAG: %[[C_2:.*]] = arith.constant 7 : i16
495+
// CHECK-DAG: %[[MIN_I16:.*]] = arith.constant -32768 : i16
496+
// CHECK-DAG: %[[C_3:.*]] = arith.constant 7 : i32
497+
// CHECK-DAG: %[[MIN_I32:.*]] = arith.constant -2147483648 : i32
498+
499+
// CHECK-NEXT: %[[CEILDIV_1:.*]] = arith.ceildivsi %[[MIN_I8]], %[[C_1]] : i8
500+
%0 = arith.constant 7 : i8
501+
%min_int_i8 = arith.constant -128 : i8
502+
%2 = arith.ceildivsi %min_int_i8, %0 : i8
503+
504+
// CHECK-NEXT: %[[CEILDIV_2:.*]] = arith.ceildivsi %[[MIN_I16]], %[[C_2]] : i16
505+
%3 = arith.constant 7 : i16
506+
%min_int_i16 = arith.constant -32768 : i16
507+
%5 = arith.ceildivsi %min_int_i16, %3 : i16
508+
509+
// CHECK-NEXT: %[[CEILDIV_2:.*]] = arith.ceildivsi %[[MIN_I32]], %[[C_3]] : i32
510+
%6 = arith.constant 7 : i32
511+
%min_int_i32 = arith.constant -2147483648 : i32
512+
%8 = arith.ceildivsi %min_int_i32, %6 : i32
513+
514+
return %2, %5, %8 : i8, i16, i32
515+
}
516+
517+
// -----
518+
481519
// CHECK-LABEL: func @simple_arith.ceildivui
482520
func.func @simple_arith.ceildivui() -> (i32, i32, i32, i32, i32) {
483521
// CHECK-DAG: [[C0:%.+]] = arith.constant 0

0 commit comments

Comments
 (0)