@@ -97,19 +97,6 @@ static DISubprogram *getSubprogram(DIScope *Scope) {
97
97
return cast<DILocalScope>(Scope)->getSubprogram ();
98
98
}
99
99
100
- // / Erase \p V from \p BB and move \II forward to avoid invalidating
101
- // / iterators.
102
- static void eraseFromParentAndMove (Value *V, BasicBlock::reverse_iterator &II,
103
- BasicBlock &BB) {
104
- auto *Inst = cast<Instruction>(V);
105
- // Still used, don't erase.
106
- if (!Inst->use_empty ())
107
- return ;
108
- if (II != BB.rend () && Inst == &*II)
109
- ++II;
110
- Inst->eraseFromParent ();
111
- }
112
-
113
100
// / Return true if V is a splat of a value (which is used when multiplying a
114
101
// / matrix with a scalar).
115
102
static bool isSplat (Value *V) {
@@ -259,7 +246,7 @@ static bool isUniformShape(Value *V) {
259
246
// / Return the ShapeInfo for the result of \p I, it it can be determined.
260
247
static std::optional<ShapeInfo>
261
248
computeShapeInfoForInst (Instruction *I,
262
- const ValueMap <Value *, ShapeInfo> &ShapeMap) {
249
+ const DenseMap <Value *, ShapeInfo> &ShapeMap) {
263
250
Value *M;
264
251
Value *N;
265
252
Value *K;
@@ -493,10 +480,16 @@ class LowerMatrixIntrinsics {
493
480
// / the result value of the instruction, with the only exceptions being store
494
481
// / instructions and the matrix_column_major_store intrinsics. For those, the
495
482
// / shape information indicates that those instructions should be lowered
496
- // / using shape information as well. A ValueMap is used so that when
497
- // / sub-passes like optimizeTransposes performs RAUW the map stays
498
- // / up-to-date.
499
- ValueMap<Value *, ShapeInfo> ShapeMap;
483
+ // / using shape information as well. Note that extra care is needed when
484
+ // / erasing or RAUW'ing a value that is present in ShapeMap. If the
485
+ // / replacement is also a matrix operation, use
486
+ // / updateShapeAndReplaceAllUsesWith to make sure the replacement is added to
487
+ // / ShapeMap. We don't use ValueMap, as there are also cases where we do not
488
+ // / want to add shape information for a replacement instruction. When directly
489
+ // / erasing a value with an entry in ShapeMap, use
490
+ // / eraseFromParentAndRemoveFromShapeMap to make sure ShapeMap is also updated
491
+ // / accordingly.
492
+ DenseMap<Value *, ShapeInfo> ShapeMap;
500
493
501
494
// / List of instructions to remove. While lowering, we are not replacing all
502
495
// / users of a lowered instruction, if shape information is available and
@@ -758,6 +751,30 @@ class LowerMatrixIntrinsics {
758
751
return Operation (T0, Shape0.t (), T1, Shape1.t ());
759
752
}
760
753
754
+ // / Erase \p Inst from both ShapeMap (if an entry exists) and erase \p Inst
755
+ // / itself.
756
+ void eraseFromParentAndRemoveFromShapeMap (Instruction *Inst) {
757
+ auto Iter = ShapeMap.find (Inst);
758
+ if (Iter != ShapeMap.end ())
759
+ ShapeMap.erase (Iter);
760
+ Inst->eraseFromParent ();
761
+ }
762
+
763
+ // / Erase \p V from \p BB and move \II forward to avoid invalidating
764
+ // / iterators.
765
+ void eraseFromParentAndMove (Value *V, BasicBlock::reverse_iterator &II,
766
+ BasicBlock &BB) {
767
+ auto *Inst = cast<Instruction>(V);
768
+ // Still used, don't erase.
769
+ if (!Inst->use_empty ())
770
+ return ;
771
+ if (II != BB.rend () && Inst == &*II)
772
+ ++II;
773
+ eraseFromParentAndRemoveFromShapeMap (Inst);
774
+ }
775
+
776
+ // / Add a new entry to ShapeMap for \p New with \p Old's shape info, erase the
777
+ // / entry for \p Old and replace all uses of \p Old with \p New.
761
778
void updateShapeAndReplaceAllUsesWith (Instruction &Old, Value *New) {
762
779
// We need to remove Old from the ShapeMap otherwise RAUW will replace it
763
780
// with New. We should only add New it it supportsShapeInfo so we insert
@@ -871,13 +888,13 @@ class LowerMatrixIntrinsics {
871
888
872
889
void liftTranspose (Instruction &I) {
873
890
// Erase dead Instructions after lifting transposes from binops.
874
- auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) {
891
+ auto CleanupBinOp = [this ](Instruction &T, Value *A, Value *B) {
875
892
if (T.use_empty ())
876
- T. eraseFromParent ( );
893
+ eraseFromParentAndRemoveFromShapeMap (&T );
877
894
if (A->use_empty ())
878
- cast<Instruction>(A)-> eraseFromParent ( );
895
+ eraseFromParentAndRemoveFromShapeMap ( cast<Instruction>(A));
879
896
if (A != B && B->use_empty ())
880
- cast<Instruction>(B)-> eraseFromParent ( );
897
+ eraseFromParentAndRemoveFromShapeMap ( cast<Instruction>(B));
881
898
};
882
899
883
900
Value *A, *B, *AT, *BT;
@@ -1484,7 +1501,7 @@ class LowerMatrixIntrinsics {
1484
1501
m_Value (Arg)))) {
1485
1502
auto *NewLoad = Builder.CreateLoad (Op->getType (), Arg);
1486
1503
Op->replaceAllUsesWith (NewLoad);
1487
- cast<Instruction>(Op)-> eraseFromParent ( );
1504
+ eraseFromParentAndRemoveFromShapeMap ( cast<Instruction>(Op));
1488
1505
return ;
1489
1506
} else if (match (Op, m_Intrinsic<Intrinsic::matrix_transpose>(
1490
1507
m_Value (Arg)))) {
@@ -1853,15 +1870,15 @@ class LowerMatrixIntrinsics {
1853
1870
// Mark eliminated instructions as fused and remove them.
1854
1871
FusedInsts.insert (Store);
1855
1872
FusedInsts.insert (MatMul);
1856
- Store-> eraseFromParent ( );
1857
- MatMul-> eraseFromParent ( );
1873
+ eraseFromParentAndRemoveFromShapeMap (Store );
1874
+ eraseFromParentAndRemoveFromShapeMap (MatMul );
1858
1875
if (LoadOp0->hasNUses (0 )) {
1859
1876
FusedInsts.insert (LoadOp0);
1860
- LoadOp0-> eraseFromParent ( );
1877
+ eraseFromParentAndRemoveFromShapeMap (LoadOp0 );
1861
1878
}
1862
1879
if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses (0 )) {
1863
1880
FusedInsts.insert (LoadOp1);
1864
- LoadOp1-> eraseFromParent ( );
1881
+ eraseFromParentAndRemoveFromShapeMap (LoadOp1 );
1865
1882
}
1866
1883
}
1867
1884
0 commit comments