Skip to content

Commit 0ae6431

Browse files
[InstCombine] Handle ceil division idiom
The expression `add (udiv (sub A, Bias), B), Bias` can be folded to `udiv (add A, B - 1), B)` when the sum between `A` and `B` is known not to overflow, and `Bias = A != 0`. Proof: https://alive2.llvm.org/ce/z/hiWHQA.
1 parent 87aa097 commit 0ae6431

File tree

2 files changed

+70
-44
lines changed

2 files changed

+70
-44
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp

Lines changed: 70 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,6 +1257,74 @@ static Instruction *foldToUnsignedSaturatedAdd(BinaryOperator &I) {
12571257
return nullptr;
12581258
}
12591259

1260+
static Value *foldCeilIdioms(BinaryOperator &I, InstCombinerImpl &IC) {
1261+
assert(I.getOpcode() == Instruction::Add && "Expecting add instruction.");
1262+
Value *A, *B;
1263+
auto &ICB = IC.Builder;
1264+
1265+
// Fold the log2 ceil idiom:
1266+
// zext (ctpop(A) >u/!= 1) + (ctlz (A, true) ^ (BW - 1))
1267+
// -> BW - ctlz (A - 1, false)
1268+
const APInt *XorC;
1269+
CmpPredicate Pred;
1270+
if (match(&I,
1271+
m_c_Add(
1272+
m_ZExt(m_ICmp(Pred, m_Intrinsic<Intrinsic::ctpop>(m_Value(A)),
1273+
m_One())),
1274+
m_OneUse(m_ZExtOrSelf(m_OneUse(m_Xor(
1275+
m_OneUse(m_TruncOrSelf(m_OneUse(
1276+
m_Intrinsic<Intrinsic::ctlz>(m_Deferred(A), m_One())))),
1277+
m_APInt(XorC))))))) &&
1278+
(Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_NE) &&
1279+
*XorC == A->getType()->getScalarSizeInBits() - 1) {
1280+
Value *Sub = ICB.CreateAdd(A, Constant::getAllOnesValue(A->getType()));
1281+
Value *Ctlz = ICB.CreateIntrinsic(Intrinsic::ctlz, {A->getType()},
1282+
{Sub, ICB.getFalse()});
1283+
Value *Ret = ICB.CreateSub(
1284+
ConstantInt::get(A->getType(), A->getType()->getScalarSizeInBits()),
1285+
Ctlz, "", /*HasNUW*/ true, /*HasNSW*/ true);
1286+
return ICB.CreateZExtOrTrunc(Ret, I.getType());
1287+
}
1288+
1289+
// Fold the ceil division idiom:
1290+
// add (udiv (sub A, Bias), B), Bias
1291+
// -> udiv (add A, B - 1), B)
1292+
// with Bias = A != 0; A + B not to overflow
1293+
auto MatchDivision = [&IC](Instruction *Div, Value *&DivOp0, Value *&DivOp1) {
1294+
if (match(Div, m_UDiv(m_Value(DivOp0), m_Value(DivOp1))))
1295+
return true;
1296+
1297+
Value *N;
1298+
if (match(Div, m_LShr(m_Value(DivOp0), m_Value(N))) &&
1299+
match(N,
1300+
m_Sub(m_SpecificInt(Div->getType()->getScalarSizeInBits() - 1),
1301+
m_Intrinsic<Intrinsic::ctlz>(m_Value(DivOp1), m_Zero()))) &&
1302+
IC.isKnownToBeAPowerOfTwo(DivOp1, /*OrZero*/ false, 0, Div))
1303+
return true;
1304+
1305+
return false;
1306+
};
1307+
1308+
Instruction *Div;
1309+
Value *Bias, *Sub;
1310+
if (match(&I, m_c_Add(m_Instruction(Div), m_Value(Bias))) &&
1311+
MatchDivision(Div, Sub, B) &&
1312+
match(Sub, m_Sub(m_Value(A), m_Value(Bias))) &&
1313+
match(Bias, m_ZExt(m_SpecificICmp(ICmpInst::ICMP_NE, m_Specific(A),
1314+
m_ZeroInt()))) &&
1315+
Bias->hasNUses(2)) {
1316+
WithCache<const Value *> LHSCache(A), RHSCache(B);
1317+
auto OR = IC.computeOverflowForUnsignedAdd(LHSCache, RHSCache, &I);
1318+
if (OR == OverflowResult::NeverOverflows) {
1319+
auto *BMinusOne =
1320+
ICB.CreateAdd(B, Constant::getAllOnesValue(I.getType()));
1321+
return ICB.CreateUDiv(ICB.CreateAdd(A, BMinusOne), B);
1322+
}
1323+
}
1324+
1325+
return nullptr;
1326+
}
1327+
12601328
// Transform:
12611329
// (add A, (shl (neg B), Y))
12621330
// -> (sub A, (shl B, Y))
@@ -1838,30 +1906,8 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
18381906
I, Builder.CreateIntrinsic(Intrinsic::ctpop, {I.getType()},
18391907
{Builder.CreateOr(A, B)}));
18401908

