Skip to content

[Matrix] Fix dimensions when hoisting transpose across add. #81507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 12, 2024

Conversation

fhahn
Copy link
Contributor

@fhahn fhahn commented Feb 12, 2024

Row and column arguments for matrix_transpose indicate the shape of the operand. When hoisting the transpose to the result of the add, the add operates on the original operand's shape, and so does the hoisted transpose.

This patch also adds an assert that the shape for the original add and the transpose match, as well as the shape of the new add matches the cached shape for it.

The assert could potentially be moved to updateShapeAndReplaceAllUsesWith.

Row and column arguments for matrix_transpose indicate the shape of the
operand. When hoisting the transpose to the result of the add, the add
operates on the original operand's shape, and so does the hoisted
transpose.

This patch also adds an assert that the shape for the original add and
the transpose match, as well as the shape of the new add matches the
cached shape for it.

The assert could potentially be moved to updateShapeAndReplaceAllUsesWith.
@llvmbot
Copy link
Member

llvmbot commented Feb 12, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Florian Hahn (fhahn)

Changes

Row and column arguments for matrix_transpose indicate the shape of the operand. When hoisting the transpose to the result of the add, the add operates on the original operand's shape, and so does the hoisted transpose.

This patch also adds an assert that the shape for the original add and the transpose match, as well as the shape of the new add matches the cached shape for it.

The assert could potentially be moved to updateShapeAndReplaceAllUsesWith.


Full diff: https://github.com/llvm/llvm-project/pull/81507.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp (+15-6)
  • (modified) llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll (+60-59)
  • (modified) llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll (+40-4)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 03e289f7a087a..075388f69a85b 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -765,8 +765,9 @@ class LowerMatrixIntrinsics {
     auto S = ShapeMap.find(&Old);
     if (S != ShapeMap.end()) {
       ShapeMap.erase(S);
-      if (supportsShapeInfo(New))
+      if (supportsShapeInfo(New)) {
         ShapeMap.insert({New, S->second});
+      }
     }
     Old.replaceAllUsesWith(New);
   }
@@ -898,20 +899,28 @@ class LowerMatrixIntrinsics {
       updateShapeAndReplaceAllUsesWith(I, NewInst);
       CleanupBinOp(I, A, B);
     }
