Skip to content

Commit 514f044

Browse files
authored
Only set debug info if unwrapping to the same function & Fix nondeterministic cache (rust-lang#246)
1 parent 1196783 commit 514f044

File tree

3 files changed

+393
-9
lines changed

3 files changed

+393
-9
lines changed

enzyme/Enzyme/EnzymeLogic.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -3028,11 +3028,20 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
30283028
unwrapToOrig[pair.second].push_back(
30293029
const_cast<Instruction *>(pair.first));
30303030
gutils->unwrappedLoads.clear();
3031+
30313032
for (auto pair : newIToNextI) {
30323033
auto newi = pair.first;
30333034
auto nexti = pair.second;
30343035
newi->replaceAllUsesWith(nexti);
30353036
gutils->erase(newi);
3037+
}
3038+
3039+
// This most occur after all the replacements have been made
3040+
// in the previous loop, lest a loop bound being unwrapped use
3041+
// a value being replaced.
3042+
for (auto pair : newIToNextI) {
3043+
auto newi = pair.first;
3044+
auto nexti = pair.second;
30363045
for (auto V : unwrapToOrig[newi]) {
30373046
ValueToValueMapTy empty;
30383047
IRBuilder<> lb(V);

enzyme/Enzyme/GradientUtils.cpp

+51-9
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
340340
if (auto newi = dyn_cast<Instruction>(toreturn)) {
341341
newi->copyIRFlags(op);
342342
unwrappedLoads[newi] = val;
343+
if (newi->getParent()->getParent() != op->getParent()->getParent())
344+
newi->setDebugLoc(nullptr);
343345
}
344346
if (permitCache)
345347
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
@@ -356,6 +358,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
356358
if (auto newi = dyn_cast<Instruction>(toreturn)) {
357359
newi->copyIRFlags(op);
358360
unwrappedLoads[newi] = val;
361+
if (newi->getParent()->getParent() != op->getParent()->getParent())
362+
newi->setDebugLoc(nullptr);
359363
}
360364
assert(val->getType() == toreturn->getType());
361365
return toreturn;
@@ -373,6 +377,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
373377
if (auto newi = dyn_cast<Instruction>(toreturn)) {
374378
newi->copyIRFlags(op);
375379
unwrappedLoads[newi] = val;
380+
if (newi->getParent()->getParent() != op->getParent()->getParent())
381+
newi->setDebugLoc(nullptr);
376382
}
377383
assert(val->getType() == toreturn->getType());
378384
return toreturn;
@@ -390,6 +396,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
390396
if (auto newi = dyn_cast<Instruction>(toreturn)) {
391397
newi->copyIRFlags(op);
392398
unwrappedLoads[newi] = val;
399+
if (newi->getParent()->getParent() != op->getParent()->getParent())
400+
newi->setDebugLoc(nullptr);
393401
}
394402
assert(val->getType() == toreturn->getType());
395403
return toreturn;
@@ -410,6 +418,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
410418
if (auto newi = dyn_cast<Instruction>(toreturn)) {
411419
newi->copyIRFlags(op);
412420
unwrappedLoads[newi] = val;
421+
if (newi->getParent()->getParent() != op->getParent()->getParent())
422+
newi->setDebugLoc(nullptr);
413423
}
414424
assert(val->getType() == toreturn->getType());
415425
return toreturn;
@@ -432,6 +442,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
432442
if (auto newi = dyn_cast<Instruction>(toreturn)) {
433443
newi->copyIRFlags(op);
434444
unwrappedLoads[newi] = val;
445+
if (newi->getParent()->getParent() != op->getParent()->getParent())
446+
newi->setDebugLoc(nullptr);
435447
}
436448
assert(val->getType() == toreturn->getType());
437449
return toreturn;
@@ -453,6 +465,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
453465
if (auto newi = dyn_cast<Instruction>(toreturn)) {
454466
newi->copyIRFlags(op);
455467
unwrappedLoads[newi] = val;
468+
if (newi->getParent()->getParent() != op->getParent()->getParent())
469+
newi->setDebugLoc(nullptr);
456470
}
457471
if (permitCache)
458472
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
@@ -470,6 +484,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
470484
if (auto newi = dyn_cast<Instruction>(toreturn)) {
471485
newi->copyIRFlags(op);
472486
unwrappedLoads[newi] = val;
487+
if (newi->getParent()->getParent() != op->getParent()->getParent())
488+
newi->setDebugLoc(nullptr);
473489
}
474490
if (permitCache)
475491
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
@@ -487,6 +503,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
487503
if (auto newi = dyn_cast<Instruction>(toreturn)) {
488504
newi->copyIRFlags(op);
489505
unwrappedLoads[newi] = val;
506+
if (newi->getParent()->getParent() != op->getParent()->getParent())
507+
newi->setDebugLoc(nullptr);
490508
}
491509
if (permitCache)
492510
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
@@ -503,6 +521,9 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
503521
if (auto newi = dyn_cast<Instruction>(toreturn)) {
504522
newi->copyIRFlags(op);
505523
unwrappedLoads[newi] = val;
524+
if (newi->getParent()->getParent() !=
525+
cast<Instruction>(val)->getParent()->getParent())
526+
newi->setDebugLoc(nullptr);
506527
}
507528
if (permitCache)
508529
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
@@ -524,6 +545,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
524545
if (auto newi = dyn_cast<Instruction>(toreturn)) {
525546
newi->copyIRFlags(op);
526547
unwrappedLoads[newi] = val;
548+
if (newi->getParent()->getParent() != op->getParent()->getParent())
549+
newi->setDebugLoc(nullptr);
527550
}
528551
if (permitCache)
529552
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
@@ -549,6 +572,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
549572
if (auto newi = dyn_cast<Instruction>(toreturn)) {
550573
newi->copyIRFlags(inst);
551574
unwrappedLoads[newi] = val;
575+
if (newi->getParent()->getParent() != inst->getParent()->getParent())
576+
newi->setDebugLoc(nullptr);
552577
}
553578
if (permitCache)
554579
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
@@ -560,7 +585,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
560585

561586
bool legalMove = mode == UnwrapMode::LegalFullUnwrap ||
562587
mode == UnwrapMode::LegalFullUnwrapNoTapeReplace;
563-
if (mode != UnwrapMode::LegalFullUnwrap) {
588+
if (!legalMove) {
564589
BasicBlock *parent = nullptr;
565590
if (isOriginalBlock(*BuilderM.GetInsertBlock()))
566591
parent = BuilderM.GetInsertBlock();
@@ -591,9 +616,11 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
591616
if (pidx->getType() != load->getOperand(0)->getType()) {
592617
llvm::errs() << "load: " << *load << "\n";
593618
llvm::errs() << "load->getOperand(0): " << *load->getOperand(0) << "\n";
594-
llvm::errs() << "idx: " << *pidx << "\n";
619+
llvm::errs() << "idx: " << *pidx << " unwrapping: " << *val
620+
<< " mode=" << mode << "\n";
595621
}
596622
assert(pidx->getType() == load->getOperand(0)->getType());
623+
597624
auto toreturn = BuilderM.CreateLoad(pidx, load->getName() + "_unwrap");
598625
toreturn->copyIRFlags(load);
599626
unwrappedLoads[toreturn] = load;
@@ -605,7 +632,10 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
605632
toreturn->setVolatile(load->isVolatile());
606633
toreturn->setOrdering(load->getOrdering());
607634
toreturn->setSyncScopeID(load->getSyncScopeID());
608-
toreturn->setDebugLoc(getNewFromOriginal(load->getDebugLoc()));
635+
if (toreturn->getParent()->getParent() != load->getParent()->getParent())
636+
toreturn->setDebugLoc(nullptr);
637+
else
638+
toreturn->setDebugLoc(getNewFromOriginal(load->getDebugLoc()));
609639
toreturn->setMetadata(LLVMContext::MD_tbaa,
610640
load->getMetadata(LLVMContext::MD_tbaa));
611641
toreturn->setMetadata(LLVMContext::MD_invariant_group,
@@ -619,7 +649,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
619649

620650
bool legalMove = mode == UnwrapMode::LegalFullUnwrap ||
621651
mode == UnwrapMode::LegalFullUnwrapNoTapeReplace;
622-
if (mode != UnwrapMode::LegalFullUnwrap) {
652+
if (!legalMove) {
623653
legalMove = legalRecompute(op, available, &BuilderM);
624654
}
625655
if (!legalMove)
@@ -646,7 +676,10 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
646676
toreturn->setAttributes(op->getAttributes());
647677
toreturn->setCallingConv(op->getCallingConv());
648678
toreturn->setTailCallKind(op->getTailCallKind());
649-
toreturn->setDebugLoc(getNewFromOriginal(op->getDebugLoc()));
679+
if (toreturn->getParent()->getParent() == op->getParent()->getParent())
680+
toreturn->setDebugLoc(getNewFromOriginal(op->getDebugLoc()));
681+
else
682+
toreturn->setDebugLoc(nullptr);
650683
if (permitCache)
651684
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
652685
unwrappedLoads[toreturn] = val;
@@ -664,7 +697,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
664697

665698
bool legalMove = mode == UnwrapMode::LegalFullUnwrap ||
666699
mode == UnwrapMode::LegalFullUnwrapNoTapeReplace;
667-
if (mode != UnwrapMode::LegalFullUnwrap) {
700+
if (!legalMove) {
668701
// TODO actually consider whether this is legal to move to the new
669702
// location, rather than recomputable anywhere
670703
legalMove = legalRecompute(dli, available, &BuilderM);
@@ -1884,13 +1917,18 @@ bool GradientUtils::legalRecompute(const Value *val,
18841917
const Instruction *orig = nullptr;
18851918
if (li->getParent()->getParent() == oldFunc) {
18861919
orig = li;
1887-
} else {
1920+
} else if (li->getParent()->getParent() == newFunc) {
18881921
orig = isOriginal(li);
18891922
// todo consider when we pass non original queries
18901923
if (orig && !isa<LoadInst>(orig)) {
18911924
return legalRecompute(orig, available, BuilderM, reverse,
18921925
legalRecomputeCache);
18931926
}
1927+
} else {
1928+
llvm::errs() << " newFunc: " << *newFunc << "\n";
1929+
llvm::errs() << " parent: " << *li->getParent()->getParent() << "\n";
1930+
llvm::errs() << " li: " << *li << "\n";
1931+
assert(0 && "illegal load legalRecopmute query");
18941932
}
18951933

18961934
if (orig) {
@@ -2005,7 +2043,9 @@ bool GradientUtils::legalRecompute(const Value *val,
20052043
if (n == "lgamma" || n == "lgammaf" || n == "lgammal" ||
20062044
n == "lgamma_r" || n == "lgammaf_r" || n == "lgammal_r" ||
20072045
n == "__lgamma_r_finite" || n == "__lgammaf_r_finite" ||
2008-
n == "__lgammal_r_finite" || isMemFreeLibMFunction(n)) {
2046+
n == "__lgammal_r_finite" || isMemFreeLibMFunction(n) ||
2047+
n.startswith("enzyme_wrapmpi$$") || n == "omp_get_thread_num" ||
2048+
n == "omp_get_max_threads") {
20092049
return true;
20102050
}
20112051
}
@@ -2173,7 +2213,9 @@ bool GradientUtils::shouldRecompute(const Value *val,
21732213
n == "__lgamma_r_finite" || n == "__lgammaf_r_finite" ||
21742214
n == "__lgammal_r_finite" || n == "tanh" || n == "tanhf" ||
21752215
n == "__pow_finite" || n == "__fd_sincos_1" ||
2176-
isMemFreeLibMFunction(n) || n == "julia.pointer_from_objref") {
2216+
isMemFreeLibMFunction(n) || n == "julia.pointer_from_objref" ||
2217+
n.startswith("enzyme_wrapmpi$$") || n == "omp_get_thread_num" ||
2218+
n == "omp_get_max_threads") {
21772219
return true;
21782220
}
21792221
}

0 commit comments

Comments
 (0)