Skip to content

Commit db42110

Browse files
committed
[Matrix] Use DenseMap for ShapeMap instead of ValueMap.
ValueMap automatically updates entries with the new value if they have been RAUW. This can lead to instructions that are expected to not have shape info to be added to the map (e.g. shufflevector as in the added test case). This leads to incorrect results. Originally it was used for transpose optimizations, but they now all use updateShapeAndReplaceAllUsesWith, which takes care of updating the shape info as needed. This fixes a crash in the newly added test case.
1 parent e48c7fe commit db42110

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ static bool isUniformShape(Value *V) {
259259
/// Return the ShapeInfo for the result of \p I, it it can be determined.
260260
static std::optional<ShapeInfo>
261261
computeShapeInfoForInst(Instruction *I,
262-
const ValueMap<Value *, ShapeInfo> &ShapeMap) {
262+
const DenseMap<Value *, ShapeInfo> &ShapeMap) {
263263
Value *M;
264264
Value *N;
265265
Value *K;
@@ -493,10 +493,8 @@ class LowerMatrixIntrinsics {
493493
/// the result value of the instruction, with the only exceptions being store
494494
/// instructions and the matrix_column_major_store intrinsics. For those, the
495495
/// shape information indicates that those instructions should be lowered
496-
/// using shape information as well. A ValueMap is used so that when
497-
/// sub-passes like optimizeTransposes performs RAUW the map stays
498-
/// up-to-date.
499-
ValueMap<Value *, ShapeInfo> ShapeMap;
496+
/// using shape information as well.
497+
DenseMap<Value *, ShapeInfo> ShapeMap;
500498

501499
/// List of instructions to remove. While lowering, we are not replacing all
502500
/// users of a lowered instruction, if shape information is available and

llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,33 @@ declare <1 x i32> @llvm.matrix.multiply.v1i32.v5i32.v5i32(<5 x i32>, <5 x i32>,
190190
declare <5 x i32> @llvm.matrix.column.major.load.v5i32.i64(ptr nocapture, i64, i1 immarg, i32 immarg, i32 immarg) #1
191191

192192
declare <5 x i32> @llvm.matrix.transpose.v5i32(<5 x i32>, i32 immarg, i32 immarg) #0
193+
194+
define <1 x i32> @test_dot_product_with_transposed_shuffle_op(<4 x i32> %a, <2 x i32> %b) {
195+
; CHECK-LABEL: @test_dot_product_with_transposed_shuffle_op(
196+
; CHECK-NEXT: entry:
197+
; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> poison, <2 x i32> <i32 0, i32 1>
198+
; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <4 x i32> [[A]], <4 x i32> poison, <2 x i32> <i32 2, i32 3>
199+
; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x i32> [[SPLIT]], i64 0
200+
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i32> poison, i32 [[TMP0]], i64 0
201+
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x i32> [[SPLIT1]], i64 0
202+
; CHECK-NEXT: [[TMP3:%.*]] = insertelement <2 x i32> [[TMP1]], i32 [[TMP2]], i64 1
203+
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x i32> [[SPLIT]], i64 1
204+
; CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x i32> poison, i32 [[TMP4]], i64 0
205+
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x i32> [[SPLIT1]], i64 1
206+
; CHECK-NEXT: [[TMP7:%.*]] = insertelement <2 x i32> [[TMP5]], i32 [[TMP6]], i64 1
207+
; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <2 x i32> [[TMP3]], <2 x i32> [[TMP7]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
208+
; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <4 x i32> [[TMP8]], <4 x i32> zeroinitializer, <2 x i32> <i32 0, i32 1>
209+
; CHECK-NEXT: [[TMP9:%.*]] = mul <2 x i32> [[SHUFFLE]], [[B:%.*]]
210+
; CHECK-NEXT: [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> [[TMP9]])
211+
; CHECK-NEXT: [[TMP11:%.*]] = insertelement <1 x i32> poison, i32 [[TMP10]], i64 0
212+
; CHECK-NEXT: ret <1 x i32> [[TMP11]]
213+
;
214+
entry:
215+
%t.a = tail call <4 x i32> @llvm.matrix.transpose.v4i32(<4 x i32> %a, i32 2, i32 2)
216+
%shuffle = shufflevector <4 x i32> %t.a, <4 x i32> zeroinitializer, <2 x i32> <i32 0, i32 1>
217+
%t.shuffle = call <2 x i32> @llvm.matrix.transpose.v2i32(<2 x i32> %shuffle, i32 2, i32 1)
218+
%m = call <1 x i32> @llvm.matrix.multiply.v1i32.v2i32.v2i32(<2 x i32> %t.shuffle, <2 x i32> %b, i32 1, i32 2, i32 1)
219+
ret <1 x i32> %m
220+
}
221+
222+
declare <2 x i32> @llvm.matrix.transpose.v2i32(<2 x i32>, i32 immarg, i32 immarg)

0 commit comments

Comments
 (0)