Skip to content

Commit dc246f9

Browse files
committed
Adding missed optimisation
1 parent 00f2c40 commit dc246f9

File tree

3 files changed

+113
-24
lines changed

3 files changed

+113
-24
lines changed

llvm/include/llvm/IR/Operator.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ class OverflowingBinaryOperator : public Operator {
123123
return NoWrapKind;
124124
}
125125

126+
/// Return true if the instruction is commutative:
127+
bool isCommutative() const { return Instruction::isCommutative(getOpcode()); }
128+
126129
static bool classof(const Instruction *I) {
127130
return I->getOpcode() == Instruction::Add ||
128131
I->getOpcode() == Instruction::Sub ||

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1505,6 +1505,97 @@ foldMinimumOverTrailingOrLeadingZeroCount(Value *I0, Value *I1,
15051505
ConstantInt::getTrue(ZeroUndef->getType()));
15061506
}
15071507

1508+
/// Return whether "X LOp (Y ROp Z)" is always equal to
1509+
/// "(X LOp Y) ROp (X LOp Z)".
1510+
static bool leftDistributesOverRightIntrinsic(Instruction::BinaryOps LOp,
1511+
bool hasNUW, bool hasNSW,
1512+
Intrinsic::ID ROp) {
1513+
switch (ROp) {
1514+
case Intrinsic::umax:
1515+
return hasNUW && LOp == Instruction::Add;
1516+
case Intrinsic::umin:
1517+
return hasNUW && LOp == Instruction::Add;
1518+
case Intrinsic::smax:
1519+
return hasNSW && LOp == Instruction::Add;
1520+
case Intrinsic::smin:
1521+
return hasNSW && LOp == Instruction::Add;
1522+
default:
1523+
return false;
1524+
}
1525+
}
1526+
1527+
// Attempts to factorise a common term
1528+
// in an instruction that has the form "(A op' B) op (C op' D)
1529+
// where op is an intrinsic and op' is a binop
1530+
static Value *
1531+
foldIntrinsicUsingDistributiveLaws(IntrinsicInst *II,
1532+
InstCombiner::BuilderTy &Builder) {
1533+
Value *LHS = II->getOperand(0), *RHS = II->getOperand(1);
1534+
Intrinsic::ID TopLevelOpcode = II->getIntrinsicID();
1535+
1536+
OverflowingBinaryOperator *Op0 = dyn_cast<OverflowingBinaryOperator>(LHS);
1537+
OverflowingBinaryOperator *Op1 = dyn_cast<OverflowingBinaryOperator>(RHS);
1538+
1539+
if (!Op0 || !Op1)
1540+
return nullptr;
1541+
1542+
if (Op0->getOpcode() != Op1->getOpcode())
1543+
return nullptr;
1544+
1545+
if (Op0->hasNoUnsignedWrap() != Op1->hasNoUnsignedWrap() ||
1546+
Op0->hasNoSignedWrap() != Op1->hasNoSignedWrap())
1547+
return nullptr;
1548+
1549+
if (!Op0->hasOneUse() || !Op1->hasOneUse())
1550+
return nullptr;
1551+
1552+
Instruction::BinaryOps InnerOpcode =
1553+
static_cast<Instruction::BinaryOps>(Op0->getOpcode());
1554+
bool HasNUW = Op0->hasNoUnsignedWrap();
1555+
bool HasNSW = Op0->hasNoSignedWrap();
1556+
1557+
if (!InnerOpcode)
1558+
return nullptr;
1559+
1560+
if (!leftDistributesOverRightIntrinsic(InnerOpcode, HasNUW, HasNSW,
1561+
TopLevelOpcode))
1562+
return nullptr;
1563+
1564+
assert(II->isCommutative() && Op0->isCommutative() &&
1565+
"Only inner and outer commutative op codes are supported.");
1566+
1567+
Value *A = Op0->getOperand(0);
1568+
Value *B = Op0->getOperand(1);
1569+
Value *C = Op1->getOperand(0);
1570+
Value *D = Op1->getOperand(1);
1571+
1572+
if (A == C || A == D) {
1573+
if (A != C)
1574+
std::swap(C, D);
1575+
1576+
Value *NewIntrinsic = Builder.CreateBinaryIntrinsic(TopLevelOpcode, B, D);
1577+
BinaryOperator *NewBinop =
1578+
cast<BinaryOperator>(Builder.CreateBinOp(InnerOpcode, NewIntrinsic, A));
1579+
NewBinop->setHasNoSignedWrap(HasNSW);
1580+
NewBinop->setHasNoUnsignedWrap(HasNUW);
1581+
return NewBinop;
1582+
}
1583+
1584+
if (B == D || B == C) {
1585+
if (B != D)
1586+
std::swap(C, D);
1587+
1588+
Value *NewIntrinsic = Builder.CreateBinaryIntrinsic(TopLevelOpcode, A, C);
1589+
BinaryOperator *NewBinop =
1590+
cast<BinaryOperator>(Builder.CreateBinOp(InnerOpcode, NewIntrinsic, B));
1591+
NewBinop->setHasNoSignedWrap(HasNSW);
1592+
NewBinop->setHasNoUnsignedWrap(HasNUW);
1593+
return NewBinop;
1594+
}
1595+
1596+
return nullptr;
1597+
}
1598+
15081599
/// CallInst simplification. This mostly only handles folding of intrinsic
15091600
/// instructions. For normal calls, it allows visitCallBase to do the heavy
15101601
/// lifting.
@@ -1929,6 +2020,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
19292020
}
19302021
}
19312022

