@@ -792,7 +792,8 @@ class LowerMatrixIntrinsics {
792
792
// / This creates and erases instructions as needed, and returns the newly
793
793
// / created instruction while updating the iterator to avoid invalidation. If
794
794
// / this returns nullptr, no new instruction was created.
795
- Instruction *sinkTranspose (Instruction &I, BasicBlock::reverse_iterator &II) {
795
+ Instruction *sinkTranspose (Instruction &I, BasicBlock::reverse_iterator &II,
796
+ bool &Changed) {
796
797
BasicBlock &BB = *I.getParent ();
797
798
IRBuilder<> IB (&I);
798
799
MatrixBuilder Builder (IB);
@@ -809,13 +810,15 @@ class LowerMatrixIntrinsics {
809
810
updateShapeAndReplaceAllUsesWith (I, TATA);
810
811
eraseFromParentAndMove (&I, II, BB);
811
812
eraseFromParentAndMove (TA, II, BB);
813
+ Changed = true ;
812
814
return nullptr ;
813
815
}
814
816
815
817
// k^T -> k
816
818
if (isSplat (TA)) {
817
819
updateShapeAndReplaceAllUsesWith (I, TA);
818
820
eraseFromParentAndMove (&I, II, BB);
821
+ Changed = true ;
819
822
return nullptr ;
820
823
}
821
824
@@ -834,6 +837,7 @@ class LowerMatrixIntrinsics {
834
837
updateShapeAndReplaceAllUsesWith (I, NewInst);
835
838
eraseFromParentAndMove (&I, II, BB);
836
839
eraseFromParentAndMove (TA, II, BB);
840
+ Changed = true ;
837
841
return NewInst;
838
842
}
839
843
@@ -859,6 +863,7 @@ class LowerMatrixIntrinsics {
859
863
updateShapeAndReplaceAllUsesWith (I, NewInst);
860
864
eraseFromParentAndMove (&I, II, BB);
861
865
eraseFromParentAndMove (TA, II, BB);
866
+ Changed = true ;
862
867
return NewInst;
863
868
}
864
869
@@ -880,13 +885,14 @@ class LowerMatrixIntrinsics {
880
885
updateShapeAndReplaceAllUsesWith (I, NewInst);
881
886
eraseFromParentAndMove (&I, II, BB);
882
887
eraseFromParentAndMove (TA, II, BB);
888
+ Changed = true ;
883
889
return NewInst;
884
890
}
885
891
886
892
return nullptr ;
887
893
}
888
894
889
- void liftTranspose (Instruction &I) {
895
+ bool liftTranspose (Instruction &I) {
890
896
// Erase dead Instructions after lifting transposes from binops.
891
897
auto CleanupBinOp = [this ](Instruction &T, Value *A, Value *B) {
892
898
if (T.use_empty ())
@@ -914,6 +920,7 @@ class LowerMatrixIntrinsics {
914
920
R->getZExtValue ());
915
921
updateShapeAndReplaceAllUsesWith (I, NewInst);
916
922
CleanupBinOp (I, A, B);
923
+ return true ;
917
924
}
918
925
// A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If
919
926
// the shape of the second transpose is different, there's a shape conflict
@@ -940,19 +947,22 @@ class LowerMatrixIntrinsics {
940
947
ShapeMap[AddI] &&
941
948
" Shape of updated addition doesn't match cached shape." );
942
949
}
950
+ return true ;
943
951
}
952
+ return false ;
944
953
}
945
954
946
955
// / Try moving transposes in order to fold them away or into multiplies.
947
- void optimizeTransposes () {
956
+ bool optimizeTransposes () {
957
+ bool Changed = false ;
948
958
// First sink all transposes inside matmuls and adds, hoping that we end up
949
959
// with NN, NT or TN variants.
950
960
for (BasicBlock &BB : reverse (Func)) {
951
961
for (auto II = BB.rbegin (); II != BB.rend ();) {
952
962
Instruction &I = *II;
953
963
// We may remove II. By default continue on the next/prev instruction.
954
964
++II;
955
- if (Instruction *NewInst = sinkTranspose (I, II))
965
+ if (Instruction *NewInst = sinkTranspose (I, II, Changed ))
956
966
II = std::next (BasicBlock::reverse_iterator (NewInst));
957
967
}
958
968
}
@@ -961,9 +971,10 @@ class LowerMatrixIntrinsics {
961
971
// to fold into consuming multiply or add.
962
972
for (BasicBlock &BB : Func) {
963
973
for (Instruction &I : llvm::make_early_inc_range (BB)) {
964
- liftTranspose (I);
974
+ Changed |= liftTranspose (I);
965
975
}
966
976
}
977
+ return Changed;
967
978
}
968
979
969
980
bool Visit () {
@@ -1006,15 +1017,15 @@ class LowerMatrixIntrinsics {
1006
1017
WorkList = propagateShapeBackward (WorkList);
1007
1018
}
1008
1019
1020
+ bool Changed = false ;
1009
1021
if (!isMinimal ()) {
1010
- optimizeTransposes ();
1022
+ Changed |= optimizeTransposes ();
1011
1023
if (PrintAfterTransposeOpt) {
1012
1024
dbgs () << " Dump after matrix transpose optimization:\n " ;
1013
1025
Func.print (dbgs ());
1014
1026
}
1015
1027
}
1016
1028
1017
- bool Changed = false ;
1018
1029
SmallVector<CallInst *, 16 > MaybeFusableInsts;
1019
1030
SmallVector<Instruction *, 16 > MatrixInsts;
1020
1031
SmallVector<IntrinsicInst *, 16 > LifetimeEnds;
@@ -1043,7 +1054,7 @@ class LowerMatrixIntrinsics {
1043
1054
if (!FusedInsts.contains (CI))
1044
1055
LowerMatrixMultiplyFused (CI, FusedInsts, LifetimeEnds);
1045
1056
1046
- Changed = !FusedInsts.empty ();
1057
+ Changed | = !FusedInsts.empty ();
1047
1058
1048
1059
// Fourth, lower remaining instructions with shape information.
1049
1060
for (Instruction *Inst : MatrixInsts) {
0 commit comments