@@ -529,9 +529,19 @@ class GradientUtils : public CacheUtility {
529
529
return cast_or_null<BasicBlock>(isOriginal ((const Value *)newinst));
530
530
}
531
531
532
+ struct LoadLikeCall {
533
+ CallInst *loadCall;
534
+ Value *operand;
535
+ LoadLikeCall () = default ;
536
+ LoadLikeCall (CallInst *a, Value *b) : loadCall(a), operand(b) {}
537
+ };
538
+
532
539
struct Rematerializer {
533
540
// Loads which may need to be rematerialized.
534
- SmallPtrSet<LoadInst *, 1 > loads;
541
+ SmallVector<LoadInst *, 1 > loads;
542
+
543
+ // Loads-like calls which need the memory initialized for the reverse.
544
+ SmallVector<LoadLikeCall, 1 > loadLikeCalls;
535
545
536
546
// Operations which must be rerun to rematerialize
537
547
// the value.
@@ -544,10 +554,12 @@ class GradientUtils : public CacheUtility {
544
554
Loop *LI;
545
555
546
556
Rematerializer () : loads(), stores(), frees(), LI(nullptr ) {}
547
- Rematerializer (const SmallPtrSetImpl<LoadInst *> &loads,
557
+ Rematerializer (const SmallVectorImpl<LoadInst *> &loads,
558
+ const SmallVectorImpl<LoadLikeCall> &loadLikeCalls,
548
559
const SmallPtrSetImpl<Instruction *> &stores,
549
560
const SmallPtrSetImpl<Instruction *> &frees, Loop *LI)
550
561
: loads(loads.begin(), loads.end()),
562
+ loadLikeCalls (loadLikeCalls.begin(), loadLikeCalls.end()),
551
563
stores(stores.begin(), stores.end()),
552
564
frees(frees.begin(), frees.end()), LI(LI) {}
553
565
};
@@ -586,7 +598,8 @@ class GradientUtils : public CacheUtility {
586
598
void computeForwardingProperties (Instruction *V, TypeResults &TR) {
587
599
if (!EnzymeRematerialize)
588
600
return ;
589
- SmallPtrSet<LoadInst *, 1 > loads;
601
+ SmallVector<LoadInst *, 1 > loads;
602
+ SmallVector<LoadLikeCall, 1 > loadLikeCalls;
590
603
SmallPtrSet<Instruction *, 1 > stores;
591
604
SmallPtrSet<Instruction *, 1 > frees;
592
605
SmallPtrSet<IntrinsicInst *, 1 > LifetimeStarts;
@@ -638,7 +651,7 @@ class GradientUtils : public CacheUtility {
638
651
shadowpromotable = false ;
639
652
}
640
653
}
641
- loads.insert (load);
654
+ loads.push_back (load);
642
655
} else if (auto store = dyn_cast<StoreInst>(cur)) {
643
656
// TODO only add store to shadow iff non float type
644
657
if (store->getValueOperand () == prev) {
@@ -710,12 +723,8 @@ class GradientUtils : public CacheUtility {
710
723
continue ;
711
724
}
712
725
713
- promotable = false ;
714
-
715
- EmitWarning (" NotPromotable" , cur->getDebugLoc (), oldFunc,
716
- cur->getParent (), " Could not promote allocation " , *V,
717
- " due to unknown call " , *cur);
718
726
size_t idx = 0 ;
727
+ bool seenLoadLikeCall = false ;
719
728
#if LLVM_VERSION_MAJOR >= 14
720
729
for (auto &arg : CI->args ())
721
730
#else
@@ -735,18 +744,63 @@ class GradientUtils : public CacheUtility {
735
744
#if LLVM_VERSION_MAJOR >= 8
736
745
if (CI->doesNotCapture (idx))
737
746
#else
738
- if (CI->dataOperandHasImpliedAttr (idx, Attribute::NoCapture) ||
747
+ if (CI->dataOperandHasImpliedAttr (idx + 1 , Attribute::NoCapture) ||
739
748
(F && F->hasParamAttribute (idx, Attribute::NoCapture)))
740
749
#endif
741
750
{
751
+ #if LLVM_VERSION_MAJOR >= 8
752
+ if (CI->onlyReadsMemory (idx))
753
+ #else
754
+ if (CI->dataOperandHasImpliedAttr (idx + 1 , Attribute::ReadOnly) ||
755
+ CI->dataOperandHasImpliedAttr (idx + 1 , Attribute::ReadNone) ||
756
+ (F && (F->hasParamAttribute (idx, Attribute::ReadOnly) ||
757
+ F->hasParamAttribute (idx, Attribute::ReadNone))))
758
+ #endif
759
+ {
760
+ // if only reading memory, ok to duplicate in forward /
761
+ // reverse if it is a stack or GC allocation.
762
+ // Said memory will still be primal initialized.
763
+ StringRef funcName = " " ;
764
+ if (auto CI = dyn_cast<CallInst>(V))
765
+ if (Function *originCall = getFunctionFromCall (CI))
766
+ funcName = originCall->getName ();
767
+ if (isa<AllocaInst>(V) || hasMetadata (V, " enzyme_fromstack" ) ||
768
+ funcName == " jl_alloc_array_1d" ||
769
+ funcName == " jl_alloc_array_2d" ||
770
+ funcName == " jl_alloc_array_3d" ||
771
+ funcName == " jl_array_copy" ||
772
+ funcName == " ijl_alloc_array_1d" ||
773
+ funcName == " ijl_alloc_array_2d" ||
774
+ funcName == " ijl_alloc_array_3d" ||
775
+ funcName == " ijl_array_copy" ||
776
+ funcName == " julia.gc_alloc_obj" ) {
777
+ if (!seenLoadLikeCall) {
778
+ loadLikeCalls.push_back (LoadLikeCall (CI, prev));
779
+ seenLoadLikeCall = true ;
780
+ }
781
+ } else {
782
+ promotable = false ;
783
+ EmitWarning (" NotPromotable" , cur->getDebugLoc (), oldFunc,
784
+ cur->getParent (), " Could not promote allocation " ,
785
+ *V, " due to unknown non-local call " , *cur);
786
+ }
787
+ } else {
788
+ promotable = false ;
789
+ EmitWarning (" NotPromotable" , cur->getDebugLoc (), oldFunc,
790
+ cur->getParent (), " Could not promote allocation " ,
791
+ *V, " due to unknown writing call " , *cur);
792
+ }
793
+
742
794
if (TT.isFloat ()) {
743
795
// all floats ok
744
796
}
745
797
#if LLVM_VERSION_MAJOR >= 8
746
798
else if (CI->onlyReadsMemory (idx))
747
799
#else
748
- else if (CI->dataOperandHasImpliedAttr (idx, Attribute::ReadOnly) ||
749
- CI->dataOperandHasImpliedAttr (idx, Attribute::ReadNone) ||
800
+ else if (CI->dataOperandHasImpliedAttr (idx + 1 ,
801
+ Attribute::ReadOnly) ||
802
+ CI->dataOperandHasImpliedAttr (idx + 1 ,
803
+ Attribute::ReadNone) ||
750
804
(F && (F->hasParamAttribute (idx, Attribute::ReadOnly) ||
751
805
F->hasParamAttribute (idx, Attribute::ReadNone))))
752
806
#endif
@@ -775,9 +829,12 @@ class GradientUtils : public CacheUtility {
775
829
} else {
776
830
shadowpromotable = false ;
777
831
}
778
- break ;
779
832
} else {
780
833
shadowpromotable = false ;
834
+ promotable = false ;
835
+ EmitWarning (" NotPromotable" , cur->getDebugLoc (), oldFunc,
836
+ cur->getParent (), " Could not promote allocation " , *V,
837
+ " due to unknown capturing call " , *cur);
781
838
}
782
839
idx++;
783
840
}
@@ -832,8 +889,25 @@ class GradientUtils : public CacheUtility {
832
889
}
833
890
rematerializable.insert (LI);
834
891
}
892
+ for (auto LI : loadLikeCalls) {
893
+ // Is there a store which could occur after the load.
894
+ // In other words
895
+ SmallVector<Instruction *, 2 > results;
896
+ mayExecuteAfter (results, LI.loadCall , stores, outer);
897
+ for (auto res : results) {
898
+ if (overwritesToMemoryReadBy (OrigAA, SE, OrigLI, OrigDT, LI.loadCall ,
899
+ res, outer)) {
900
+ EmitWarning (" NotPromotable" , LI.loadCall ->getDebugLoc (), oldFunc,
901
+ LI.loadCall ->getParent (),
902
+ " Could not promote allocation " , *V,
903
+ " due to load-like call " , *LI.loadCall ,
904
+ " which does not postdominates store " , *res);
905
+ return ;
906
+ }
907
+ }
908
+ }
835
909
rematerializableAllocations[V] =
836
- Rematerializer (loads, stores, frees, outer);
910
+ Rematerializer (loads, loadLikeCalls, stores, frees, outer);
837
911
}
838
912
839
913
void computeGuaranteedFrees (
0 commit comments