Skip to content

Commit e496dfa

Browse files
authored
Expose function replacement (rust-lang#741)
1 parent 3dfdabc commit e496dfa

File tree

3 files changed

+42
-30
lines changed

3 files changed

+42
-30
lines changed

enzyme/Enzyme/CApi.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,10 @@ void EnzymeSetMustCache(LLVMValueRef inst1) {
600600
I1->setMetadata("enzyme_mustcache", MDNode::get(I1->getContext(), {}));
601601
}
602602

603+
void EnzymeReplaceFunctionImplementation(LLVMModuleRef M) {
604+
ReplaceFunctionImplementation(*unwrap(M));
605+
}
606+
603607
#if LLVM_VERSION_MAJOR >= 9
604608
void EnzymeAddAttributorLegacyPass(LLVMPassManagerRef PM) {
605609
unwrap(PM)->add(createAttributorLegacyPass());

enzyme/Enzyme/FunctionUtils.cpp

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2070,42 +2070,16 @@ void SelectOptimization(Function *F) {
20702070
}
20712071
}
20722072
}
2073-
void PreProcessCache::optimizeIntermediate(Function *F) {
2074-
PromotePass().run(*F, FAM);
2075-
#if LLVM_VERSION_MAJOR >= 14 && !defined(FLANG)
2076-
GVNPass().run(*F, FAM);
2077-
#else
2078-
GVN().run(*F, FAM);
2079-
#endif
2080-
#if LLVM_VERSION_MAJOR >= 14 && !defined(FLANG)
2081-
SROAPass().run(*F, FAM);
2082-
#else
2083-
SROA().run(*F, FAM);
2084-
#endif
20852073

2086-
if (EnzymeSelectOpt) {
2087-
#if LLVM_VERSION_MAJOR >= 12
2088-
SimplifyCFGOptions scfgo;
2089-
#else
2090-
SimplifyCFGOptions scfgo(
2091-
/*unsigned BonusThreshold=*/1, /*bool ForwardSwitchCond=*/false,
2092-
/*bool SwitchToLookup=*/false, /*bool CanonicalLoops=*/true,
2093-
/*bool SinkCommon=*/true, /*AssumptionCache *AssumpCache=*/nullptr);
2094-
#endif
2095-
SimplifyCFGPass(scfgo).run(*F, FAM);
2096-
CorrelatedValuePropagationPass().run(*F, FAM);
2097-
SelectOptimization(F);
2098-
}
2099-
// EarlyCSEPass(/*memoryssa*/ true).run(*F, FAM);
2100-
2101-
for (Function &Impl : *F->getParent()) {
2074+
void ReplaceFunctionImplementation(Module &M) {
2075+
for (Function &Impl : M) {
21022076
for (auto attr : {"implements", "implements2"}) {
21032077
if (!Impl.hasFnAttribute(attr))
21042078
continue;
21052079
const Attribute &A = Impl.getFnAttribute(attr);
21062080

21072081
const StringRef SpecificationName = A.getValueAsString();
2108-
Function *Specification = F->getParent()->getFunction(SpecificationName);
2082+
Function *Specification = M.getFunction(SpecificationName);
21092083
if (!Specification) {
21102084
LLVM_DEBUG(dbgs() << "Found implementation '" << Impl.getName()
21112085
<< "' but no matching specification with name '"
@@ -2139,10 +2113,41 @@ void PreProcessCache::optimizeIntermediate(Function *F) {
21392113
}
21402114
}
21412115
}
2116+
}
2117+
2118+
void PreProcessCache::optimizeIntermediate(Function *F) {
2119+
PromotePass().run(*F, FAM);
2120+
#if LLVM_VERSION_MAJOR >= 14 && !defined(FLANG)
2121+
GVNPass().run(*F, FAM);
2122+
#else
2123+
GVN().run(*F, FAM);
2124+
#endif
2125+
#if LLVM_VERSION_MAJOR >= 14 && !defined(FLANG)
2126+
SROAPass().run(*F, FAM);
2127+
#else
2128+
SROA().run(*F, FAM);
2129+
#endif
2130+
2131+
if (EnzymeSelectOpt) {
2132+
#if LLVM_VERSION_MAJOR >= 12
2133+
SimplifyCFGOptions scfgo;
2134+
#else
2135+
SimplifyCFGOptions scfgo(
2136+
/*unsigned BonusThreshold=*/1, /*bool ForwardSwitchCond=*/false,
2137+
/*bool SwitchToLookup=*/false, /*bool CanonicalLoops=*/true,
2138+
/*bool SinkCommon=*/true, /*AssumptionCache *AssumpCache=*/nullptr);
2139+
#endif
2140+
SimplifyCFGPass(scfgo).run(*F, FAM);
2141+
CorrelatedValuePropagationPass().run(*F, FAM);
2142+
SelectOptimization(F);
2143+
}
2144+
// EarlyCSEPass(/*memoryssa*/ true).run(*F, FAM);
21422145

21432146
if (EnzymeCoalese)
21442147
CoaleseTrivialMallocs(*F, FAM.getResult<DominatorTreeAnalysis>(*F));
21452148

2149+
ReplaceFunctionImplementation(*F->getParent());
2150+
21462151
PreservedAnalyses PA;
21472152
FAM.invalidate(*F, PA);
21482153
// TODO actually run post optimizations.

enzyme/Enzyme/FunctionUtils.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,11 +346,14 @@ static inline void calculateUnusedStores(
346346
}
347347
}
348348

349+
void ReplaceFunctionImplementation(llvm::Module &M);
350+
349351
/// Is the use of value val as an argument of call CI potentially captured
350352
bool couldFunctionArgumentCapture(llvm::CallInst *CI, llvm::Value *val);
351-
#endif
352353

353354
llvm::FunctionType *getFunctionTypeForClone(
354355
llvm::FunctionType *FTy, DerivativeMode mode, unsigned width,
355356
llvm::Type *additionalArg, llvm::ArrayRef<DIFFE_TYPE> constant_args,
356357
bool diffeReturnArg, ReturnType returnValue, DIFFE_TYPE returnType);
358+
359+
#endif

0 commit comments

Comments
 (0)