@@ -269,10 +269,15 @@ class AdjointGenerator
269
269
IRBuilder<> BuilderZ (newi);
270
270
Value *newip = nullptr ;
271
271
272
- bool needShadow = is_value_needed_in_reverse<ValueType::ShadowPtr>(
273
- TR, gutils, &I,
274
- /* toplevel*/ Mode == DerivativeMode::ReverseModeCombined,
275
- oldUnreachable);
272
+ // TODO: In the case of fwd mode this should be true if the loaded value
273
+ // itself is used as a pointer.
274
+ bool needShadow =
275
+ Mode == DerivativeMode::ForwardMode
276
+ ? false
277
+ : is_value_needed_in_reverse<ValueType::ShadowPtr>(
278
+ TR, gutils, &I,
279
+ /* toplevel*/ Mode == DerivativeMode::ReverseModeCombined,
280
+ oldUnreachable);
276
281
277
282
switch (Mode) {
278
283
@@ -291,8 +296,8 @@ class AdjointGenerator
291
296
gutils->invertedPointers [&I] = newip;
292
297
break ;
293
298
}
294
-
295
- case DerivativeMode::ReverseModeGradient : {
299
+ case DerivativeMode::ReverseModeGradient:
300
+ case DerivativeMode::ForwardMode : {
296
301
// only make shadow where caching needed
297
302
if (can_modref && needShadow) {
298
303
newip = gutils->cacheForReverse (BuilderZ, placeholder,
@@ -322,13 +327,18 @@ class AdjointGenerator
322
327
323
328
Value *inst = newi;
324
329
330
+ // TODO: In the case of fwd mode this should be true if the loaded value
331
+ // itself is used as a pointer.
332
+ bool primalNeededInReverse =
333
+ Mode == DerivativeMode::ForwardMode
334
+ ? false
335
+ : is_value_needed_in_reverse<ValueType::Primal>(
336
+ TR, gutils, &I,
337
+ /* toplevel*/ Mode == DerivativeMode::ReverseModeCombined,
338
+ oldUnreachable);
325
339
// ! Store loads that need to be cached for use in reverse pass
326
340
if (cache_reads_always ||
327
- (!cache_reads_never && can_modref &&
328
- is_value_needed_in_reverse<ValueType::Primal>(
329
- TR, gutils, &I,
330
- /* toplevel*/ Mode == DerivativeMode::ReverseModeCombined,
331
- oldUnreachable))) {
341
+ (!cache_reads_never && can_modref && primalNeededInReverse)) {
332
342
if (!gutils->unnecessaryIntermediates .count (&I)) {
333
343
IRBuilder<> BuilderZ (gutils->getNewFromOriginal (&I)->getNextNode ());
334
344
// auto tbaa = inst->getMetadata(LLVMContext::MD_tbaa);
@@ -379,15 +389,36 @@ class AdjointGenerator
379
389
}
380
390
381
391
if (isfloat) {
382
- IRBuilder<> Builder2 (parent);
383
- getReverseBuilder (Builder2);
384
- auto prediff = diffe (&I, Builder2);
385
- setDiffe (&I, Constant::getNullValue (type), Builder2);
386
392
387
- if (!gutils->isConstantValue (I.getOperand (0 ))) {
388
- ((DiffeGradientUtils *)gutils)
389
- ->addToInvertedPtrDiffe (I.getOperand (0 ), prediff, Builder2,
390
- alignment, OrigOffset);
393
+ switch (Mode) {
394
+ case DerivativeMode::ForwardMode: {
395
+ IRBuilder<> Builder2 (&I);
396
+ getForwardBuilder (Builder2);
397
+
398
+ if (!gutils->isConstantValue (&I)) {
399
+ auto diff = Builder2.CreateLoad (
400
+ gutils->invertPointerM (I.getOperand (0 ), Builder2));
401
+ setDiffe (&I, diff, Builder2);
402
+ }
403
+ break ;
404
+ }
405
+ case DerivativeMode::ReverseModeGradient:
406
+ case DerivativeMode::ReverseModeCombined: {
407
+ IRBuilder<> Builder2 (parent);
408
+ getReverseBuilder (Builder2);
409
+
410
+ auto prediff = diffe (&I, Builder2);
411
+ setDiffe (&I, Constant::getNullValue (type), Builder2);
412
+
413
+ if (!gutils->isConstantValue (I.getOperand (0 ))) {
414
+ ((DiffeGradientUtils *)gutils)
415
+ ->addToInvertedPtrDiffe (I.getOperand (0 ), prediff, Builder2,
416
+ alignment, OrigOffset);
417
+ }
418
+ break ;
419
+ }
420
+ case DerivativeMode::ReverseModePrimal:
421
+ break ;
391
422
}
392
423
}
393
424
}
@@ -494,8 +525,9 @@ class AdjointGenerator
494
525
495
526
if (FT) {
496
527
// ! Only need to update the reverse function
497
- if (Mode == DerivativeMode::ReverseModeGradient ||
498
- Mode == DerivativeMode::ReverseModeCombined) {
528
+ switch (Mode) {
529
+ case DerivativeMode::ReverseModeGradient:
530
+ case DerivativeMode::ReverseModeCombined: {
499
531
IRBuilder<> Builder2 (SI.getParent ());
500
532
getReverseBuilder (Builder2);
501
533
@@ -512,13 +544,29 @@ class AdjointGenerator
512
544
ts = setPtrDiffe (orig_ptr, Constant::getNullValue (valType), Builder2);
513
545
addToDiffe (orig_val, dif1, Builder2, FT);
514
546
}
547
+ break ;
548
+ }
549
+ case DerivativeMode::ForwardMode: {
550
+ IRBuilder<> Builder2 (&SI);
551
+ getForwardBuilder (Builder2);
552
+
553
+ if (constantval) {
554
+ ts = setPtrDiffe (orig_ptr, Constant::getNullValue (valType), Builder2);
555
+ } else {
556
+ auto diff = diffe (orig_val, Builder2);
557
+
558
+ ts = setPtrDiffe (orig_ptr, diff, Builder2);
559
+ }
560
+ break ;
561
+ }
515
562
}
516
563
517
564
// ! Storing an integer or pointer
518
565
} else {
519
566
// ! Only need to update the forward function
520
567
if (Mode == DerivativeMode::ReverseModePrimal ||
521
- Mode == DerivativeMode::ReverseModeCombined) {
568
+ Mode == DerivativeMode::ReverseModeCombined ||
569
+ Mode == DerivativeMode::ForwardMode) {
522
570
IRBuilder<> storeBuilder (gutils->getNewFromOriginal (&SI));
523
571
524
572
Value *valueop = nullptr ;
@@ -935,25 +983,12 @@ class AdjointGenerator
935
983
setDiffe (&IVI, Constant::getNullValue (IVI.getType ()), Builder2);
936
984
}
937
985
938
- inline void getReverseBuilder (IRBuilder<> &Builder2, bool original = true ) {
939
- BasicBlock *BB = Builder2.GetInsertBlock ();
940
- if (original)
941
- BB = gutils->getNewFromOriginal (BB);
942
- BasicBlock *BB2 = gutils->reverseBlocks [BB].back ();
943
- if (!BB2) {
944
- llvm::errs () << " oldFunc: " << *gutils->oldFunc << " \n " ;
945
- llvm::errs () << " newFunc: " << *gutils->newFunc << " \n " ;
946
- llvm::errs () << " could not invert " << *BB;
947
- }
948
- assert (BB2);
949
-
950
- if (BB2->getTerminator ())
951
- Builder2.SetInsertPoint (BB2->getTerminator ());
952
- else
953
- Builder2.SetInsertPoint (BB2);
954
- Builder2.SetCurrentDebugLocation (
955
- gutils->getNewFromOriginal (Builder2.getCurrentDebugLocation ()));
956
- Builder2.setFastMathFlags (getFast ());
986
+ void getReverseBuilder (IRBuilder<> &Builder2, bool original = true ) {
987
+ ((GradientUtils *)gutils)->getReverseBuilder (Builder2, original);
988
+ }
989
+
990
+ void getForwardBuilder (IRBuilder<> &Builder2) {
991
+ ((GradientUtils *)gutils)->getForwardBuilder (Builder2);
957
992
}
958
993
959
994
Value *diffe (Value *val, IRBuilder<> &Builder) {
@@ -1398,19 +1433,7 @@ class AdjointGenerator
1398
1433
1399
1434
void createBinaryOperatorDual (llvm::BinaryOperator &BO) {
1400
1435
IRBuilder<> Builder2 (&BO);
1401
-
1402
- Instruction *nBO = gutils->getNewFromOriginal (&BO);
1403
-
1404
- assert (nBO);
1405
- assert (nBO->getNextNode ());
1406
-
1407
- if (nBO->getNextNode ()) {
1408
- Builder2.SetInsertPoint (nBO->getNextNode ());
1409
- }
1410
-
1411
- Builder2.SetCurrentDebugLocation (
1412
- gutils->getNewFromOriginal (Builder2.getCurrentDebugLocation ()));
1413
- Builder2.setFastMathFlags (getFast ());
1436
+ getForwardBuilder (Builder2);
1414
1437
1415
1438
Value *orig_op0 = BO.getOperand (0 );
1416
1439
Value *orig_op1 = BO.getOperand (1 );
0 commit comments