Skip to content

Commit d816645

Browse files
authored
Forward mode load store (#193)
* added getForwardBuilder * storeInst forward mode * loadInst forward mod * moved forward and reverse builder helpers into GradientUtils
1 parent 692cc07 commit d816645

11 files changed

+427
-65
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 77 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -269,10 +269,15 @@ class AdjointGenerator
269269
IRBuilder<> BuilderZ(newi);
270270
Value *newip = nullptr;
271271

272-
bool needShadow = is_value_needed_in_reverse<ValueType::ShadowPtr>(
273-
TR, gutils, &I,
274-
/*toplevel*/ Mode == DerivativeMode::ReverseModeCombined,
275-
oldUnreachable);
272+
// TODO: In the case of fwd mode this should be true if the loaded value
273+
// itself is used as a pointer.
274+
bool needShadow =
275+
Mode == DerivativeMode::ForwardMode
276+
? false
277+
: is_value_needed_in_reverse<ValueType::ShadowPtr>(
278+
TR, gutils, &I,
279+
/*toplevel*/ Mode == DerivativeMode::ReverseModeCombined,
280+
oldUnreachable);
276281

277282
switch (Mode) {
278283

@@ -291,8 +296,8 @@ class AdjointGenerator
291296
gutils->invertedPointers[&I] = newip;
292297
break;
293298
}
294-
295-
case DerivativeMode::ReverseModeGradient: {
299+
case DerivativeMode::ReverseModeGradient:
300+
case DerivativeMode::ForwardMode: {
296301
// only make shadow where caching needed
297302
if (can_modref && needShadow) {
298303
newip = gutils->cacheForReverse(BuilderZ, placeholder,
@@ -322,13 +327,18 @@ class AdjointGenerator
322327

323328
Value *inst = newi;
324329

330+
// TODO: In the case of fwd mode this should be true if the loaded value
331+
// itself is used as a pointer.
332+
bool primalNeededInReverse =
333+
Mode == DerivativeMode::ForwardMode
334+
? false
335+
: is_value_needed_in_reverse<ValueType::Primal>(
336+
TR, gutils, &I,
337+
/*toplevel*/ Mode == DerivativeMode::ReverseModeCombined,
338+
oldUnreachable);
325339
//! Store loads that need to be cached for use in reverse pass
326340
if (cache_reads_always ||
327-
(!cache_reads_never && can_modref &&
328-
is_value_needed_in_reverse<ValueType::Primal>(
329-
TR, gutils, &I,
330-
/*toplevel*/ Mode == DerivativeMode::ReverseModeCombined,
331-
oldUnreachable))) {
341+
(!cache_reads_never && can_modref && primalNeededInReverse)) {
332342
if (!gutils->unnecessaryIntermediates.count(&I)) {
333343
IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&I)->getNextNode());
334344
// auto tbaa = inst->getMetadata(LLVMContext::MD_tbaa);
@@ -379,15 +389,36 @@ class AdjointGenerator
379389
}
380390

381391
if (isfloat) {
382-
IRBuilder<> Builder2(parent);
383-
getReverseBuilder(Builder2);
384-
auto prediff = diffe(&I, Builder2);
385-
setDiffe(&I, Constant::getNullValue(type), Builder2);
386392

387-
if (!gutils->isConstantValue(I.getOperand(0))) {
388-
((DiffeGradientUtils *)gutils)
389-
->addToInvertedPtrDiffe(I.getOperand(0), prediff, Builder2,
390-
alignment, OrigOffset);
393+
switch (Mode) {
394+
case DerivativeMode::ForwardMode: {
395+
IRBuilder<> Builder2(&I);
396+
getForwardBuilder(Builder2);
397+
398+
if (!gutils->isConstantValue(&I)) {
399+
auto diff = Builder2.CreateLoad(
400+
gutils->invertPointerM(I.getOperand(0), Builder2));
401+
setDiffe(&I, diff, Builder2);
402+
}
403+
break;
404+
}
405+
case DerivativeMode::ReverseModeGradient:
406+
case DerivativeMode::ReverseModeCombined: {
407+
IRBuilder<> Builder2(parent);
408+
getReverseBuilder(Builder2);
409+
410+
auto prediff = diffe(&I, Builder2);
411+
setDiffe(&I, Constant::getNullValue(type), Builder2);
412+
413+
if (!gutils->isConstantValue(I.getOperand(0))) {
414+
((DiffeGradientUtils *)gutils)
415+
->addToInvertedPtrDiffe(I.getOperand(0), prediff, Builder2,
416+
alignment, OrigOffset);
417+
}
418+
break;
419+
}
420+
case DerivativeMode::ReverseModePrimal:
421+
break;
391422
}
392423
}
393424
}
@@ -494,8 +525,9 @@ class AdjointGenerator
494525

495526
if (FT) {
496527
//! Only need to update the reverse function
497-
if (Mode == DerivativeMode::ReverseModeGradient ||
498-
Mode == DerivativeMode::ReverseModeCombined) {
528+
switch (Mode) {
529+
case DerivativeMode::ReverseModeGradient:
530+
case DerivativeMode::ReverseModeCombined: {
499531
IRBuilder<> Builder2(SI.getParent());
500532
getReverseBuilder(Builder2);
501533

@@ -512,13 +544,29 @@ class AdjointGenerator
512544
ts = setPtrDiffe(orig_ptr, Constant::getNullValue(valType), Builder2);
513545
addToDiffe(orig_val, dif1, Builder2, FT);
514546
}
547+
break;
548+
}
549+
case DerivativeMode::ForwardMode: {
550+
IRBuilder<> Builder2(&SI);
551+
getForwardBuilder(Builder2);
552+
553+
if (constantval) {
554+
ts = setPtrDiffe(orig_ptr, Constant::getNullValue(valType), Builder2);
555+
} else {
556+
auto diff = diffe(orig_val, Builder2);
557+
558+
ts = setPtrDiffe(orig_ptr, diff, Builder2);
559+
}
560+
break;
561+
}
515562
}
516563

517564
//! Storing an integer or pointer
518565
} else {
519566
//! Only need to update the forward function
520567
if (Mode == DerivativeMode::ReverseModePrimal ||
521-
Mode == DerivativeMode::ReverseModeCombined) {
568+
Mode == DerivativeMode::ReverseModeCombined ||
569+
Mode == DerivativeMode::ForwardMode) {
522570
IRBuilder<> storeBuilder(gutils->getNewFromOriginal(&SI));
523571

524572
Value *valueop = nullptr;
@@ -935,25 +983,12 @@ class AdjointGenerator
935983
setDiffe(&IVI, Constant::getNullValue(IVI.getType()), Builder2);
936984
}
937985

938-
inline void getReverseBuilder(IRBuilder<> &Builder2, bool original = true) {
939-
BasicBlock *BB = Builder2.GetInsertBlock();
940-
if (original)
941-
BB = gutils->getNewFromOriginal(BB);
942-
BasicBlock *BB2 = gutils->reverseBlocks[BB].back();
943-
if (!BB2) {
944-
llvm::errs() << "oldFunc: " << *gutils->oldFunc << "\n";
945-
llvm::errs() << "newFunc: " << *gutils->newFunc << "\n";
946-
llvm::errs() << "could not invert " << *BB;
947-
}
948-
assert(BB2);
949-
950-
if (BB2->getTerminator())
951-
Builder2.SetInsertPoint(BB2->getTerminator());
952-
else
953-
Builder2.SetInsertPoint(BB2);
954-
Builder2.SetCurrentDebugLocation(
955-
gutils->getNewFromOriginal(Builder2.getCurrentDebugLocation()));
956-
Builder2.setFastMathFlags(getFast());
986+
void getReverseBuilder(IRBuilder<> &Builder2, bool original = true) {
987+
((GradientUtils *)gutils)->getReverseBuilder(Builder2, original);
988+
}
989+
990+
void getForwardBuilder(IRBuilder<> &Builder2) {
991+
((GradientUtils *)gutils)->getForwardBuilder(Builder2);
957992
}
958993

959994
Value *diffe(Value *val, IRBuilder<> &Builder) {
@@ -1398,19 +1433,7 @@ class AdjointGenerator
13981433

13991434
void createBinaryOperatorDual(llvm::BinaryOperator &BO) {
14001435
IRBuilder<> Builder2(&BO);
1401-
1402-
Instruction *nBO = gutils->getNewFromOriginal(&BO);
1403-
1404-
assert(nBO);
1405-
assert(nBO->getNextNode());
1406-
1407-
if (nBO->getNextNode()) {
1408-
Builder2.SetInsertPoint(nBO->getNextNode());
1409-
}
1410-
1411-
Builder2.SetCurrentDebugLocation(
1412-
gutils->getNewFromOriginal(Builder2.getCurrentDebugLocation()));
1413-
Builder2.setFastMathFlags(getFast());
1436+
getForwardBuilder(Builder2);
14141437

14151438
Value *orig_op0 = BO.getOperand(0);
14161439
Value *orig_op1 = BO.getOperand(1);

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2659,16 +2659,26 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
26592659

26602660
assert(!todiff->empty());
26612661

2662-
ReturnType retVal =
2663-
returnValue ? (dretPtr ? ReturnType::ArgsWithTwoReturns
2664-
: ReturnType::ArgsWithReturn)
2665-
: (dretPtr ? ReturnType::ArgsWithReturn : ReturnType::Args);
2662+
ReturnType retVal;
2663+
if (fwdMode) {
2664+
auto TR = TA.analyzeFunction(oldTypeInfo);
2665+
bool retActive = TR.getReturnAnalysis().Inner0().isFloat();
2666+
2667+
retVal = returnValue
2668+
? (retActive ? ReturnType::TwoReturns : ReturnType::Return)
2669+
: (retActive ? ReturnType::Return : ReturnType::Void);
2670+
} else {
2671+
retVal = returnValue
2672+
? (dretPtr ? ReturnType::ArgsWithTwoReturns
2673+
: ReturnType::ArgsWithReturn)
2674+
: (dretPtr ? ReturnType::ArgsWithReturn : ReturnType::Args);
2675+
}
26662676

26672677
bool diffeReturnArg = fwdMode ? false : retType == DIFFE_TYPE::OUT_DIFF;
26682678

26692679
DiffeGradientUtils *gutils = DiffeGradientUtils::CreateFromClone(
26702680
*this, topLevel, todiff, TLI, TA, retType, diffeReturnArg, constant_args,
2671-
fwdMode ? ReturnType::Return : retVal, additionalArg);
2681+
retVal, additionalArg);
26722682

26732683
if (omp)
26742684
gutils->setupOMPFor();

enzyme/Enzyme/GradientUtils.h

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -957,11 +957,10 @@ class GradientUtils : public CacheUtility {
957957
if (!TR.query(inst).Inner0().isPossiblePointer())
958958
continue;
959959

960-
Instruction *newi = getNewFromOriginal(inst);
961-
962960
if (isa<LoadInst>(inst)) {
963-
IRBuilder<> BuilderZ(getNextNonDebugInstruction(newi));
964-
BuilderZ.setFastMathFlags(getFast());
961+
IRBuilder<> BuilderZ(inst);
962+
getForwardBuilder(BuilderZ);
963+
965964
PHINode *anti = BuilderZ.CreatePHI(inst->getType(), 1,
966965
inst->getName() + "'il_phi");
967966
invertedPointers[inst] = anti;
@@ -987,8 +986,9 @@ class GradientUtils : public CacheUtility {
987986
continue;
988987
}
989988

990-
IRBuilder<> BuilderZ(getNextNonDebugInstruction(newi));
991-
BuilderZ.setFastMathFlags(getFast());
989+
IRBuilder<> BuilderZ(inst);
990+
getForwardBuilder(BuilderZ);
991+
992992
PHINode *anti =
993993
BuilderZ.CreatePHI(op->getType(), 1, op->getName() + "'ip_phi");
994994
invertedPointers[inst] = anti;
@@ -1181,6 +1181,39 @@ class GradientUtils : public CacheUtility {
11811181
/*successor*/ BasicBlock *>>>
11821182
&targetToPreds,
11831183
const std::map<BasicBlock *, PHINode *> *replacePHIs = nullptr);
1184+
1185+
void getReverseBuilder(IRBuilder<> &Builder2, bool original = true) {
1186+
BasicBlock *BB = Builder2.GetInsertBlock();
1187+
if (original)
1188+
BB = getNewFromOriginal(BB);
1189+
BasicBlock *BB2 = reverseBlocks[BB].back();
1190+
if (!BB2) {
1191+
llvm::errs() << "oldFunc: " << oldFunc << "\n";
1192+
llvm::errs() << "newFunc: " << newFunc << "\n";
1193+
llvm::errs() << "could not invert " << *BB;
1194+
}
1195+
assert(BB2);
1196+
1197+
if (BB2->getTerminator())
1198+
Builder2.SetInsertPoint(BB2->getTerminator());
1199+
else
1200+
Builder2.SetInsertPoint(BB2);
1201+
Builder2.SetCurrentDebugLocation(
1202+
getNewFromOriginal(Builder2.getCurrentDebugLocation()));
1203+
Builder2.setFastMathFlags(getFast());
1204+
}
1205+
1206+
void getForwardBuilder(IRBuilder<> &Builder2) {
1207+
Instruction *insert = &*Builder2.GetInsertPoint();
1208+
Instruction *nInsert = getNewFromOriginal(insert);
1209+
1210+
assert(nInsert);
1211+
1212+
Builder2.SetInsertPoint(getNextNonDebugInstruction(nInsert));
1213+
Builder2.SetCurrentDebugLocation(
1214+
getNewFromOriginal(Builder2.getCurrentDebugLocation()));
1215+
Builder2.setFastMathFlags(getFast());
1216+
}
11841217
};
11851218

11861219
class DiffeGradientUtils : public GradientUtils {

enzyme/Enzyme/Utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ enum class ReturnType {
220220
Tape,
221221
TwoReturns,
222222
Return,
223+
Void,
223224
};
224225

225226
/// Potential differentiable argument classifications
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -inline -mem2reg -instsimplify -gvn -dse -dse -S | FileCheck %s
2+
3+
; __attribute__((noinline))
4+
; void addOneMem(double *x) {
5+
; *x += 1;
6+
; }
7+
;
8+
; void test_derivative(double *x, double *xp) {
9+
; __builtin_autodiff(addOneMem, x, xp);
10+
; }
11+
12+
; Function Attrs: noinline norecurse nounwind uwtable
13+
define dso_local void @addOneMem(double* nocapture %x) {
14+
entry:
15+
%0 = load double, double* %x, align 8, !tbaa !2
16+
%add = fadd fast double %0, 1.000000e+00
17+
store double %add, double* %x, align 8, !tbaa !2
18+
ret void
19+
}
20+
21+
; Function Attrs: nounwind uwtable
22+
define dso_local void @test_derivative(double* %x, double* %xp) local_unnamed_addr {
23+
entry:
24+
%0 = tail call double (void (double*)*, ...) @__enzyme_fwddiff(void (double*)* nonnull @addOneMem, double* %x, double* %xp)
25+
ret void
26+
}
27+
28+
; Function Attrs: nounwind
29+
declare double @__enzyme_fwddiff(void (double*)*, ...)
30+
31+
!llvm.module.flags = !{!0}
32+
!llvm.ident = !{!1}
33+
34+
!0 = !{i32 1, !"wchar_size", i32 4}
35+
!1 = !{!"clang version 7.1.0 "}
36+
!2 = !{!3, !3, i64 0}
37+
!3 = !{!"double", !4, i64 0}
38+
!4 = !{!"omnipotent char", !5, i64 0}
39+
!5 = !{!"Simple C/C++ TBAA"}
40+
41+
42+
; CHECK: define {{(dso_local )?}}void @test_derivative(double* %x, double* %xp)
43+
; CHECK-NEXT: entry:
44+
; CHECK-NEXT: %0 = load double, double* %x, align 8, !tbaa !2
45+
; CHECK-NEXT: %1 = load double, double* %xp
46+
; CHECK-NEXT: %add.i = fadd fast double %0, 1.000000e+00
47+
; CHECK-NEXT: store double %add.i, double* %x, align 8, !tbaa !2
48+
; CHECK-NEXT: store double %1, double* %xp, align 8
49+
; CHECK-NEXT: ret void
50+
; CHECK-NEXT: }
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
2+
3+
; Function Attrs: noinline nounwind readnone uwtable
4+
define double @tester(double %x) {
5+
entry:
6+
tail call void @myprint(double %x)
7+
ret double %x
8+
}
9+
10+
define double @test_derivative(double %x) {
11+
entry:
12+
%0 = tail call double (double (double)*, ...) @__enzyme_fwddiff(double (double)* nonnull @tester, double %x, double 1.0)
13+
ret double %0
14+
}
15+
16+
declare void @myprint(double %x) #0
17+
18+
; Function Attrs: nounwind
19+
declare double @__enzyme_fwddiff(double (double)*, ...)
20+
21+
attributes #0 = { "enzyme_inactive" }
22+
23+
; CHECK: define internal {{(dso_local )?}}{ double } @diffetester(double %x, double %"x'")
24+
; CHECK-NEXT: entry:
25+
; CHECK-NEXT: tail call void @myprint(double %x)
26+
; CHECK-NEXT: %0 = insertvalue { double } undef, double %"x'", 0
27+
; CHECK-NEXT: ret { double } %0
28+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)