Skip to content

Commit 725ed6b

Browse files
authored
Rematerialize primal readonly (rust-lang#663)
* Rematerialize primal readonly * Fix LLVM7 bug
1 parent b29fc8a commit 725ed6b

File tree

4 files changed

+118
-28
lines changed

4 files changed

+118
-28
lines changed

enzyme/Enzyme/DifferentialUseAnalysis.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ static inline bool is_value_needed_in_reverse(
432432
return seen[idx] = true;
433433
}
434434

435-
// Anything we may try to rematerialize requires its store opreands for
435+
// Anything we may try to rematerialize requires its store operands for
436436
// the reverse pass.
437437
if (!OneLevel) {
438438
if (isa<StoreInst>(user) || isa<MemTransferInst>(user) ||
@@ -443,12 +443,21 @@ static inline bool is_value_needed_in_reverse(
443443
// we'll set it to unused, then check the gep, then here we'll
444444
// directly say unused by induction instead of checking the final
445445
// loads.
446-
if (pair.second.stores.count(user))
446+
if (pair.second.stores.count(user)) {
447447
for (LoadInst *L : pair.second.loads)
448448
if (is_value_needed_in_reverse<VT>(TR, gutils, L, mode, seen,
449449
oldUnreachable)) {
450450
return seen[idx] = true;
451451
}
452+
for (auto &pair : pair.second.loadLikeCalls)
453+
if (is_use_directly_needed_in_reverse(TR, gutils, pair.operand,
454+
pair.loadCall,
455+
oldUnreachable) ||
456+
is_value_needed_in_reverse<VT>(TR, gutils, pair.loadCall,
457+
mode, seen, oldUnreachable)) {
458+
return seen[idx] = true;
459+
}
460+
}
452461
}
453462
}
454463
}
@@ -648,6 +657,11 @@ static inline void minCut(const DataLayout &DL, LoopInfo &OrigLI,
648657
G[Node(pair.first, true)].insert(Node(L, false));
649658
}
650659
}
660+
for (auto L : pair.second.loadLikeCalls) {
661+
if (Intermediates.count(L.loadCall)) {
662+
G[Node(pair.first, true)].insert(Node(L.loadCall, false));
663+
}
664+
}
651665
}
652666
}
653667
for (auto R : Required) {

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,8 @@ void calculateUnusedValuesInFunction(
822822
// Don't erase any store that needs to be preserved for a
823823
// rematerialization. However, if not used in a rematerialization, the
824824
// store should be eliminated in the reverse pass
825-
if (mode == DerivativeMode::ReverseModeGradient) {
825+
if (mode == DerivativeMode::ReverseModeGradient ||
826+
mode == DerivativeMode::ForwardModeSplit) {
826827
auto CI = dyn_cast<CallInst>(const_cast<Instruction *>(inst));
827828
Function *CF = CI ? getFunctionFromCall(CI) : nullptr;
828829
StringRef funcName = CF ? CF->getName() : "";
@@ -832,8 +833,9 @@ void calculateUnusedValuesInFunction(
832833
if (pair.second.stores.count(inst)) {
833834
if (is_value_needed_in_reverse<ValueType::Primal>(
834835
TR, gutils, pair.first, mode, PrimalSeen,
835-
oldUnreachable))
836+
oldUnreachable)) {
836837
return UseReq::Need;
838+
}
837839
}
838840
}
839841
return UseReq::Recur;
@@ -4063,6 +4065,8 @@ Function *EnzymeLogic::CreateForwardDiff(
40634065
for (auto &I : *BB)
40644066
unnecessaryInstructionsTmp.insert(&I);
40654067
}
4068+
if (mode == DerivativeMode::ForwardModeSplit)
4069+
gutils->computeGuaranteedFrees(guaranteedUnreachable, TR);
40664070

40674071
SmallPtrSet<const Value *, 4> unnecessaryValues;
40684072
SmallPtrSet<const Instruction *, 4> unnecessaryInstructions;

enzyme/Enzyme/GradientUtils.h

Lines changed: 88 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -529,9 +529,19 @@ class GradientUtils : public CacheUtility {
529529
return cast_or_null<BasicBlock>(isOriginal((const Value *)newinst));
530530
}
531531

532+
struct LoadLikeCall {
533+
CallInst *loadCall;
534+
Value *operand;
535+
LoadLikeCall() = default;
536+
LoadLikeCall(CallInst *a, Value *b) : loadCall(a), operand(b) {}
537+
};
538+
532539
struct Rematerializer {
533540
// Loads which may need to be rematerialized.
534-
SmallPtrSet<LoadInst *, 1> loads;
541+
SmallVector<LoadInst *, 1> loads;
542+
543+
// Loads-like calls which need the memory initialized for the reverse.
544+
SmallVector<LoadLikeCall, 1> loadLikeCalls;
535545

536546
// Operations which must be rerun to rematerialize
537547
// the value.
@@ -544,10 +554,12 @@ class GradientUtils : public CacheUtility {
544554
Loop *LI;
545555

546556
Rematerializer() : loads(), stores(), frees(), LI(nullptr) {}
547-
Rematerializer(const SmallPtrSetImpl<LoadInst *> &loads,
557+
Rematerializer(const SmallVectorImpl<LoadInst *> &loads,
558+
const SmallVectorImpl<LoadLikeCall> &loadLikeCalls,
548559
const SmallPtrSetImpl<Instruction *> &stores,
549560
const SmallPtrSetImpl<Instruction *> &frees, Loop *LI)
550561
: loads(loads.begin(), loads.end()),
562+
loadLikeCalls(loadLikeCalls.begin(), loadLikeCalls.end()),
551563
stores(stores.begin(), stores.end()),
552564
frees(frees.begin(), frees.end()), LI(LI) {}
553565
};
@@ -586,7 +598,8 @@ class GradientUtils : public CacheUtility {
586598
void computeForwardingProperties(Instruction *V, TypeResults &TR) {
587599
if (!EnzymeRematerialize)
588600
return;
589-
SmallPtrSet<LoadInst *, 1> loads;
601+
SmallVector<LoadInst *, 1> loads;
602+
SmallVector<LoadLikeCall, 1> loadLikeCalls;
590603
SmallPtrSet<Instruction *, 1> stores;
591604
SmallPtrSet<Instruction *, 1> frees;
592605
SmallPtrSet<IntrinsicInst *, 1> LifetimeStarts;
@@ -638,7 +651,7 @@ class GradientUtils : public CacheUtility {
638651
shadowpromotable = false;
639652
}
640653
}
641-
loads.insert(load);
654+
loads.push_back(load);
642655
} else if (auto store = dyn_cast<StoreInst>(cur)) {
643656
// TODO only add store to shadow iff non float type
644657
if (store->getValueOperand() == prev) {
@@ -710,12 +723,8 @@ class GradientUtils : public CacheUtility {
710723
continue;
711724
}
712725

713-
promotable = false;
714-
715-
EmitWarning("NotPromotable", cur->getDebugLoc(), oldFunc,
716-
cur->getParent(), " Could not promote allocation ", *V,
717-
" due to unknown call ", *cur);
718726
size_t idx = 0;
727+
bool seenLoadLikeCall = false;
719728
#if LLVM_VERSION_MAJOR >= 14
720729
for (auto &arg : CI->args())
721730
#else
@@ -735,18 +744,63 @@ class GradientUtils : public CacheUtility {
735744
#if LLVM_VERSION_MAJOR >= 8
736745
if (CI->doesNotCapture(idx))
737746
#else
738-
if (CI->dataOperandHasImpliedAttr(idx, Attribute::NoCapture) ||
747+
if (CI->dataOperandHasImpliedAttr(idx + 1, Attribute::NoCapture) ||
739748
(F && F->hasParamAttribute(idx, Attribute::NoCapture)))
740749
#endif
741750
{
751+
#if LLVM_VERSION_MAJOR >= 8
752+
if (CI->onlyReadsMemory(idx))
753+
#else
754+
if (CI->dataOperandHasImpliedAttr(idx + 1, Attribute::ReadOnly) ||
755+
CI->dataOperandHasImpliedAttr(idx + 1, Attribute::ReadNone) ||
756+
(F && (F->hasParamAttribute(idx, Attribute::ReadOnly) ||
757+
F->hasParamAttribute(idx, Attribute::ReadNone))))
758+
#endif
759+
{
760+
// if only reading memory, ok to duplicate in forward /
761+
// reverse if it is a stack or GC allocation.
762+
// Said memory will still be primal initialized.
763+
StringRef funcName = "";
764+
if (auto CI = dyn_cast<CallInst>(V))
765+
if (Function *originCall = getFunctionFromCall(CI))
766+
funcName = originCall->getName();
767+
if (isa<AllocaInst>(V) || hasMetadata(V, "enzyme_fromstack") ||
768+
funcName == "jl_alloc_array_1d" ||
769+
funcName == "jl_alloc_array_2d" ||
770+
funcName == "jl_alloc_array_3d" ||
771+
funcName == "jl_array_copy" ||
772+
funcName == "ijl_alloc_array_1d" ||
773+
funcName == "ijl_alloc_array_2d" ||
774+
funcName == "ijl_alloc_array_3d" ||
775+
funcName == "ijl_array_copy" ||
776+
funcName == "julia.gc_alloc_obj") {
777+
if (!seenLoadLikeCall) {
778+
loadLikeCalls.push_back(LoadLikeCall(CI, prev));
779+
seenLoadLikeCall = true;
780+
}
781+
} else {
782+
promotable = false;
783+
EmitWarning("NotPromotable", cur->getDebugLoc(), oldFunc,
784+
cur->getParent(), " Could not promote allocation ",
785+
*V, " due to unknown non-local call ", *cur);
786+
}
787+
} else {
788+
promotable = false;
789+
EmitWarning("NotPromotable", cur->getDebugLoc(), oldFunc,
790+
cur->getParent(), " Could not promote allocation ",
791+
*V, " due to unknown writing call ", *cur);
792+
}
793+
742794
if (TT.isFloat()) {
743795
// all floats ok
744796
}
745797
#if LLVM_VERSION_MAJOR >= 8
746798
else if (CI->onlyReadsMemory(idx))
747799
#else
748-
else if (CI->dataOperandHasImpliedAttr(idx, Attribute::ReadOnly) ||
749-
CI->dataOperandHasImpliedAttr(idx, Attribute::ReadNone) ||
800+
else if (CI->dataOperandHasImpliedAttr(idx + 1,
801+
Attribute::ReadOnly) ||
802+
CI->dataOperandHasImpliedAttr(idx + 1,
803+
Attribute::ReadNone) ||
750804
(F && (F->hasParamAttribute(idx, Attribute::ReadOnly) ||
751805
F->hasParamAttribute(idx, Attribute::ReadNone))))
752806
#endif
@@ -775,9 +829,12 @@ class GradientUtils : public CacheUtility {
775829
} else {
776830
shadowpromotable = false;
777831
}
778-
break;
779832
} else {
780833
shadowpromotable = false;
834+
promotable = false;
835+
EmitWarning("NotPromotable", cur->getDebugLoc(), oldFunc,
836+
cur->getParent(), " Could not promote allocation ", *V,
837+
" due to unknown capturing call ", *cur);
781838
}
782839
idx++;
783840
}
@@ -832,8 +889,25 @@ class GradientUtils : public CacheUtility {
832889
}
833890
rematerializable.insert(LI);
834891
}
892+
for (auto LI : loadLikeCalls) {
893+
// Is there a store which could occur after the load.
894+
// In other words
895+
SmallVector<Instruction *, 2> results;
896+
mayExecuteAfter(results, LI.loadCall, stores, outer);
897+
for (auto res : results) {
898+
if (overwritesToMemoryReadBy(OrigAA, SE, OrigLI, OrigDT, LI.loadCall,
899+
res, outer)) {
900+
EmitWarning("NotPromotable", LI.loadCall->getDebugLoc(), oldFunc,
901+
LI.loadCall->getParent(),
902+
" Could not promote allocation ", *V,
903+
" due to load-like call ", *LI.loadCall,
904+
" which does not postdominates store ", *res);
905+
return;
906+
}
907+
}
908+
}
835909
rematerializableAllocations[V] =
836-
Rematerializer(loads, stores, frees, outer);
910+
Rematerializer(loads, loadLikeCalls, stores, frees, outer);
837911
}
838912

839913
void computeGuaranteedFrees(

enzyme/test/Enzyme/ForwardModeSplit/square2.ll

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
; }
2121

2222

23-
define dso_local void @square_(double* nocapture readonly %src, double* nocapture %dest) local_unnamed_addr #0 {
23+
define dso_local void @square_(double* nocapture readonly %src, double* nocapture noalias %dest) local_unnamed_addr #0 {
2424
entry:
2525
%0 = load double, double* %src, align 8
2626
%mul = fmul double %0, %0
@@ -60,30 +60,28 @@ attributes #3 = { "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-
6060
attributes #4 = { nounwind }
6161

6262

63-
; CHECK: define internal double @fwddiffesquare(double %x, double %"x'", i8* %tapeArg)
63+
; CHECK: define internal double @fwddiffesquare(double %x, double %"x'", i8* %malloccall1)
6464
; CHECK-NEXT: entry:
65-
; CHECK-NEXT: %0 = bitcast i8* %tapeArg to { double, i8*, i8* }*
66-
; CHECK-NEXT: %truetape = load { double, i8*, i8* }, { double, i8*, i8* }* %0
67-
; CHECK-NEXT: %malloccall = extractvalue { double, i8*, i8* } %truetape, 2
65+
; CHECK-NEXT: %[[malloccall:.+]] = alloca i8, i64 8, align 8
6866
; CHECK-NEXT: %"malloccall'mi" = alloca i8, i64 8, align 8
6967
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* nonnull dereferenceable(8) dereferenceable_or_null(8) %"malloccall'mi", i8 0, i64 8, i1 false)
7068
; CHECK-NEXT: %"x.addr'ipc" = bitcast i8* %"malloccall'mi" to double*
71-
; CHECK-NEXT: %x.addr = bitcast i8* %malloccall to double*
72-
; CHECK-NEXT: %malloccall1 = extractvalue { double, i8*, i8* } %truetape, 1
69+
; CHECK-NEXT: %x.addr = bitcast i8* %[[malloccall]] to double*
7370
; CHECK-NEXT: %"malloccall1'mi" = alloca i8, i64 8, align 8
7471
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* nonnull dereferenceable(8) dereferenceable_or_null(8) %"malloccall1'mi", i8 0, i64 8, i1 false)
7572
; CHECK-NEXT: %"y'ipc" = bitcast i8* %"malloccall1'mi" to double*
7673
; CHECK-NEXT: %y = bitcast i8* %malloccall1 to double*
74+
; CHECK-NEXT: store double %x, double* %x.addr, align 8
7775
; CHECK-NEXT: store double %"x'", double* %"x.addr'ipc", align 8
78-
; CHECK-NEXT: %[[tapeArg1:.+]] = extractvalue { double, i8*, i8* } %truetape, 0
79-
; CHECK-NEXT: call void @fwddiffesquare_(double* %x.addr, double* %"x.addr'ipc", double* %y, double* %"y'ipc", double %[[tapeArg1]])
76+
; CHECK-NEXT: call void @fwddiffesquare_(double* %x.addr, double* %"x.addr'ipc", double* %y, double* %"y'ipc")
8077
; CHECK-NEXT: %[[i1:.+]] = load double, double* %"y'ipc", align 8
8178
; CHECK-NEXT: ret double %[[i1]]
8279
; CHECK-NEXT: }
8380

84-
; CHECK: define internal void @fwddiffesquare_(double* nocapture readonly %src, double* nocapture %"src'", double* nocapture %dest, double* nocapture %"dest'", double
81+
; CHECK: define internal void @fwddiffesquare_(double* nocapture readonly %src, double* nocapture %"src'", double* noalias nocapture %dest, double* nocapture %"dest'")
8582
; CHECK-NEXT: entry:
8683
; CHECK-NEXT: %[[i1:.+]] = load double, double* %"src'", align 8
84+
; CHECK-NEXT: %0 = load double, double* %src
8785
; CHECK-NEXT: %[[i2:.+]] = fmul fast double %[[i1]], %0
8886
; CHECK-NEXT: %[[i3:.+]] = fmul fast double %[[i1]], %0
8987
; CHECK-NEXT: %[[i4:.+]] = fadd fast double %[[i2]], %[[i3]]

0 commit comments

Comments
 (0)