Skip to content

Commit 7b6e0d9

Browse files
authored
[Matrix] Use DenseMap for ShapeMap instead of ValueMap. (#118282)
ValueMap automatically updates entries with the new value if they have been RAUW. This can lead to instructions that are expected to not have shape info to be added to the map (e.g. shufflevector as in the added test case). This leads to incorrect results. Originally it was used for transpose optimizations, but they now all use updateShapeAndReplaceAllUsesWith, which takes care of updating the shape info as needed. This fixes a crash in the newly added test cases. PR: #118282
1 parent 7235ac9 commit 7b6e0d9

File tree

3 files changed

+94
-27
lines changed

3 files changed

+94
-27
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -97,19 +97,6 @@ static DISubprogram *getSubprogram(DIScope *Scope) {
9797
return cast<DILocalScope>(Scope)->getSubprogram();
9898
}
9999

100-
/// Erase \p V from \p BB and move \II forward to avoid invalidating
101-
/// iterators.
102-
static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
103-
BasicBlock &BB) {
104-
auto *Inst = cast<Instruction>(V);
105-
// Still used, don't erase.
106-
if (!Inst->use_empty())
107-
return;
108-
if (II != BB.rend() && Inst == &*II)
109-
++II;
110-
Inst->eraseFromParent();
111-
}
112-
113100
/// Return true if V is a splat of a value (which is used when multiplying a
114101
/// matrix with a scalar).
115102
static bool isSplat(Value *V) {
@@ -259,7 +246,7 @@ static bool isUniformShape(Value *V) {
259246
/// Return the ShapeInfo for the result of \p I, it it can be determined.
260247
static std::optional<ShapeInfo>
261248
computeShapeInfoForInst(Instruction *I,
262-
const ValueMap<Value *, ShapeInfo> &ShapeMap) {
249+
const DenseMap<Value *, ShapeInfo> &ShapeMap) {
263250
Value *M;
264251
Value *N;
265252
Value *K;
@@ -493,10 +480,16 @@ class LowerMatrixIntrinsics {
493480
/// the result value of the instruction, with the only exceptions being store
494481
/// instructions and the matrix_column_major_store intrinsics. For those, the
495482
/// shape information indicates that those instructions should be lowered
496-
/// using shape information as well. A ValueMap is used so that when
497-
/// sub-passes like optimizeTransposes performs RAUW the map stays
498-
/// up-to-date.
499-
ValueMap<Value *, ShapeInfo> ShapeMap;
483+
/// using shape information as well. Note that extra care is needed when
484+
/// erasing or RAUW'ing a value that is present in ShapeMap. If the
485+
/// replacement is also a matrix operation, use
486+
/// updateShapeAndReplaceAllUsesWith to make sure the replacement is added to
487+
/// ShapeMap. We don't use ValueMap, as there are also cases where we do not
488+
/// want to add shape information for a replacement instruction. When directly
489+
/// erasing a value with an entry in ShapeMap, use
490+
/// eraseFromParentAndRemoveFromShapeMap to make sure ShapeMap is also updated
491+
/// accordingly.
492+
DenseMap<Value *, ShapeInfo> ShapeMap;
500493

501494
/// List of instructions to remove. While lowering, we are not replacing all
502495
/// users of a lowered instruction, if shape information is available and
@@ -758,6 +751,30 @@ class LowerMatrixIntrinsics {
758751
return Operation(T0, Shape0.t(), T1, Shape1.t());
759752
}
760753

754+
/// Erase \p Inst from both ShapeMap (if an entry exists) and erase \p Inst
755+
/// itself.
756+
void eraseFromParentAndRemoveFromShapeMap(Instruction *Inst) {
757+
auto Iter = ShapeMap.find(Inst);
758+
if (Iter != ShapeMap.end())
759+
ShapeMap.erase(Iter);
760+
Inst->eraseFromParent();
761+
}
762+
763+
/// Erase \p V from \p BB and move \II forward to avoid invalidating
764+
/// iterators.
765+
void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
766+
BasicBlock &BB) {
767+
auto *Inst = cast<Instruction>(V);
768+
// Still used, don't erase.
769+
if (!Inst->use_empty())
770+
return;
771+
if (II != BB.rend() && Inst == &*II)
772+
++II;
773+
eraseFromParentAndRemoveFromShapeMap(Inst);
774+
}
775+
776+
/// Add a new entry to ShapeMap for \p New with \p Old's shape info, erase the
777+
/// entry for \p Old and replace all uses of \p Old with \p New.
761778
void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
762779
// We need to remove Old from the ShapeMap otherwise RAUW will replace it
763780
// with New. We should only add New it it supportsShapeInfo so we insert
@@ -871,13 +888,13 @@ class LowerMatrixIntrinsics {
871888

872889
void liftTranspose(Instruction &I) {
873890
// Erase dead Instructions after lifting transposes from binops.
874-
auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) {
891+
auto CleanupBinOp = [this](Instruction &T, Value *A, Value *B) {
875892
if (T.use_empty())
876-
T.eraseFromParent();
893+
eraseFromParentAndRemoveFromShapeMap(&T);
877894
if (A->use_empty())
878-
cast<Instruction>(A)->eraseFromParent();
895+
eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(A));
879896
if (A != B && B->use_empty())
880-
cast<Instruction>(B)->eraseFromParent();
897+
eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(B));
881898
};
882899

883900
Value *A, *B, *AT, *BT;
@@ -1484,7 +1501,7 @@ class LowerMatrixIntrinsics {
14841501
m_Value(Arg)))) {
14851502
auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
14861503
Op->replaceAllUsesWith(NewLoad);
1487-
cast<Instruction>(Op)->eraseFromParent();
1504+
eraseFromParentAndRemoveFromShapeMap(cast<Instruction>(Op));
14881505
return;
14891506
} else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
14901507
m_Value(Arg)))) {
@@ -1853,15 +1870,15 @@ class LowerMatrixIntrinsics {
18531870
// Mark eliminated instructions as fused and remove them.
18541871
FusedInsts.insert(Store);
18551872
FusedInsts.insert(MatMul);
1856-
Store->eraseFromParent();
1857-
MatMul->eraseFromParent();
1873+
eraseFromParentAndRemoveFromShapeMap(Store);
1874+
eraseFromParentAndRemoveFromShapeMap(MatMul);
18581875
if (LoadOp0->hasNUses(0)) {
18591876
FusedInsts.insert(LoadOp0);
1860-
LoadOp0->eraseFromParent();
1877+
eraseFromParentAndRemoveFromShapeMap(LoadOp0);
18611878
}
18621879
if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(0)) {
18631880
FusedInsts.insert(LoadOp1);
1864-
LoadOp1->eraseFromParent();
1881+
eraseFromParentAndRemoveFromShapeMap(LoadOp1);
18651882
}
18661883
}
18671884

llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,33 @@ declare <1 x i32> @llvm.matrix.multiply.v1i32.v5i32.v5i32(<5 x i32>, <5 x i32>,
190190
declare <5 x i32> @llvm.matrix.column.major.load.v5i32.i64(ptr nocapture, i64, i1 immarg, i32 immarg, i32 immarg) #1
191191

192192
declare <5 x i32> @llvm.matrix.transpose.v5i32(<5 x i32>, i32 immarg, i32 immarg) #0
193+
194+
define <1 x i32> @test_dot_product_with_transposed_shuffle_op(<4 x i32> %a, <2 x i32> %b) {
195+
; CHECK-LABEL: @test_dot_product_with_transposed_shuffle_op(
196+
; CHECK-NEXT: entry:
197+
; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> poison, <2 x i32> <i32 0, i32 1>
198+
; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <4 x i32> [[A]], <4 x i32> poison, <2 x i32> <i32 2, i32 3>
199+
; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x i32> [[SPLIT]], i64 0
200+
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i32> poison, i32 [[TMP0]], i64 0
201+
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x i32> [[SPLIT1]], i64 0
202+
; CHECK-NEXT: [[TMP3:%.*]] = insertelement <2 x i32> [[TMP1]], i32 [[TMP2]], i64 1
203+
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x i32> [[SPLIT]], i64 1
204+
; CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x i32> poison, i32 [[TMP4]], i64 0
205+
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x i32> [[SPLIT1]], i64 1
206+
; CHECK-NEXT: [[TMP7:%.*]] = insertelement <2 x i32> [[TMP5]], i32 [[TMP6]], i64 1
207+
; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <2 x i32> [[TMP3]], <2 x i32> [[TMP7]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
208+
; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <4 x i32> [[TMP8]], <4 x i32> zeroinitializer, <2 x i32> <i32 0, i32 1>
209+
; CHECK-NEXT: [[TMP9:%.*]] = mul <2 x i32> [[SHUFFLE]], [[B:%.*]]
210+
; CHECK-NEXT: [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> [[TMP9]])
211+
; CHECK-NEXT: [[TMP11:%.*]] = insertelement <1 x i32> poison, i32 [[TMP10]], i64 0
212+
; CHECK-NEXT: ret <1 x i32> [[TMP11]]
213+
;
214+
entry:
215+
%t.a = tail call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> %a, i32 2, i32 2)
216+
%shuffle = shufflevector <4 x i32> %t.a, <4 x i32> zeroinitializer, <2 x i32> <i32 0, i32 1>
217+
%t.shuffle = call <2 x i32> @llvm.matrix.transpose.v2i32(<2 x i32> %shuffle, i32 2, i32 1)
218+
%m = call <1 x i32> @llvm.matrix.multiply.v1i32.v2i32.v2i32(<2 x i32> %t.shuffle, <2 x i32> %b, i32 1, i32 2, i32 1)
219+
ret <1 x i32> %m
220+
}
221+
222+
declare <2 x i32> @llvm.matrix.transpose.v2i32(<2 x i32>, i32 immarg, i32 immarg)

llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,28 @@ entry:
144144
ret <6 x double> %mul
145145
}
146146

147+
define void @test_remove_entries_from_shape_map(<3 x float> %a, <2 x float> %b, <6 x float> %c, ptr %dst) {
148+
; CHECK-LABEL: define void @test_remove_entries_from_shape_map(
149+
; CHECK-SAME: <3 x float> [[A:%.*]], <2 x float> [[B:%.*]], <6 x float> [[C:%.*]], ptr [[DST:%.*]]) {
150+
; CHECK-NEXT: [[ENTRY:.*:]]
151+
; CHECK-NEXT: [[TMP0:%.*]] = call <6 x float> @llvm.matrix.multiply.v6f32.v3f32.v2f32(<3 x float> [[A]], <2 x float> [[B]], i32 3, i32 1, i32 2)
152+
; CHECK-NEXT: [[MFADD:%.*]] = fadd <6 x float> [[C]], [[TMP0]]
153+
; CHECK-NEXT: [[MFADD_T:%.*]] = call <6 x float> @llvm.matrix.transpose.v6f32(<6 x float> [[MFADD]], i32 3, i32 2)
154+
; CHECK-NEXT: store <6 x float> [[MFADD_T]], ptr [[DST]], align 4
155+
; CHECK-NEXT: ret void
156+
;
157+
entry:
158+
%m = tail call <6 x float> @llvm.matrix.multiply.v6f32.v3f32.v2f32(<3 x float> %a, <2 x float> %b, i32 3, i32 1, i32 2)
159+
%add = fadd <6 x float> %c, %m
160+
%t = tail call <6 x float> @llvm.matrix.transpose.v6f32(<6 x float> %add, i32 3, i32 2)
161+
store <6 x float> %t, ptr %dst, align 4
162+
ret void
163+
}
164+
147165
declare <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double>, i32, i32)
148166
declare <4 x double> @llvm.matrix.transpose.v4f64.v4f64(<4 x double>, i32, i32)
149167
declare <9 x double> @llvm.matrix.multiply.v9f64.v6f64(<6 x double>, <6 x double>, i32, i32, i32)
150168
declare <6 x double> @llvm.matrix.multiply.v6f64.v6f64.v4f64(<6 x double>, <4 x double>, i32, i32, i32)
151169
declare <6 x double> @llvm.matrix.multiply.v6f64.v6f64.v6f64(<6 x double>, <4 x double>, i32, i32, i32)
170+
declare <6 x float> @llvm.matrix.transpose.v6f32(<6 x float>, i32 immarg, i32 immarg)
171+
declare <6 x float> @llvm.matrix.multiply.v6f32.v3f32.v2f32(<3 x float>, <2 x float>, i32 immarg, i32 immarg, i32 immarg)

0 commit comments

Comments
 (0)