1841-
// Fold the log2_ceil idiom:
1842-
// zext(ctpop(A) >u/!= 1) + (ctlz(A, true) ^ (BW - 1))
1843-
// -->
1844-
// BW - ctlz(A - 1, false)
1845-
const APInt *XorC;
1846-
CmpPredicate Pred;
1847-
if (match(&I,
1848-
m_c_Add(
1849-
m_ZExt(m_ICmp(Pred, m_Intrinsic<Intrinsic::ctpop>(m_Value(A)),
1850-
m_One())),
1851-
m_OneUse(m_ZExtOrSelf(m_OneUse(m_Xor(
1852-
m_OneUse(m_TruncOrSelf(m_OneUse(
1853-
m_Intrinsic<Intrinsic::ctlz>(m_Deferred(A), m_One())))),
1854-
m_APInt(XorC))))))) &&
1855-
(Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_NE) &&
1856-
*XorC == A->getType()->getScalarSizeInBits() - 1) {
1857-
Value *Sub = Builder.CreateAdd(A, Constant::getAllOnesValue(A->getType()));
1858-
Value *Ctlz = Builder.CreateIntrinsic(Intrinsic::ctlz, {A->getType()},
1859-
{Sub, Builder.getFalse()});
1860-
Value *Ret = Builder.CreateSub(
1861-
ConstantInt::get(A->getType(), A->getType()->getScalarSizeInBits()),
1862-
Ctlz, "", /*HasNUW*/ true, /*HasNSW*/ true);
1863-
return replaceInstUsesWith(I, Builder.CreateZExtOrTrunc(Ret, I.getType()));
1864-
}
1909+
if (Value *V = foldCeilIdioms(I, *this))
1910+
return replaceInstUsesWith(I, V);
18651911

18661912
if (Instruction *Res = foldSquareSumInt(I))
18671913
return Res;

