Skip to content

Commit 14e7244

Browse files
committed
!fixup also remove entries from ShapeMap when removing instructions.
1 parent db42110 commit 14e7244

File tree

1 file changed

+29
-22
lines changed

1 file changed

+29
-22
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -97,19 +97,6 @@ static DISubprogram *getSubprogram(DIScope *Scope) {
9797
return cast<DILocalScope>(Scope)->getSubprogram();
9898
}
9999

100-
/// Erase \p V from \p BB and move \II forward to avoid invalidating
101-
/// iterators.
102-
static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
103-
BasicBlock &BB) {
104-
auto *Inst = cast<Instruction>(V);
105-
// Still used, don't erase.
106-
if (!Inst->use_empty())
107-
return;
108-
if (II != BB.rend() && Inst == &*II)
109-
++II;
110-
Inst->eraseFromParent();
111-
}
112-
113100
/// Return true if V is a splat of a value (which is used when multiplying a
114101
/// matrix with a scalar).
115102
static bool isSplat(Value *V) {
@@ -756,6 +743,26 @@ class LowerMatrixIntrinsics {
756743
return Operation(T0, Shape0.t(), T1, Shape1.t());
757744
}
758745

746+
void eraseFromParentAndRemoveFromShapeMap(Instruction *Inst) {
747+
auto Iter = ShapeMap.find(Inst);
748+
if (Iter != ShapeMap.end())
749+
ShapeMap.erase(Iter);
750+
Inst->eraseFromParent();
751+
}
752+
753+
/// Erase \p V from \p BB and move \II forward to avoid invalidating
754+
/// iterators.
755+
void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
756+
BasicBlock &BB) {
757+
auto *Inst = cast<Instruction>(V);
758+
// Still used, don't erase.
759+
if (!Inst->use_empty())
760+
return;
761+
if (II != BB.rend() && Inst == &*II)
762+
++II;
763+
eraseFromParentAndRemoveFromShapeMap(Inst);
764+
}
765+
759766
void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
760767
// We need to remove Old from the ShapeMap otherwise RAUW will replace it
761768
// with New. We should only add New it it supportsShapeInfo so we insert
@@ -869,13 +876,13 @@ class LowerMatrixIntrinsics {
869876

870877
void liftTranspose(Instruction &I) {
871878
// Erase dead Instructions after lifting transposes from binops.
872-
auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) {
879+
auto CleanupBinOp = [this](Instruction &T, Value *A, Value *B) {
873880
if (T.use_empty())
874-
T.eraseFromParent();
881+
eraseFromParentAndRemoveFromShapeMap(&T);
875882
if (A->use_empty())
876-
cast<Instruction>(A)->eraseFromParent();
883+
eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(A));
877884
if (A != B && B->use_empty())
878-
cast<Instruction>(B)->eraseFromParent();
885+
eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(B));
879886
};
880887

881888
Value *A, *B, *AT, *BT;
@@ -1482,7 +1489,7 @@ class LowerMatrixIntrinsics {
14821489
m_Value(Arg)))) {
14831490
auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
14841491
Op->replaceAllUsesWith(NewLoad);
1485-
cast<Instruction>(Op)->eraseFromParent();
1492+
eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(Op));
14861493
return;
14871494
} else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
14881495
m_Value(Arg)))) {
@@ -1851,15 +1858,15 @@ class LowerMatrixIntrinsics {
18511858
// Mark eliminated instructions as fused and remove them.
18521859
FusedInsts.insert(Store);
18531860
FusedInsts.insert(MatMul);
1854-
Store->eraseFromParent();
1855-
MatMul->eraseFromParent();
1861+
eraseFromParentAndRemoveFromShapeMap(Store);
1862+
eraseFromParentAndRemoveFromShapeMap(MatMul);
18561863
if (LoadOp0->hasNUses(0)) {
18571864
FusedInsts.insert(LoadOp0);
1858-
LoadOp0->eraseFromParent();
1865+
eraseFromParentAndRemoveFromShapeMap(eraseFromParent());
18591866
}
18601867
if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
18611868
FusedInsts.insert(LoadOp1);
1862-
LoadOp1->eraseFromParent();
1869+
eraseFromParentAndRemoveFromShapeMap(eraseFromParent());
18631870
}
18641871
}
18651872

0 commit comments

Comments
 (0)