Skip to content

Commit 1bbfc35

Browse files
tgymnichwsmoses
andauthored
Improve DifferentialUseAnalysis (rust-lang#709)
* Improve DifferentialUseAnalysis * Fix test Co-authored-by: William S. Moses <[email protected]>
1 parent 9cf2b8c commit 1bbfc35

File tree

2 files changed

+38
-8
lines changed

2 files changed

+38
-8
lines changed

enzyme/Enzyme/DifferentialUseAnalysis.h

+38-2
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,6 @@ static inline bool is_use_directly_needed_in_reverse(
107107

108108
if (isa<CmpInst>(user) || isa<BranchInst>(user) || isa<ReturnInst>(user) ||
109109
isa<FPExtInst>(user) || isa<FPTruncInst>(user) ||
110-
(isa<InsertElementInst>(user) &&
111-
cast<InsertElementInst>(user)->getOperand(2) != val) ||
112110
(isa<ExtractElementInst>(user) &&
113111
cast<ExtractElementInst>(user)->getIndexOperand() != val)
114112
#if LLVM_VERSION_MAJOR >= 10
@@ -121,6 +119,44 @@ static inline bool is_use_directly_needed_in_reverse(
121119
return false;
122120
}
123121

122+
if (auto IEI = dyn_cast<InsertElementInst>(user)) {
123+
// Only need the index in the reverse, so if the value is not
124+
// the index, short circuit and say we don't need
125+
if (IEI->getOperand(2) != val) {
126+
return false;
127+
}
128+
// The index is only needed in the reverse if the value being inserted
129+
// is a possible active floating point value
130+
if (gutils->isConstantValue(const_cast<InsertElementInst *>(IEI)) ||
131+
TR.query(const_cast<InsertElementInst *>(IEI))[{-1}] ==
132+
BaseType::Pointer)
133+
return false;
134+
// Otherwise, we need the value.
135+
return true;
136+
}
137+
138+
if (auto IVI = dyn_cast<InsertValueInst>(user)) {
139+
// Only need the index in the reverse, so if the value is not
140+
// the index, short circuit and say we don't need
141+
bool valueIsIndex = false;
142+
for (unsigned i = 2; i < IVI->getNumOperands(); ++i) {
143+
if (IVI->getOperand(i) == val) {
144+
valueIsIndex = true;
145+
}
146+
}
147+
148+
if (!valueIsIndex)
149+
return false;
150+
151+
// The index is only needed in the reverse if the value being inserted
152+
// is a possible active floating point value
153+
if (gutils->isConstantValue(const_cast<InsertValueInst *>(IVI)) ||
154+
TR.query(const_cast<InsertValueInst *>(IVI))[{-1}] == BaseType::Pointer)
155+
return false;
156+
// Otherwise, we need the value.
157+
return true;
158+
}
159+
124160
Intrinsic::ID ID = Intrinsic::not_intrinsic;
125161
if (auto II = dyn_cast<IntrinsicInst>(user)) {
126162
ID = II->getIntrinsicID();

enzyme/test/Enzyme/ForwardModeVector/vecsquare.ll

-6
Original file line numberDiff line numberDiff line change
@@ -56,32 +56,26 @@ entry:
5656
; CHECK-NEXT: [[TMP26:%.*]] = insertvalue [2 x <4 x float>] undef, <4 x float> %"id'ipsv", 0
5757
; CHECK-NEXT: %"id'ipsv1" = shufflevector <4 x float> [[TMP10]], <4 x float> [[TMP22]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
5858
; CHECK-NEXT: [[TMP29:%.*]] = insertvalue [2 x <4 x float>] [[TMP26]], <4 x float> %"id'ipsv1", 1
59-
; CHECK-NEXT: [[ID:%.*]] = shufflevector <4 x float> [[SQ]], <4 x float> [[CB]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
6059
; CHECK-NEXT: %"res1'ipee" = extractelement <4 x float> %"id'ipsv", i32 1
6160
; CHECK-NEXT: [[TMP31:%.*]] = insertvalue [2 x float] undef, float %"res1'ipee", 0
6261
; CHECK-NEXT: %"res1'ipee2" = extractelement <4 x float> %"id'ipsv1", i32 1
6362
; CHECK-NEXT: [[TMP33:%.*]] = insertvalue [2 x float] [[TMP31]], float %"res1'ipee2", 1
64-
; CHECK-NEXT: [[RES1:%.*]] = extractelement <4 x float> [[ID]], i32 1
6563
; CHECK-NEXT: %"res2'ipee" = extractelement <4 x float> %"id'ipsv", i32 2
6664
; CHECK-NEXT: [[TMP35:%.*]] = insertvalue [2 x float] undef, float %"res2'ipee", 0
6765
; CHECK-NEXT: %"res2'ipee3" = extractelement <4 x float> %"id'ipsv1", i32 2
6866
; CHECK-NEXT: [[TMP37:%.*]] = insertvalue [2 x float] [[TMP35]], float %"res2'ipee3", 1
69-
; CHECK-NEXT: [[RES2:%.*]] = extractelement <4 x float> [[ID]], i32 2
7067
; CHECK-NEXT: %"res3'ipee" = extractelement <4 x float> %"id'ipsv", i32 3
7168
; CHECK-NEXT: [[TMP39:%.*]] = insertvalue [2 x float] undef, float %"res3'ipee", 0
7269
; CHECK-NEXT: %"res3'ipee4" = extractelement <4 x float> %"id'ipsv1", i32 3
7370
; CHECK-NEXT: [[TMP41:%.*]] = insertvalue [2 x float] [[TMP39]], float %"res3'ipee4", 1
74-
; CHECK-NEXT: [[RES3:%.*]] = extractelement <4 x float> [[ID]], i32 3
7571
; CHECK-NEXT: %"agg1'ipiv" = insertvalue { float, float, float } zeroinitializer, float %"res1'ipee", 0
7672
; CHECK-NEXT: [[TMP43:%.*]] = insertvalue [2 x { float, float, float }] undef, { float, float, float } %"agg1'ipiv", 0
7773
; CHECK-NEXT: %"agg1'ipiv5" = insertvalue { float, float, float } zeroinitializer, float %"res1'ipee2", 0
7874
; CHECK-NEXT: [[TMP45:%.*]] = insertvalue [2 x { float, float, float }] [[TMP43]], { float, float, float } %"agg1'ipiv5", 1
79-
; CHECK-NEXT: [[AGG1:%.*]] = insertvalue { float, float, float } undef, float [[RES1]], 0
8075
; CHECK-NEXT: %"agg2'ipiv" = insertvalue { float, float, float } %"agg1'ipiv", float %"res2'ipee", 1
8176
; CHECK-NEXT: [[TMP48:%.*]] = insertvalue [2 x { float, float, float }] undef, { float, float, float } %"agg2'ipiv", 0
8277
; CHECK-NEXT: %"agg2'ipiv6" = insertvalue { float, float, float } %"agg1'ipiv5", float %"res2'ipee3", 1
8378
; CHECK-NEXT: [[TMP51:%.*]] = insertvalue [2 x { float, float, float }] [[TMP48]], { float, float, float } %"agg2'ipiv6", 1
84-
; CHECK-NEXT: [[AGG2:%.*]] = insertvalue { float, float, float } [[AGG1]], float [[RES2]], 1
8579
; CHECK-NEXT: %"agg3'ipiv" = insertvalue { float, float, float } %"agg2'ipiv", float %"res3'ipee", 2
8680
; CHECK-NEXT: [[TMP54:%.*]] = insertvalue [2 x { float, float, float }] undef, { float, float, float } %"agg3'ipiv", 0
8781
; CHECK-NEXT: %"agg3'ipiv7" = insertvalue { float, float, float } %"agg2'ipiv6", float %"res3'ipee4", 2

0 commit comments

Comments
 (0)