Skip to content

Commit fe6cf6f

Browse files
committed
[InstCombine] Fold adds + shifts with nsw and nuw flags
[InstCombine] Fold adds + shifts with nsw and nuw flags I also added mul nsw/nuw 3, div 2 since this was the canonical version of ((x << 1) + x) / 2, which is a specific expression which canonicalization causes the InstCombine to miss it. Proofs: https://alive2.llvm.org/ce/z/kDVTiL https://alive2.llvm.org/ce/z/wORNYm
1 parent ef0ebd3 commit fe6cf6f

File tree

3 files changed

+62
-14
lines changed

3 files changed

+62
-14
lines changed

llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,6 +1267,18 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
12671267
match(Op1, m_SpecificIntAllowPoison(BitWidth - 1)))
12681268
return new ZExtInst(Builder.CreateIsNotNeg(X, "isnotneg"), Ty);
12691269

1270+
// If both the add and the shift are nuw, then:
1271+
// ((X << Z) + Y) nuw >>u Z --> X + (Y >>u Z) nuw
1272+
Value *Y;
1273+
if (match(Op0, m_OneUse(m_c_NUWAdd(m_NUWShl(m_Value(X), m_Specific(Op1)),
1274+
m_Value(Y))))) {
1275+
Value *NewLshr = Builder.CreateLShr(Y, Op1, "", I.isExact());
1276+
auto *newAdd = BinaryOperator::CreateNUWAdd(NewLshr, X);
1277+
if (auto *Op0Bin = cast<OverflowingBinaryOperator>(Op0))
1278+
newAdd->setHasNoSignedWrap(Op0Bin->hasNoSignedWrap());
1279+
return newAdd;
1280+
}
1281+
12701282
if (match(Op1, m_APInt(C))) {
12711283
unsigned ShAmtC = C->getZExtValue();
12721284
auto *II = dyn_cast<IntrinsicInst>(Op0);
@@ -1283,7 +1295,6 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
12831295
return new ZExtInst(Cmp, Ty);
12841296
}
12851297

1286-
Value *X;
12871298
const APInt *C1;
12881299
if (match(Op0, m_Shl(m_Value(X), m_APInt(C1))) && C1->ult(BitWidth)) {
12891300
if (C1->ult(ShAmtC)) {
@@ -1328,7 +1339,6 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
13281339
// ((X << C) + Y) >>u C --> (X + (Y >>u C)) & (-1 >>u C)
13291340
// TODO: Consolidate with the more general transform that starts from shl
13301341
// (the shifts are in the opposite order).
1331-
Value *Y;
13321342
if (match(Op0,
13331343
m_OneUse(m_c_Add(m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))),
13341344
m_Value(Y))))) {
@@ -1450,9 +1460,24 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
14501460
NewMul->setHasNoSignedWrap(true);
14511461
return NewMul;
14521462
}
1463+
1464+
// Special case: lshr nuw (mul (X, 3), 1) -> add nuw nsw (X, lshr(X, 1)
1465+
if (ShAmtC == 1 && MulC->getZExtValue() == 3) {
1466+
auto *NewAdd = BinaryOperator::CreateNUWAdd(
1467+
X,
1468+
Builder.CreateLShr(X, ConstantInt::get(Ty, 1), "", I.isExact()));
1469+
NewAdd->setHasNoSignedWrap(true);
1470+
return NewAdd;
1471+
}
14531472
}
14541473
}
14551474

1475+
// // lshr nsw (mul (X, 3), 1) -> add nsw (X, lshr(X, 1)
1476+
if (match(Op0, m_OneUse(m_NSWMul(m_Value(X), m_SpecificInt(3)))) &&
1477+
ShAmtC == 1)
1478+
return BinaryOperator::CreateNSWAdd(
1479+
X, Builder.CreateLShr(X, ConstantInt::get(Ty, 1), "", I.isExact()));
1480+
14561481
// Try to narrow bswap.
14571482
// In the case where the shift amount equals the bitwidth difference, the
14581483
// shift is eliminated.
@@ -1656,6 +1681,26 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
16561681
if (match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y)))))
16571682
return new SExtInst(Builder.CreateICmpSLT(X, Y), Ty);
16581683
}
1684+
1685+
// Special case: ashr nuw (mul (X, 3), 1) -> add nuw nsw (X, lshr(X, 1)
1686+
if (match(Op0, m_OneUse(m_NSWMul(m_Value(X), m_SpecificInt(3)))) &&
1687+
ShAmt == 1) {
1688+
Value *Shift;
1689+
if (auto *Op0Bin = cast<OverflowingBinaryOperator>(Op0)) {
1690+
if (Op0Bin->hasNoUnsignedWrap())
1691+
// We can use lshr if the mul is nuw and nsw
1692+
Shift =
1693+
Builder.CreateLShr(X, ConstantInt::get(Ty, 1), "", I.isExact());
1694+
else
1695+
Shift =
1696+
Builder.CreateAShr(X, ConstantInt::get(Ty, 1), "", I.isExact());
1697+
1698+
auto *NewAdd = BinaryOperator::CreateNSWAdd(X, Shift);
1699+
NewAdd->setHasNoUnsignedWrap(Op0Bin->hasNoUnsignedWrap());
1700+
1701+
return NewAdd;
1702+
}
1703+
}
16591704
}
16601705

16611706
const SimplifyQuery Q = SQ.getWithInstruction(&I);

