Skip to content

Commit 48441cb

Browse files
committed
[Matrix] Properly set Changed status when optimizing transposes.
Currently Changed is not updated properly when transposes are optimized, causing missing analysis invalidation. Update optimizeTransposes to indicate if changes have been made.
1 parent 449e2f5 commit 48441cb

File tree

2 files changed

+36
-8
lines changed

2 files changed

+36
-8
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,8 @@ class LowerMatrixIntrinsics {
792792
/// This creates and erases instructions as needed, and returns the newly
793793
/// created instruction while updating the iterator to avoid invalidation. If
794794
/// this returns nullptr, no new instruction was created.
795-
Instruction *sinkTranspose(Instruction &I, BasicBlock::reverse_iterator &II) {
795+
Instruction *sinkTranspose(Instruction &I, BasicBlock::reverse_iterator &II,
796+
bool &Changed) {
796797
BasicBlock &BB = *I.getParent();
797798
IRBuilder<> IB(&I);
798799
MatrixBuilder Builder(IB);
@@ -809,13 +810,15 @@ class LowerMatrixIntrinsics {
809810
updateShapeAndReplaceAllUsesWith(I, TATA);
810811
eraseFromParentAndMove(&I, II, BB);
811812
eraseFromParentAndMove(TA, II, BB);
813+
Changed = true;
812814
return nullptr;
813815
}
814816

815817
// k^T -> k
816818
if (isSplat(TA)) {
817819
updateShapeAndReplaceAllUsesWith(I, TA);
818820
eraseFromParentAndMove(&I, II, BB);
821+
Changed = true;
819822
return nullptr;
820823
}
821824

@@ -834,6 +837,7 @@ class LowerMatrixIntrinsics {
834837
updateShapeAndReplaceAllUsesWith(I, NewInst);
835838
eraseFromParentAndMove(&I, II, BB);
836839
eraseFromParentAndMove(TA, II, BB);
840+
Changed = true;
837841
return NewInst;
838842
}
839843

@@ -859,6 +863,7 @@ class LowerMatrixIntrinsics {
859863
updateShapeAndReplaceAllUsesWith(I, NewInst);
860864
eraseFromParentAndMove(&I, II, BB);
861865
eraseFromParentAndMove(TA, II, BB);
866+
Changed = true;
862867
return NewInst;
863868
}
864869

@@ -880,13 +885,14 @@ class LowerMatrixIntrinsics {
880885
updateShapeAndReplaceAllUsesWith(I, NewInst);
881886
eraseFromParentAndMove(&I, II, BB);
882887
eraseFromParentAndMove(TA, II, BB);
888+
Changed = true;
883889
return NewInst;
884890
}
885891

886892
return nullptr;
887893
}
888894

889-
void liftTranspose(Instruction &I) {
895+
bool liftTranspose(Instruction &I) {
890896
// Erase dead Instructions after lifting transposes from binops.
891897
auto CleanupBinOp = [this](Instruction &T, Value *A, Value *B) {
892898
if (T.use_empty())
@@ -914,6 +920,7 @@ class LowerMatrixIntrinsics {
914920
R->getZExtValue());
915921
updateShapeAndReplaceAllUsesWith(I, NewInst);
916922
CleanupBinOp(I, A, B);
923+
return true;
917924
}
918925
// A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If
919926
// the shape of the second transpose is different, there's a shape conflict
@@ -940,19 +947,22 @@ class LowerMatrixIntrinsics {
940947
ShapeMap[AddI] &&
941948
"Shape of updated addition doesn't match cached shape.");
942949
}
950+
return true;
943951
}
952+
return false;
944953
}
945954

946955
/// Try moving transposes in order to fold them away or into multiplies.
947-
void optimizeTransposes() {
956+
bool optimizeTransposes() {
957+
bool Changed = false;
948958
// First sink all transposes inside matmuls and adds, hoping that we end up
949959
// with NN, NT or TN variants.
950960
for (BasicBlock &BB : reverse(Func)) {
951961
for (auto II = BB.rbegin(); II != BB.rend();) {
952962
Instruction &I = *II;
953963
// We may remove II. By default continue on the next/prev instruction.
954964
++II;
955-
if (Instruction *NewInst = sinkTranspose(I, II))
965+
if (Instruction *NewInst = sinkTranspose(I, II, Changed))
956966
II = std::next(BasicBlock::reverse_iterator(NewInst));
957967
}
958968
}
@@ -961,9 +971,10 @@ class LowerMatrixIntrinsics {
961971
// to fold into consuming multiply or add.
962972
for (BasicBlock &BB : Func) {
963973
for (Instruction &I : llvm::make_early_inc_range(BB)) {
964-
liftTranspose(I);
974+
Changed |= liftTranspose(I);
965975
}
966976
}
977+
return Changed;
967978
}
968979

969980
bool Visit() {
@@ -1006,15 +1017,15 @@ class LowerMatrixIntrinsics {
10061017
WorkList = propagateShapeBackward(WorkList);
10071018
}
10081019

1020+
bool Changed = false;
10091021
if (!isMinimal()) {
1010-
optimizeTransposes();
1022+
Changed |= optimizeTransposes();
10111023
if (PrintAfterTransposeOpt) {
10121024
dbgs() << "Dump after matrix transpose optimization:\n";
10131025
Func.print(dbgs());
10141026
}
10151027
}
10161028

1017-
bool Changed = false;
10181029
SmallVector<CallInst *, 16> MaybeFusableInsts;
10191030
SmallVector<Instruction *, 16> MatrixInsts;
10201031
SmallVector<IntrinsicInst *, 16> LifetimeEnds;
@@ -1043,7 +1054,7 @@ class LowerMatrixIntrinsics {
10431054
if (!FusedInsts.contains(CI))
10441055
LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
10451056

1046-
Changed = !FusedInsts.empty();
1057+
Changed |= !FusedInsts.empty();
10471058

10481059
// Fourth, lower remaining instructions with shape information.
10491060
for (Instruction *Inst : MatrixInsts) {
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -p lower-matrix-intrinsics -verify-analysis-invalidation -S %s | FileCheck %s
3+
4+
define <3 x float> @splat_transpose(<3 x float> %in) {
5+
; CHECK-LABEL: define <3 x float> @splat_transpose(
6+
; CHECK-SAME: <3 x float> [[IN:%.*]]) {
7+
; CHECK-NEXT: [[ENTRY:.*:]]
8+
; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <3 x float> [[IN]], <3 x float> zeroinitializer, <3 x i32> zeroinitializer
9+
; CHECK-NEXT: ret <3 x float> [[SPLAT]]
10+
;
11+
entry:
12+
%splat = shufflevector <3 x float> %in, <3 x float> zeroinitializer, <3 x i32> zeroinitializer
13+
%r = tail call <3 x float> @llvm.matrix.transpose.v3f32(<3 x float> %splat, i32 3, i32 1)
14+
ret <3 x float> %r
15+
}
16+
17+
declare <3 x float> @llvm.matrix.transpose.v3f32(<3 x float>, i32 immarg, i32 immarg)

0 commit comments

Comments
 (0)