@@ -340,6 +340,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
340
340
if (auto newi = dyn_cast<Instruction>(toreturn)) {
341
341
newi->copyIRFlags (op);
342
342
unwrappedLoads[newi] = val;
343
+ if (newi->getParent ()->getParent () != op->getParent ()->getParent ())
344
+ newi->setDebugLoc (nullptr );
343
345
}
344
346
if (permitCache)
345
347
unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
@@ -356,6 +358,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
356
358
if (auto newi = dyn_cast<Instruction>(toreturn)) {
357
359
newi->copyIRFlags (op);
358
360
unwrappedLoads[newi] = val;
361
+ if (newi->getParent ()->getParent () != op->getParent ()->getParent ())
362
+ newi->setDebugLoc (nullptr );
359
363
}
360
364
assert (val->getType () == toreturn->getType ());
361
365
return toreturn;
@@ -373,6 +377,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
373
377
if (auto newi = dyn_cast<Instruction>(toreturn)) {
374
378
newi->copyIRFlags (op);
375
379
unwrappedLoads[newi] = val;
380
+ if (newi->getParent ()->getParent () != op->getParent ()->getParent ())
381
+ newi->setDebugLoc (nullptr );
376
382
}
377
383
assert (val->getType () == toreturn->getType ());
378
384
return toreturn;
@@ -390,6 +396,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
390
396
if (auto newi = dyn_cast<Instruction>(toreturn)) {
391
397
newi->copyIRFlags (op);
392
398
unwrappedLoads[newi] = val;
399
+ if (newi->getParent ()->getParent () != op->getParent ()->getParent ())
400
+ newi->setDebugLoc (nullptr );
393
401
}
394
402
assert (val->getType () == toreturn->getType ());
395
403
return toreturn;
@@ -410,6 +418,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
410
418
if (auto newi = dyn_cast<Instruction>(toreturn)) {
411
419
newi->copyIRFlags (op);
412
420
unwrappedLoads[newi] = val;
421
+ if (newi->getParent ()->getParent () != op->getParent ()->getParent ())
422
+ newi->setDebugLoc (nullptr );
413
423
}
414
424
assert (val->getType () == toreturn->getType ());
415
425
return toreturn;
@@ -432,6 +442,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
432
442
if (auto newi = dyn_cast<Instruction>(toreturn)) {
433
443
newi->copyIRFlags (op);
434
444
unwrappedLoads[newi] = val;
445
+ if (newi->getParent ()->getParent () != op->getParent ()->getParent ())
446
+ newi->setDebugLoc (nullptr );
435
447
}
436
448
assert (val->getType () == toreturn->getType ());
437
449
return toreturn;
@@ -453,6 +465,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
453
465
if (auto newi = dyn_cast<Instruction>(toreturn)) {
454
466
newi->copyIRFlags (op);
455
467
unwrappedLoads[newi] = val;
468
+ if (newi->getParent ()->getParent () != op->getParent ()->getParent ())
469
+ newi->setDebugLoc (nullptr );
456
470
}
457
471
if (permitCache)
458
472
unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
@@ -470,6 +484,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
470
484
if (auto newi = dyn_cast<Instruction>(toreturn)) {
471
485
newi->copyIRFlags (op);
472
486
unwrappedLoads[newi] = val;
487
+ if (newi->getParent ()->getParent () != op->getParent ()->getParent ())
488
+ newi->setDebugLoc (nullptr );
473
489
}
474
490
if (permitCache)
475
491
unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
@@ -487,6 +503,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
487
503
if (auto newi = dyn_cast<Instruction>(toreturn)) {
488
504
newi->copyIRFlags (op);
489
505
unwrappedLoads[newi] = val;
506
+ if (newi->getParent ()->getParent () != op->getParent ()->getParent ())
507
+ newi->setDebugLoc (nullptr );
490
508
}
491
509
if (permitCache)
492
510
unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
@@ -503,6 +521,9 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
503
521
if (auto newi = dyn_cast<Instruction>(toreturn)) {
504
522
newi->copyIRFlags (op);
505
523
unwrappedLoads[newi] = val;
524
+ if (newi->getParent ()->getParent () !=
525
+ cast<Instruction>(val)->getParent ()->getParent ())
526
+ newi->setDebugLoc (nullptr );
506
527
}
507
528
if (permitCache)
508
529
unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
@@ -524,6 +545,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
524
545
if (auto newi = dyn_cast<Instruction>(toreturn)) {
525
546
newi->copyIRFlags (op);
526
547
unwrappedLoads[newi] = val;
548
+ if (newi->getParent ()->getParent () != op->getParent ()->getParent ())
549
+ newi->setDebugLoc (nullptr );
527
550
}
528
551
if (permitCache)
529
552
unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
@@ -549,6 +572,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
549
572
if (auto newi = dyn_cast<Instruction>(toreturn)) {
550
573
newi->copyIRFlags (inst);
551
574
unwrappedLoads[newi] = val;
575
+ if (newi->getParent ()->getParent () != inst->getParent ()->getParent ())
576
+ newi->setDebugLoc (nullptr );
552
577
}
553
578
if (permitCache)
554
579
unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
@@ -560,7 +585,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
560
585
561
586
bool legalMove = mode == UnwrapMode::LegalFullUnwrap ||
562
587
mode == UnwrapMode::LegalFullUnwrapNoTapeReplace;
563
- if (mode != UnwrapMode::LegalFullUnwrap ) {
588
+ if (!legalMove ) {
564
589
BasicBlock *parent = nullptr ;
565
590
if (isOriginalBlock (*BuilderM.GetInsertBlock ()))
566
591
parent = BuilderM.GetInsertBlock ();
@@ -591,9 +616,11 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
591
616
if (pidx->getType () != load->getOperand (0 )->getType ()) {
592
617
llvm::errs () << " load: " << *load << " \n " ;
593
618
llvm::errs () << " load->getOperand(0): " << *load->getOperand (0 ) << " \n " ;
594
- llvm::errs () << " idx: " << *pidx << " \n " ;
619
+ llvm::errs () << " idx: " << *pidx << " unwrapping: " << *val
620
+ << " mode=" << mode << " \n " ;
595
621
}
596
622
assert (pidx->getType () == load->getOperand (0 )->getType ());
623
+
597
624
auto toreturn = BuilderM.CreateLoad (pidx, load->getName () + " _unwrap" );
598
625
toreturn->copyIRFlags (load);
599
626
unwrappedLoads[toreturn] = load;
@@ -605,7 +632,10 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
605
632
toreturn->setVolatile (load->isVolatile ());
606
633
toreturn->setOrdering (load->getOrdering ());
607
634
toreturn->setSyncScopeID (load->getSyncScopeID ());
608
- toreturn->setDebugLoc (getNewFromOriginal (load->getDebugLoc ()));
635
+ if (toreturn->getParent ()->getParent () != load->getParent ()->getParent ())
636
+ toreturn->setDebugLoc (nullptr );
637
+ else
638
+ toreturn->setDebugLoc (getNewFromOriginal (load->getDebugLoc ()));
609
639
toreturn->setMetadata (LLVMContext::MD_tbaa,
610
640
load->getMetadata (LLVMContext::MD_tbaa));
611
641
toreturn->setMetadata (LLVMContext::MD_invariant_group,
@@ -619,7 +649,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
619
649
620
650
bool legalMove = mode == UnwrapMode::LegalFullUnwrap ||
621
651
mode == UnwrapMode::LegalFullUnwrapNoTapeReplace;
622
- if (mode != UnwrapMode::LegalFullUnwrap ) {
652
+ if (!legalMove ) {
623
653
legalMove = legalRecompute (op, available, &BuilderM);
624
654
}
625
655
if (!legalMove)
@@ -646,7 +676,10 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
646
676
toreturn->setAttributes (op->getAttributes ());
647
677
toreturn->setCallingConv (op->getCallingConv ());
648
678
toreturn->setTailCallKind (op->getTailCallKind ());
649
- toreturn->setDebugLoc (getNewFromOriginal (op->getDebugLoc ()));
679
+ if (toreturn->getParent ()->getParent () == op->getParent ()->getParent ())
680
+ toreturn->setDebugLoc (getNewFromOriginal (op->getDebugLoc ()));
681
+ else
682
+ toreturn->setDebugLoc (nullptr );
650
683
if (permitCache)
651
684
unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
652
685
unwrappedLoads[toreturn] = val;
@@ -664,7 +697,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
664
697
665
698
bool legalMove = mode == UnwrapMode::LegalFullUnwrap ||
666
699
mode == UnwrapMode::LegalFullUnwrapNoTapeReplace;
667
- if (mode != UnwrapMode::LegalFullUnwrap ) {
700
+ if (!legalMove ) {
668
701
// TODO actually consider whether this is legal to move to the new
669
702
// location, rather than recomputable anywhere
670
703
legalMove = legalRecompute (dli, available, &BuilderM);
@@ -1884,13 +1917,18 @@ bool GradientUtils::legalRecompute(const Value *val,
1884
1917
const Instruction *orig = nullptr ;
1885
1918
if (li->getParent ()->getParent () == oldFunc) {
1886
1919
orig = li;
1887
- } else {
1920
+ } else if (li-> getParent ()-> getParent () == newFunc) {
1888
1921
orig = isOriginal (li);
1889
1922
// todo consider when we pass non original queries
1890
1923
if (orig && !isa<LoadInst>(orig)) {
1891
1924
return legalRecompute (orig, available, BuilderM, reverse,
1892
1925
legalRecomputeCache);
1893
1926
}
1927
+ } else {
1928
+ llvm::errs () << " newFunc: " << *newFunc << " \n " ;
1929
+ llvm::errs () << " parent: " << *li->getParent ()->getParent () << " \n " ;
1930
+ llvm::errs () << " li: " << *li << " \n " ;
1931
+ assert (0 && " illegal load legalRecopmute query" );
1894
1932
}
1895
1933
1896
1934
if (orig) {
@@ -2005,7 +2043,9 @@ bool GradientUtils::legalRecompute(const Value *val,
2005
2043
if (n == " lgamma" || n == " lgammaf" || n == " lgammal" ||
2006
2044
n == " lgamma_r" || n == " lgammaf_r" || n == " lgammal_r" ||
2007
2045
n == " __lgamma_r_finite" || n == " __lgammaf_r_finite" ||
2008
- n == " __lgammal_r_finite" || isMemFreeLibMFunction (n)) {
2046
+ n == " __lgammal_r_finite" || isMemFreeLibMFunction (n) ||
2047
+ n.startswith (" enzyme_wrapmpi$$" ) || n == " omp_get_thread_num" ||
2048
+ n == " omp_get_max_threads" ) {
2009
2049
return true ;
2010
2050
}
2011
2051
}
@@ -2173,7 +2213,9 @@ bool GradientUtils::shouldRecompute(const Value *val,
2173
2213
n == " __lgamma_r_finite" || n == " __lgammaf_r_finite" ||
2174
2214
n == " __lgammal_r_finite" || n == " tanh" || n == " tanhf" ||
2175
2215
n == " __pow_finite" || n == " __fd_sincos_1" ||
2176
- isMemFreeLibMFunction (n) || n == " julia.pointer_from_objref" ) {
2216
+ isMemFreeLibMFunction (n) || n == " julia.pointer_from_objref" ||
2217
+ n.startswith (" enzyme_wrapmpi$$" ) || n == " omp_get_thread_num" ||
2218
+ n == " omp_get_max_threads" ) {
2177
2219
return true ;
2178
2220
}
2179
2221
}
0 commit comments