llvm/test/Transforms/InstCombine/ashr-lshr.ll

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -607,8 +607,8 @@ define <2 x i8> @ashr_known_pos_exact_vec(<2 x i8> %x, <2 x i8> %y) {
607607

608608
define i32 @ashr_mul_times_3_div_2(i32 %0) {
609609
; CHECK-LABEL: @ashr_mul_times_3_div_2(
610-
; CHECK-NEXT: [[MUL:%.*]] = mul nuw nsw i32 [[TMP0:%.*]], 3
611-
; CHECK-NEXT: [[ASHR:%.*]] = ashr i32 [[MUL]], 1
610+
; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP0:%.*]], 1
611+
; CHECK-NEXT: [[ASHR:%.*]] = add nuw nsw i32 [[TMP2]], [[TMP0]]
612612
; CHECK-NEXT: ret i32 [[ASHR]]
613613
;
614614
%mul = mul nsw nuw i32 %0, 3
@@ -618,8 +618,8 @@ define i32 @ashr_mul_times_3_div_2(i32 %0) {
618618

619619
define i32 @ashr_mul_times_3_div_2_exact(i32 %x) {
620620
; CHECK-LABEL: @ashr_mul_times_3_div_2_exact(
621-
; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[X:%.*]], 3
622-
; CHECK-NEXT: [[ASHR:%.*]] = ashr exact i32 [[MUL]], 1
621+
; CHECK-NEXT: [[TMP1:%.*]] = ashr exact i32 [[X:%.*]], 1
622+
; CHECK-NEXT: [[ASHR:%.*]] = add nsw i32 [[TMP1]], [[X]]
623623
; CHECK-NEXT: ret i32 [[ASHR]]
624624
;
625625
%mul = mul nsw i32 %x, 3

llvm/test/Transforms/InstCombine/lshr.ll

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -360,12 +360,12 @@ define <3 x i14> @mul_splat_fold_vec(<3 x i14> %x) {
360360
ret <3 x i14> %t
361361
}
362362

363-
; Negative tests
363+
; Negative test
364364

365365
define i32 @mul_times_3_div_2(i32 %x) {
366366
; CHECK-LABEL: @mul_times_3_div_2(
367-
; CHECK-NEXT: [[MUL:%.*]] = mul nuw nsw i32 [[X:%.*]], 3
368-
; CHECK-NEXT: [[RES:%.*]] = lshr i32 [[MUL]], 1
367+
; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[X:%.*]], 1
368+
; CHECK-NEXT: [[RES:%.*]] = add nuw nsw i32 [[TMP1]], [[X]]
369369
; CHECK-NEXT: ret i32 [[RES]]
370370
;
371371
%mul = mul nsw nuw i32 %x, 3
@@ -375,9 +375,8 @@ define i32 @mul_times_3_div_2(i32 %x) {
375375

376376
define i32 @shl_add_lshr(i32 %x, i32 %c, i32 %y) {
377377
; CHECK-LABEL: @shl_add_lshr(
378-
; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 [[X:%.*]], [[C:%.*]]
379-
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[SHL]], [[Y:%.*]]
380-
; CHECK-NEXT: [[LSHR:%.*]] = lshr exact i32 [[ADD]], [[C]]
378+
; CHECK-NEXT: [[TMP1:%.*]] = lshr exact i32 [[Y:%.*]], [[C:%.*]]
379+
; CHECK-NEXT: [[LSHR:%.*]] = add nuw nsw i32 [[TMP1]], [[X:%.*]]
381380
; CHECK-NEXT: ret i32 [[LSHR]]
382381
;
383382
%shl = shl nuw i32 %x, %c
@@ -399,8 +398,8 @@ define i32 @lshr_mul_times_3_div_2_nuw(i32 %0) {
399398

400399
define i32 @lshr_mul_times_3_div_2_nsw(i32 %0) {
401400
; CHECK-LABEL: @lshr_mul_times_3_div_2_nsw(
402-
; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[TMP0:%.*]], 3
403-
; CHECK-NEXT: [[LSHR:%.*]] = lshr i32 [[MUL]], 1
401+
; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP0:%.*]], 1
402+
; CHECK-NEXT: [[LSHR:%.*]] = add nsw i32 [[TMP2]], [[TMP0]]
404403
; CHECK-NEXT: ret i32 [[LSHR]]
405404
;
406405
%mul = mul nsw i32 %0, 3
@@ -445,6 +444,8 @@ define i32 @mul_splat_fold_wrong_mul_const(i32 %x) {
445444
ret i32 %t
446445
}
447446

447+
; Negative test
448+
448449
define i32 @shl_add_lshr_multiuse(i32 %x, i32 %y, i32 %z) {
449450
; CHECK-LABEL: @shl_add_lshr_multiuse(
450451
; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 [[X:%.*]], [[Y:%.*]]
@@ -484,6 +485,8 @@ define i32 @mul_splat_fold_wrong_lshr_const(i32 %x) {
484485
ret i32 %t
485486
}
486487

488+
; Negative test
489+
487490
define i32 @mul_splat_fold_no_nuw(i32 %x) {
488491
; CHECK-LABEL: @mul_splat_fold_no_nuw(
489492
; CHECK-NEXT: [[M:%.*]] = mul nsw i32 [[X:%.*]], 65537

0 commit comments

Comments
 (0)