2023+
if (Value *V = foldIntrinsicUsingDistributiveLaws(II, Builder))
2024+
return replaceInstUsesWith(*II, V);
2025+
19322026
break;
19332027
}
19342028
case Intrinsic::bitreverse: {

llvm/test/Transforms/InstCombine/intrinsic-distributive.ll

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
define i8 @umax_of_add_nuw(i8 %a, i8 %b, i8 %c) {
66
; CHECK-LABEL: define i8 @umax_of_add_nuw(
77
; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
8-
; CHECK-NEXT: [[ADD1:%.*]] = add nuw i8 [[B]], [[A]]
9-
; CHECK-NEXT: [[ADD2:%.*]] = add nuw i8 [[C]], [[A]]
10-
; CHECK-NEXT: [[MAX:%.*]] = call i8 @llvm.umax.i8(i8 [[ADD1]], i8 [[ADD2]])
8+
; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.umax.i8(i8 [[B]], i8 [[C]])
9+
; CHECK-NEXT: [[MAX:%.*]] = add nuw i8 [[TMP1]], [[A]]
1110
; CHECK-NEXT: ret i8 [[MAX]]
1211
;
1312
%add1 = add nuw i8 %b, %a
@@ -19,9 +18,8 @@ define i8 @umax_of_add_nuw(i8 %a, i8 %b, i8 %c) {
1918
define i8 @umax_of_add_nuw_comm(i8 %a, i8 %b, i8 %c) {
2019
; CHECK-LABEL: define i8 @umax_of_add_nuw_comm(
2120
; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
22-
; CHECK-NEXT: [[ADD1:%.*]] = add nuw i8 [[A]], [[B]]
23-
; CHECK-NEXT: [[ADD2:%.*]] = add nuw i8 [[A]], [[C]]
24-
; CHECK-NEXT: [[MAX:%.*]] = call i8 @llvm.umax.i8(i8 [[ADD1]], i8 [[ADD2]])
21+
; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.umax.i8(i8 [[B]], i8 [[C]])
22+
; CHECK-NEXT: [[MAX:%.*]] = add nuw i8 [[TMP1]], [[A]]
2523
; CHECK-NEXT: ret i8 [[MAX]]
2624
;
2725
%add1 = add nuw i8 %a, %b
@@ -64,9 +62,8 @@ define i8 @umax_of_add(i8 %a, i8 %b, i8 %c) {
6462
define i8 @umin_of_add_nuw(i8 %a, i8 %b, i8 %c) {
6563
; CHECK-LABEL: define i8 @umin_of_add_nuw(
6664
; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
67-
; CHECK-NEXT: [[ADD1:%.*]] = add nuw i8 [[B]], [[A]]
68-
; CHECK-NEXT: [[ADD2:%.*]] = add nuw i8 [[C]], [[A]]
69-
; CHECK-NEXT: [[MIN:%.*]] = call i8 @llvm.umin.i8(i8 [[ADD1]], i8 [[ADD2]])
65+
; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.umin.i8(i8 [[B]], i8 [[C]])
66+
; CHECK-NEXT: [[MIN:%.*]] = add nuw i8 [[TMP1]], [[A]]
7067
; CHECK-NEXT: ret i8 [[MIN]]
7168
;
7269
%add1 = add nuw i8 %b, %a
@@ -78,9 +75,8 @@ define i8 @umin_of_add_nuw(i8 %a, i8 %b, i8 %c) {
7875
define i8 @umin_of_add_nuw_comm(i8 %a, i8 %b, i8 %c) {
7976
; CHECK-LABEL: define i8 @umin_of_add_nuw_comm(
8077
; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
81-
; CHECK-NEXT: [[ADD1:%.*]] = add nuw i8 [[A]], [[B]]
82-
; CHECK-NEXT: [[ADD2:%.*]] = add nuw i8 [[A]], [[C]]
83-
; CHECK-NEXT: [[MIN:%.*]] = call i8 @llvm.umin.i8(i8 [[ADD1]], i8 [[ADD2]])
78+
; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.umin.i8(i8 [[B]], i8 [[C]])
79+
; CHECK-NEXT: [[MIN:%.*]] = add nuw i8 [[TMP1]], [[A]]
8480
; CHECK-NEXT: ret i8 [[MIN]]
8581
;
8682
%add1 = add nuw i8 %a, %b
@@ -137,9 +133,8 @@ define i8 @smax_of_add_nuw(i8 %a, i8 %b, i8 %c) {
137133
define i8 @smax_of_add_nsw(i8 %a, i8 %b, i8 %c) {
138134
; CHECK-LABEL: define i8 @smax_of_add_nsw(
139135
; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
140-
; CHECK-NEXT: [[ADD1:%.*]] = add nsw i8 [[B]], [[A]]
141-
; CHECK-NEXT: [[ADD2:%.*]] = add nsw i8 [[C]], [[A]]
142-
; CHECK-NEXT: [[MAX:%.*]] = call i8 @llvm.smax.i8(i8 [[ADD1]], i8 [[ADD2]])
136+
; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.smax.i8(i8 [[B]], i8 [[C]])
137+
; CHECK-NEXT: [[MAX:%.*]] = add nsw i8 [[TMP1]], [[A]]
143138
; CHECK-NEXT: ret i8 [[MAX]]
144139
;
145140
%add1 = add nsw i8 %b, %a
@@ -151,9 +146,8 @@ define i8 @smax_of_add_nsw(i8 %a, i8 %b, i8 %c) {
151146
define i8 @smax_of_add_nsw_comm(i8 %a, i8 %b, i8 %c) {
152147
; CHECK-LABEL: define i8 @smax_of_add_nsw_comm(
153148
; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
154-
; CHECK-NEXT: [[ADD1:%.*]] = add nsw i8 [[A]], [[B]]
155-
; CHECK-NEXT: [[ADD2:%.*]] = add nsw i8 [[A]], [[C]]
156-
; CHECK-NEXT: [[MAX:%.*]] = call i8 @llvm.smax.i8(i8 [[ADD1]], i8 [[ADD2]])
149+
; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.smax.i8(i8 [[B]], i8 [[C]])
150+
; CHECK-NEXT: [[MAX:%.*]] = add nsw i8 [[TMP1]], [[A]]
157151
; CHECK-NEXT: ret i8 [[MAX]]
158152
;
159153
%add1 = add nsw i8 %a, %b
@@ -195,9 +189,8 @@ define i8 @smin_of_add_nuw(i8 %a, i8 %b, i8 %c) {
195189
define i8 @smin_of_add_nsw(i8 %a, i8 %b, i8 %c) {
196190
; CHECK-LABEL: define i8 @smin_of_add_nsw(
197191
; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
198-
; CHECK-NEXT: [[ADD1:%.*]] = add nsw i8 [[B]], [[A]]
199-
; CHECK-NEXT: [[ADD2:%.*]] = add nsw i8 [[C]], [[A]]
200-
; CHECK-NEXT: [[MIN:%.*]] = call i8 @llvm.smin.i8(i8 [[ADD1]], i8 [[ADD2]])
192+
; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.smin.i8(i8 [[B]], i8 [[C]])
193+
; CHECK-NEXT: [[MIN:%.*]] = add nsw i8 [[TMP1]], [[A]]
201194
; CHECK-NEXT: ret i8 [[MIN]]
202195
;
203196
%add1 = add nsw i8 %b, %a
@@ -209,9 +202,8 @@ define i8 @smin_of_add_nsw(i8 %a, i8 %b, i8 %c) {
209202
define i8 @smin_of_add_nsw_comm(i8 %a, i8 %b, i8 %c) {
210203
; CHECK-LABEL: define i8 @smin_of_add_nsw_comm(
211204
; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
212-
; CHECK-NEXT: [[ADD1:%.*]] = add nsw i8 [[A]], [[B]]
213-
; CHECK-NEXT: [[ADD2:%.*]] = add nsw i8 [[A]], [[C]]
214-
; CHECK-NEXT: [[MIN:%.*]] = call i8 @llvm.smin.i8(i8 [[ADD1]], i8 [[ADD2]])
205+
; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.smin.i8(i8 [[B]], i8 [[C]])
206+
; CHECK-NEXT: [[MIN:%.*]] = add nsw i8 [[TMP1]], [[A]]
215207
; CHECK-NEXT: ret i8 [[MIN]]
216208
;
217209
%add1 = add nsw i8 %a, %b

0 commit comments

Comments
 (0)