llvm/test/Transforms/InstCombine/fold-ceil-div-idiom.ll

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44
define i8 @ceil_div_idiom(i8 %x, i8 %y) {
55
; CHECK-LABEL: define i8 @ceil_div_idiom(
66
; CHECK-SAME: i8 [[X:%.*]], i8 [[Y:%.*]]) {
7-
; CHECK-NEXT: [[WO:%.*]] = call { i8, i1 } @llvm.uadd.with.overflow.i8(i8 [[X]], i8 [[Y]])
8-
; CHECK-NEXT: [[OV:%.*]] = extractvalue { i8, i1 } [[WO]], 1
9-
; CHECK-NEXT: [[OV_NOT:%.*]] = xor i1 [[OV]], true
10-
; CHECK-NEXT: call void @llvm.assume(i1 [[OV_NOT]])
117
; CHECK-NEXT: [[NONZERO:%.*]] = icmp ne i8 [[X]], 0
128
; CHECK-NEXT: [[BIAS:%.*]] = zext i1 [[NONZERO]] to i8
139
; CHECK-NEXT: [[SUB:%.*]] = sub i8 [[X]], [[BIAS]]
@@ -56,10 +52,6 @@ define i8 @ceil_div_idiom_2(i8 %x, i8 %y) {
5652
define i8 @ceil_div_idiom_with_lshr(i8 %x, i8 %y) {
5753
; CHECK-LABEL: define i8 @ceil_div_idiom_with_lshr(
5854
; CHECK-SAME: i8 [[X:%.*]], i8 [[Y:%.*]]) {
59-
; CHECK-NEXT: [[WO:%.*]] = call { i8, i1 } @llvm.uadd.with.overflow.i8(i8 [[X]], i8 [[Y]])
60-
; CHECK-NEXT: [[OV:%.*]] = extractvalue { i8, i1 } [[WO]], 1
61-
; CHECK-NEXT: [[OV_NOT:%.*]] = xor i1 [[OV]], true
62-
; CHECK-NEXT: call void @llvm.assume(i1 [[OV_NOT]])
6355
; CHECK-NEXT: [[CTPOPULATION:%.*]] = call range(i8 0, 9) i8 @llvm.ctpop.i8(i8 [[Y]])
6456
; CHECK-NEXT: [[IS_POW_2:%.*]] = icmp eq i8 [[CTPOPULATION]], 1
6557
; CHECK-NEXT: call void @llvm.assume(i1 [[IS_POW_2]])
@@ -112,10 +104,6 @@ define i8 @ceil_div_idiom_add_may_overflow(i8 %x, i8 %y) {
112104
define i8 @ceil_div_idiom_multiuse_bias(i8 %x, i8 %y) {
113105
; CHECK-LABEL: define i8 @ceil_div_idiom_multiuse_bias(
114106
; CHECK-SAME: i8 [[X:%.*]], i8 [[Y:%.*]]) {
115-
; CHECK-NEXT: [[WO:%.*]] = call { i8, i1 } @llvm.uadd.with.overflow.i8(i8 [[X]], i8 [[Y]])
116-
; CHECK-NEXT: [[OV:%.*]] = extractvalue { i8, i1 } [[WO]], 1
117-
; CHECK-NEXT: [[OV_NOT:%.*]] = xor i1 [[OV]], true
118-
; CHECK-NEXT: call void @llvm.assume(i1 [[OV_NOT]])
119107
; CHECK-NEXT: [[NONZERO:%.*]] = icmp ne i8 [[X]], 0
120108
; CHECK-NEXT: [[BIAS:%.*]] = zext i1 [[NONZERO]] to i8
121109
; CHECK-NEXT: [[SUB:%.*]] = sub i8 [[X]], [[BIAS]]
@@ -141,10 +129,6 @@ define i8 @ceil_div_idiom_multiuse_bias(i8 %x, i8 %y) {
141129
define i8 @ceil_div_idiom_with_lshr_not_power_2(i8 %x, i8 %y) {
142130
; CHECK-LABEL: define i8 @ceil_div_idiom_with_lshr_not_power_2(
143131
; CHECK-SAME: i8 [[X:%.*]], i8 [[Y:%.*]]) {
144-
; CHECK-NEXT: [[WO:%.*]] = call { i8, i1 } @llvm.uadd.with.overflow.i8(i8 [[X]], i8 [[Y]])
145-
; CHECK-NEXT: [[OV:%.*]] = extractvalue { i8, i1 } [[WO]], 1
146-
; CHECK-NEXT: [[OV_NOT:%.*]] = xor i1 [[OV]], true
147-
; CHECK-NEXT: call void @llvm.assume(i1 [[OV_NOT]])
148132
; CHECK-NEXT: [[NONZERO:%.*]] = icmp ne i8 [[X]], 0
149133
; CHECK-NEXT: [[BIAS:%.*]] = zext i1 [[NONZERO]] to i8
150134
; CHECK-NEXT: [[SUB:%.*]] = sub i8 [[X]], [[BIAS]]
@@ -172,10 +156,6 @@ define i8 @ceil_div_idiom_with_lshr_not_power_2(i8 %x, i8 %y) {
172156
define i8 @ceil_div_idiom_with_lshr_wrong_bw(i8 %x, i8 %y) {
173157
; CHECK-LABEL: define i8 @ceil_div_idiom_with_lshr_wrong_bw(
174158
; CHECK-SAME: i8 [[X:%.*]], i8 [[Y:%.*]]) {
175-
; CHECK-NEXT: [[WO:%.*]] = call { i8, i1 } @llvm.uadd.with.overflow.i8(i8 [[X]], i8 [[Y]])
176-
; CHECK-NEXT: [[OV:%.*]] = extractvalue { i8, i1 } [[WO]], 1
177-
; CHECK-NEXT: [[OV_NOT:%.*]] = xor i1 [[OV]], true
178-
; CHECK-NEXT: call void @llvm.assume(i1 [[OV_NOT]])
179159
; CHECK-NEXT: [[CTPOPULATION:%.*]] = call range(i8 0, 9) i8 @llvm.ctpop.i8(i8 [[Y]])
180160
; CHECK-NEXT: [[IS_POW_2:%.*]] = icmp eq i8 [[CTPOPULATION]], 1
181161
; CHECK-NEXT: call void @llvm.assume(i1 [[IS_POW_2]])

0 commit comments

Comments
 (0)