-    // A^t + B ^t -> (A + B)^t
+    // A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If
+    // the shape of the second transpose is different, there's a shape conflict
+    // which gets resolved by picking the shape of the first operand.
     else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) &&
              match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
                           m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) &&
              match(B, m_Intrinsic<Intrinsic::matrix_transpose>(
-                          m_Value(BT), m_ConstantInt(R), m_ConstantInt(C)))) {
+                          m_Value(BT), m_ConstantInt(), m_ConstantInt()))) {
       IRBuilder<> Builder(&I);
-      Value *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
-      setShapeInfo(Add, {C, R});
+      auto *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
+      setShapeInfo(Add, {R, C});
       MatrixBuilder MBuilder(Builder);
       Instruction *NewInst = MBuilder.CreateMatrixTranspose(
-          Add, C->getZExtValue(), R->getZExtValue(), "mfadd_t");
+          Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t");
       updateShapeAndReplaceAllUsesWith(I, NewInst);
+      assert(computeShapeInfoForInst(NewInst, ShapeMap) ==
+                 computeShapeInfoForInst(&I, ShapeMap) &&
+             "Shape of new instruction doesn't match original shape.");
       CleanupBinOp(I, A, B);
+      assert(computeShapeInfoForInst(Add, ShapeMap).value_or(ShapeMap[Add]) ==
+                 ShapeMap[Add] &&
+             "Shape of updated addition doesn't match cached shape.");
     }
   }
 
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll
index 82ae93b31035d..33a338dbc4ea0 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll
@@ -4,31 +4,35 @@
 define <8 x double> @fadd_transpose(<8 x double> %a, <8 x double> %b) {
 ; CHECK-LABEL: @fadd_transpose(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x double> [[A:%.*]], <8 x double> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[SPLIT2:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    [[SPLIT3:%.*]] = shufflevector <8 x double> [[B:%.*]], <8 x double> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[SPLIT4:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    [[TMP0:%.*]] = fadd <4 x double> [[SPLIT]], [[SPLIT3]]
-; CHECK-NEXT:    [[TMP1:%.*]] = fadd <4 x double> [[SPLIT2]], [[SPLIT4]]
-; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <4 x double> [[TMP0]], i64 0
-; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <2 x double> poison, double [[TMP2]], i64 0
-; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <4 x double> [[TMP1]], i64 0
-; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <2 x double> [[TMP3]], double [[TMP4]], i64 1
-; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <4 x double> [[TMP0]], i64 1
-; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <2 x double> poison, double [[TMP6]], i64 0
-; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <4 x double> [[TMP1]], i64 1
-; CHECK-NEXT:    [[TMP9:%.*]] = insertelement <2 x double> [[TMP7]], double [[TMP8]], i64 1
-; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <4 x double> [[TMP0]], i64 2
-; CHECK-NEXT:    [[TMP11:%.*]] = insertelement <2 x double> poison, double [[TMP10]], i64 0
-; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <4 x double> [[TMP1]], i64 2
-; CHECK-NEXT:    [[TMP13:%.*]] = insertelement <2 x double> [[TMP11]], double [[TMP12]], i64 1
-; CHECK-NEXT:    [[TMP14:%.*]] = extractelement <4 x double> [[TMP0]], i64 3
-; CHECK-NEXT:    [[TMP15:%.*]] = insertelement <2 x double> poison, double [[TMP14]], i64 0
-; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <4 x double> [[TMP1]], i64 3
-; CHECK-NEXT:    [[TMP17:%.*]] = insertelement <2 x double> [[TMP15]], double [[TMP16]], i64 1
-; CHECK-NEXT:    [[TMP18:%.*]] = shufflevector <2 x double> [[TMP5]], <2 x double> [[TMP9]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[TMP19:%.*]] = shufflevector <2 x double> [[TMP13]], <2 x double> [[TMP17]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[TMP20:%.*]] = shufflevector <4 x double> [[TMP18]], <4 x double> [[TMP19]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x double> [[A:%.*]], <8 x double> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[SPLIT1:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT:    [[SPLIT2:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <2 x i32> <i32 4, i32 5>
+; CHECK-NEXT:    [[SPLIT3:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <2 x i32> <i32 6, i32 7>
+; CHECK-NEXT:    [[SPLIT4:%.*]] = shufflevector <8 x double> [[B:%.*]], <8 x double> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[SPLIT5:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT:    [[SPLIT6:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <2 x i32> <i32 4, i32 5>
+; CHECK-NEXT:    [[SPLIT7:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <2 x i32> <i32 6, i32 7>
+; CHECK-NEXT:    [[TMP0:%.*]] = fadd <2 x double> [[SPLIT]], [[SPLIT4]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fadd <2 x double> [[SPLIT1]], [[SPLIT5]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fadd <2 x double> [[SPLIT2]], [[SPLIT6]]
+; CHECK-NEXT:    [[TMP3:%.*]] = fadd <2 x double> [[SPLIT3]], [[SPLIT7]]
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <2 x double> [[TMP0]], i64 0
+; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <4 x double> poison, double [[TMP4]], i64 0
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <2 x double> [[TMP1]], i64 0
+; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <4 x double> [[TMP5]], double [[TMP6]], i64 1
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <2 x double> [[TMP2]], i64 0
+; CHECK-NEXT:    [[TMP9:%.*]] = insertelement <4 x double> [[TMP7]], double [[TMP8]], i64 2
+; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <2 x double> [[TMP3]], i64 0
+; CHECK-NEXT:    [[TMP11:%.*]] = insertelement <4 x double> [[TMP9]], double [[TMP10]], i64 3
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <2 x double> [[TMP0]], i64 1
+; CHECK-NEXT:    [[TMP13:%.*]] = insertelement <4 x double> poison, double [[TMP12]], i64 0
+; CHECK-NEXT:    [[TMP14:%.*]] = extractelement <2 x double> [[TMP1]], i64 1
+; CHECK-NEXT:    [[TMP15:%.*]] = insertelement <4 x double> [[TMP13]], double [[TMP14]], i64 1
+; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <2 x double> [[TMP2]], i64 1
+; CHECK-NEXT:    [[TMP17:%.*]] = insertelement <4 x double> [[TMP15]], double [[TMP16]], i64 2
+; CHECK-NEXT:    [[TMP18:%.*]] = extractelement <2 x double> [[TMP3]], i64 1
+; CHECK-NEXT:    [[TMP19:%.*]] = insertelement <4 x double> [[TMP17]], double [[TMP18]], i64 3
+; CHECK-NEXT:    [[TMP20:%.*]] = shufflevector <4 x double> [[TMP11]], <4 x double> [[TMP19]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
 ; CHECK-NEXT:    ret <8 x double> [[TMP20]]
 ;
 entry:
@@ -42,40 +46,37 @@ define <8 x double> @load_fadd_transpose(ptr %A.Ptr, <8 x double> %b) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x double>, ptr [[A_PTR:%.*]], align 8
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[A_PTR]], i64 2
-; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 8
-; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr double, ptr [[A_PTR]], i64 4
-; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x double>, ptr [[VEC_GEP3]], align 8
-; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr double, ptr [[A_PTR]], i64 6
-; CHECK-NEXT:    [[COL_LOAD6:%.*]] = load <2 x double>, ptr [[VEC_GEP5]], align 8
-; CHECK-NEXT:    [[TMP0:%.*]] = shufflevector <2 x double> [[COL_LOAD]], <2 x double> [[COL_LOAD2]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x double> [[COL_LOAD4]], <2 x double> [[COL_LOAD6]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <4 x double> [[TMP0]], <4 x double> [[TMP1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x double> [[TMP2]], <8 x double> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[SPLIT7:%.*]] = shufflevector <8 x double> [[TMP2]], <8 x double> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    [[SPLIT8:%.*]] = shufflevector <8 x double> [[B:%.*]], <8 x double> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[SPLIT9:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    [[TMP3:%.*]] = fadd <4 x double> [[SPLIT]], [[SPLIT8]]
-; CHECK-NEXT:    [[TMP4:%.*]] = fadd <4 x double> [[SPLIT7]], [[SPLIT9]]
-; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <4 x double> [[TMP3]], i64 0
-; CHECK-NEXT:    [[TMP6:%.*]] = insertelement <2 x double> poison, double [[TMP5]], i64 0
-; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <4 x double> [[TMP4]], i64 0
-; CHECK-NEXT:    [[TMP8:%.*]] = insertelement <2 x double> [[TMP6]], double [[TMP7]], i64 1
-; CHECK-NEXT:    [[TMP9:%.*]] = extractelement <4 x double> [[TMP3]], i64 1
-; CHECK-NEXT:    [[TMP10:%.*]] = insertelement <2 x double> poison, double [[TMP9]], i64 0
-; CHECK-NEXT:    [[TMP11:%.*]] = extractelement <4 x double> [[TMP4]], i64 1
-; CHECK-NEXT:    [[TMP12:%.*]] = insertelement <2 x double> [[TMP10]], double [[TMP11]], i64 1
-; CHECK-NEXT:    [[TMP13:%.*]] = extractelement <4 x double> [[TMP3]], i64 2
-; CHECK-NEXT:    [[TMP14:%.*]] = insertelement <2 x double> poison, double [[TMP13]], i64 0
-; CHECK-NEXT:    [[TMP15:%.*]] = extractelement <4 x double> [[TMP4]], i64 2
-; CHECK-NEXT:    [[TMP16:%.*]] = insertelement <2 x double> [[TMP14]], double [[TMP15]], i64 1
-; CHECK-NEXT:    [[TMP17:%.*]] = extractelement <4 x double> [[TMP3]], i64 3
-; CHECK-NEXT:    [[TMP18:%.*]] = insertelement <2 x double> poison, double [[TMP17]], i64 0
-; CHECK-NEXT:    [[TMP19:%.*]] = extractelement <4 x double> [[TMP4]], i64 3
-; CHECK-NEXT:    [[TMP20:%.*]] = insertelement <2 x double> [[TMP18]], double [[TMP19]], i64 1
-; CHECK-NEXT:    [[TMP21:%.*]] = shufflevector <2 x double> [[TMP8]], <2 x double> [[TMP12]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[TMP22:%.*]] = shufflevector <2 x double> [[TMP16]], <2 x double> [[TMP20]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[TMP23:%.*]] = shufflevector <4 x double> [[TMP21]], <4 x double> [[TMP22]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    ret <8 x double> [[TMP23]]
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr double, ptr [[A_PTR]], i64 4
+; CHECK-NEXT:    [[COL_LOAD3:%.*]] = load <2 x double>, ptr [[VEC_GEP2]], align 8
+; CHECK-NEXT:    [[VEC_GEP4:%.*]] = getelementptr double, ptr [[A_PTR]], i64 6
+; CHECK-NEXT:    [[COL_LOAD5:%.*]] = load <2 x double>, ptr [[VEC_GEP4]], align 8
+; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x double> [[B:%.*]], <8 x double> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[SPLIT6:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT:    [[SPLIT7:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <2 x i32> <i32 4, i32 5>
+; CHECK-NEXT:    [[SPLIT8:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <2 x i32> <i32 6, i32 7>
+; CHECK-NEXT:    [[TMP0:%.*]] = fadd <2 x double> [[COL_LOAD]], [[SPLIT]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fadd <2 x double> [[COL_LOAD1]], [[SPLIT6]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fadd <2 x double> [[COL_LOAD3]], [[SPLIT7]]
+; CHECK-NEXT:    [[TMP3:%.*]] = fadd <2 x double> [[COL_LOAD5]], [[SPLIT8]]
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <2 x double> [[TMP0]], i64 0
+; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <4 x double> poison, double [[TMP4]], i64 0
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <2 x double> [[TMP1]], i64 0
+; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <4 x double> [[TMP5]], double [[TMP6]], i64 1
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <2 x double> [[TMP2]], i64 0
+; CHECK-NEXT:    [[TMP9:%.*]] = insertelement <4 x double> [[TMP7]], double [[TMP8]], i64 2
+; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <2 x double> [[TMP3]], i64 0
+; CHECK-NEXT:    [[TMP11:%.*]] = insertelement <4 x double> [[TMP9]], double [[TMP10]], i64 3
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <2 x double> [[TMP0]], i64 1
+; CHECK-NEXT:    [[TMP13:%.*]] = insertelement <4 x double> poison, double [[TMP12]], i64 0
+; CHECK-NEXT:    [[TMP14:%.*]] = extractelement <2 x double> [[TMP1]], i64 1
+; CHECK-NEXT:    [[TMP15:%.*]] = insertelement <4 x double> [[TMP13]], double [[TMP14]], i64 1
+; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <2 x double> [[TMP2]], i64 1
+; CHECK-NEXT:    [[TMP17:%.*]] = insertelement <4 x double> [[TMP15]], double [[TMP16]], i64 2
+; CHECK-NEXT:    [[TMP18:%.*]] = extractelement <2 x double> [[TMP3]], i64 1
+; CHECK-NEXT:    [[TMP19:%.*]] = insertelement <4 x double> [[TMP17]], double [[TMP18]], i64 3
+; CHECK-NEXT:    [[TMP20:%.*]] = shufflevector <4 x double> [[TMP11]], <4 x double> [[TMP19]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    ret <8 x double> [[TMP20]]
 ;
 
 
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll
index d0c67556224c8..fcf83b03bc3d2 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll
@@ -9,7 +9,7 @@ define <6 x double> @lift_through_add_matching_transpose_dimensions(<6 x double>
 ; CHECK-LABEL:  define <6 x double> @lift_through_add_matching_transpose_dimensions(<6 x double> %a, <6 x double> %b) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[A:%.+]] = fadd <6 x double> %a, %b
-; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[A]], i32 2, i32 3)
+; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[A]], i32 3, i32 2)
 ; CHECK-NEXT:    ret <6 x double> [[T]]
 ;
 entry:
@@ -25,7 +25,7 @@ define <6 x double> @lift_through_add_matching_transpose_dimensions_ops_also_hav
 ; CHECK-NEXT:    [[A:%.+]] = load <6 x double>, ptr %a.ptr
 ; CHECK-NEXT:    [[B:%.+]] = load <6 x double>, ptr %b.ptr
 ; CHECK-NEXT:    [[ADD:%.+]] = fadd <6 x double> [[A]], [[B]]
-; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[ADD]], i32 2, i32 3)
+; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[ADD]], i32 3, i32 2)
 ; CHECK-NEXT:    ret <6 x double> [[T]]
 ;
 entry:
@@ -41,10 +41,28 @@ define <6 x double> @lift_through_add_mismatching_dimensions_1(<6 x double> %a,
 ; CHECK-LABEL:  define <6 x double> @lift_through_add_mismatching_dimensions_1(<6 x double> %a, <6 x double> %b) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[A:%.+]] = fadd <6 x double> %a, %b
-; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[A]], i32 2, i32 3)
+; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[A]], i32 1, i32 6)
+; CHECK-NEXT:    ret <6 x double> [[T]]
+;
+entry:
+  %a.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %a, i32 1, i32 6)
+  %b.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %b, i32 3, i32 2)
+  %add = fadd <6 x double> %a.t, %b.t
+  ret <6 x double> %add
+}
+
+define <6 x double> @lift_through_add_mismatching_dimensions_1_transpose_dimensions_ops_also_have_shape_info(ptr %a.ptr, ptr %b.ptr) {
+; CHECK-LABEL: define <6 x double> @lift_through_add_mismatching_dimensions_1_transpose_dimensions_ops_also_have_shape_info(ptr %a.ptr, ptr %b.ptr)
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[A:%.+]] = load <6 x double>, ptr %a.ptr
+; CHECK-NEXT:    [[B:%.+]] = load <6 x double>, ptr %b.ptr
+; CHECK-NEXT:    [[ADD:%.+]] = fadd <6 x double> [[A]], [[B]]
+; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[ADD]], i32 1, i32 6)
 ; CHECK-NEXT:    ret <6 x double> [[T]]
 ;
 entry:
+  %a = load <6 x double>, ptr %a.ptr
+  %b = load <6 x double>, ptr %b.ptr
   %a.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %a, i32 1, i32 6)
   %b.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %b, i32 3, i32 2)
   %add = fadd <6 x double> %a.t, %b.t
@@ -55,7 +73,7 @@ define <6 x double> @lift_through_add_mismatching_dimensions_2(<6 x double> %a,
 ; CHECK-LABEL:  define <6 x double> @lift_through_add_mismatching_dimensions_2(<6 x double> %a, <6 x double> %b) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[A:%.+]] = fadd <6 x double> %a, %b
-; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[A]], i32 1, i32 6)
+; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[A]], i32 3, i32 2)
 ; CHECK-NEXT:    ret <6 x double> [[T]]
 ;
 
@@ -66,6 +84,24 @@ entry:
   ret <6 x double> %add
 }
 
+define <6 x double> @lift_through_add_mismatching_dimensions_2_transpose_dimensions_ops_also_have_shape_info(ptr %a.ptr, ptr %b.ptr) {
+; CHECK-LABEL: define <6 x double> @lift_through_add_mismatching_dimensions_2_transpose_dimensions_ops_also_have_shape_info(ptr %a.ptr, ptr %b.ptr)
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[A:%.+]] = load <6 x double>, ptr %a.ptr
+; CHECK-NEXT:    [[B:%.+]] = load <6 x double>, ptr %b.ptr
+; CHECK-NEXT:    [[ADD:%.+]] = fadd <6 x double> [[A]], [[B]]
+; CHECK-NEXT:    [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[ADD]], i32 3, i32 2)
+; CHECK-NEXT:    ret <6 x double> [[T]]
+;
+entry:
+  %a = load <6 x double>, ptr %a.ptr
+  %b = load <6 x double>, ptr %b.ptr
+  %a.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %a, i32 3, i32 2)
+  %b.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %b, i32 6, i32 1)
+  %add = fadd <6 x double> %a.t, %b.t
+  ret <6 x double> %add
+}
+
 define <9 x double> @lift_through_multiply(<6 x double> %a, <6 x double> %b) {
 ; CHECK-LABEL: define <9 x double> @lift_through_multiply(<6 x double> %a, <6 x double> %b) {
 ; CHECK-NEXT:  entry:

Copy link
Collaborator

@francisvm francisvm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@fhahn fhahn merged commit dbe4143 into llvm:main Feb 12, 2024
@fhahn fhahn deleted the matrix-hoist-transpose branch February 12, 2024 18:45
fhahn added a commit to swiftlang/llvm-project that referenced this pull request Feb 13, 2024
Row and column arguments for matrix_transpose indicate the shape of the
operand. When hoisting the transpose to the result of the add, the add
operates on the original operand's shape, and so does the hoisted
transpose.

This patch also adds an assert that the shape for the original add and
the transpose match, as well as the shape of the new add matches the
cached shape for it.

The assert could potentially be moved to
updateShapeAndReplaceAllUsesWith.

(cherry-picked from dbe4143)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants