-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[Matrix] Fix dimensions when hoisting transpose across add. #81507
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Row and column arguments for matrix_transpose indicate the shape of the operand. When hoisting the transpose to the result of the add, the add operates on the original operand's shape, and so does the hoisted transpose. This patch also adds an assert that the shape for the original add and the transpose match, as well as the shape of the new add matches the cached shape for it. The assert could potentially be moved to updateShapeAndReplaceAllUsesWith.
@llvm/pr-subscribers-llvm-transforms Author: Florian Hahn (fhahn) ChangesRow and column arguments for matrix_transpose indicate the shape of the operand. When hoisting the transpose to the result of the add, the add operates on the original operand's shape, and so does the hoisted transpose. This patch also adds an assert that the shape for the original add and the transpose match, as well as the shape of the new add matches the cached shape for it. The assert could potentially be moved to updateShapeAndReplaceAllUsesWith. Full diff: https://github.com/llvm/llvm-project/pull/81507.diff 3 Files Affected:
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 03e289f7a087a..075388f69a85b 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -765,8 +765,9 @@ class LowerMatrixIntrinsics {
auto S = ShapeMap.find(&Old);
if (S != ShapeMap.end()) {
ShapeMap.erase(S);
- if (supportsShapeInfo(New))
+ if (supportsShapeInfo(New)) {
ShapeMap.insert({New, S->second});
+ }
}
Old.replaceAllUsesWith(New);
}
@@ -898,20 +899,28 @@ class LowerMatrixIntrinsics {
updateShapeAndReplaceAllUsesWith(I, NewInst);
CleanupBinOp(I, A, B);
}
- // A^t + B ^t -> (A + B)^t
+ // A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If
+ // the shape of the second transpose is different, there's a shape conflict
+ // which gets resolved by picking the shape of the first operand.
else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) &&
match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) &&
match(B, m_Intrinsic<Intrinsic::matrix_transpose>(
- m_Value(BT), m_ConstantInt(R), m_ConstantInt(C)))) {
+ m_Value(BT), m_ConstantInt(), m_ConstantInt()))) {
IRBuilder<> Builder(&I);
- Value *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
- setShapeInfo(Add, {C, R});
+ auto *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
+ setShapeInfo(Add, {R, C});
MatrixBuilder MBuilder(Builder);
Instruction *NewInst = MBuilder.CreateMatrixTranspose(
- Add, C->getZExtValue(), R->getZExtValue(), "mfadd_t");
+ Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t");
updateShapeAndReplaceAllUsesWith(I, NewInst);
+ assert(computeShapeInfoForInst(NewInst, ShapeMap) ==
+ computeShapeInfoForInst(&I, ShapeMap) &&
+ "Shape of new instruction doesn't match original shape.");
CleanupBinOp(I, A, B);
+ assert(computeShapeInfoForInst(Add, ShapeMap).value_or(ShapeMap[Add]) ==
+ ShapeMap[Add] &&
+ "Shape of updated addition doesn't match cached shape.");
}
}
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll
index 82ae93b31035d..33a338dbc4ea0 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll
@@ -4,31 +4,35 @@
define <8 x double> @fadd_transpose(<8 x double> %a, <8 x double> %b) {
; CHECK-LABEL: @fadd_transpose(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x double> [[A:%.*]], <8 x double> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <8 x double> [[B:%.*]], <8 x double> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT: [[TMP0:%.*]] = fadd <4 x double> [[SPLIT]], [[SPLIT3]]
-; CHECK-NEXT: [[TMP1:%.*]] = fadd <4 x double> [[SPLIT2]], [[SPLIT4]]
-; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x double> [[TMP0]], i64 0
-; CHECK-NEXT: [[TMP3:%.*]] = insertelement <2 x double> poison, double [[TMP2]], i64 0
-; CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x double> [[TMP1]], i64 0
-; CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x double> [[TMP3]], double [[TMP4]], i64 1
-; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x double> [[TMP0]], i64 1
-; CHECK-NEXT: [[TMP7:%.*]] = insertelement <2 x double> poison, double [[TMP6]], i64 0
-; CHECK-NEXT: [[TMP8:%.*]] = extractelement <4 x double> [[TMP1]], i64 1
-; CHECK-NEXT: [[TMP9:%.*]] = insertelement <2 x double> [[TMP7]], double [[TMP8]], i64 1
-; CHECK-NEXT: [[TMP10:%.*]] = extractelement <4 x double> [[TMP0]], i64 2
-; CHECK-NEXT: [[TMP11:%.*]] = insertelement <2 x double> poison, double [[TMP10]], i64 0
-; CHECK-NEXT: [[TMP12:%.*]] = extractelement <4 x double> [[TMP1]], i64 2
-; CHECK-NEXT: [[TMP13:%.*]] = insertelement <2 x double> [[TMP11]], double [[TMP12]], i64 1
-; CHECK-NEXT: [[TMP14:%.*]] = extractelement <4 x double> [[TMP0]], i64 3
-; CHECK-NEXT: [[TMP15:%.*]] = insertelement <2 x double> poison, double [[TMP14]], i64 0
-; CHECK-NEXT: [[TMP16:%.*]] = extractelement <4 x double> [[TMP1]], i64 3
-; CHECK-NEXT: [[TMP17:%.*]] = insertelement <2 x double> [[TMP15]], double [[TMP16]], i64 1
-; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <2 x double> [[TMP5]], <2 x double> [[TMP9]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT: [[TMP19:%.*]] = shufflevector <2 x double> [[TMP13]], <2 x double> [[TMP17]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT: [[TMP20:%.*]] = shufflevector <4 x double> [[TMP18]], <4 x double> [[TMP19]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x double> [[A:%.*]], <8 x double> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <2 x i32> <i32 4, i32 5>
+; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <2 x i32> <i32 6, i32 7>
+; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <8 x double> [[B:%.*]], <8 x double> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT: [[SPLIT5:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT: [[SPLIT6:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <2 x i32> <i32 4, i32 5>
+; CHECK-NEXT: [[SPLIT7:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <2 x i32> <i32 6, i32 7>
+; CHECK-NEXT: [[TMP0:%.*]] = fadd <2 x double> [[SPLIT]], [[SPLIT4]]
+; CHECK-NEXT: [[TMP1:%.*]] = fadd <2 x double> [[SPLIT1]], [[SPLIT5]]
+; CHECK-NEXT: [[TMP2:%.*]] = fadd <2 x double> [[SPLIT2]], [[SPLIT6]]
+; CHECK-NEXT: [[TMP3:%.*]] = fadd <2 x double> [[SPLIT3]], [[SPLIT7]]
+; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x double> [[TMP0]], i64 0
+; CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x double> poison, double [[TMP4]], i64 0
+; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x double> [[TMP1]], i64 0
+; CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x double> [[TMP5]], double [[TMP6]], i64 1
+; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x double> [[TMP2]], i64 0
+; CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x double> [[TMP7]], double [[TMP8]], i64 2
+; CHECK-NEXT: [[TMP10:%.*]] = extractelement <2 x double> [[TMP3]], i64 0
+; CHECK-NEXT: [[TMP11:%.*]] = insertelement <4 x double> [[TMP9]], double [[TMP10]], i64 3
+; CHECK-NEXT: [[TMP12:%.*]] = extractelement <2 x double> [[TMP0]], i64 1
+; CHECK-NEXT: [[TMP13:%.*]] = insertelement <4 x double> poison, double [[TMP12]], i64 0
+; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x double> [[TMP1]], i64 1
+; CHECK-NEXT: [[TMP15:%.*]] = insertelement <4 x double> [[TMP13]], double [[TMP14]], i64 1
+; CHECK-NEXT: [[TMP16:%.*]] = extractelement <2 x double> [[TMP2]], i64 1
+; CHECK-NEXT: [[TMP17:%.*]] = insertelement <4 x double> [[TMP15]], double [[TMP16]], i64 2
+; CHECK-NEXT: [[TMP18:%.*]] = extractelement <2 x double> [[TMP3]], i64 1
+; CHECK-NEXT: [[TMP19:%.*]] = insertelement <4 x double> [[TMP17]], double [[TMP18]], i64 3
+; CHECK-NEXT: [[TMP20:%.*]] = shufflevector <4 x double> [[TMP11]], <4 x double> [[TMP19]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: ret <8 x double> [[TMP20]]
;
entry:
@@ -42,40 +46,37 @@ define <8 x double> @load_fadd_transpose(ptr %A.Ptr, <8 x double> %b) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x double>, ptr [[A_PTR:%.*]], align 8
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[A_PTR]], i64 2
-; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 8
-; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr double, ptr [[A_PTR]], i64 4
-; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x double>, ptr [[VEC_GEP3]], align 8
-; CHECK-NEXT: [[VEC_GEP5:%.*]] = getelementptr double, ptr [[A_PTR]], i64 6
-; CHECK-NEXT: [[COL_LOAD6:%.*]] = load <2 x double>, ptr [[VEC_GEP5]], align 8
-; CHECK-NEXT: [[TMP0:%.*]] = shufflevector <2 x double> [[COL_LOAD]], <2 x double> [[COL_LOAD2]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x double> [[COL_LOAD4]], <2 x double> [[COL_LOAD6]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x double> [[TMP0]], <4 x double> [[TMP1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x double> [[TMP2]], <8 x double> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT: [[SPLIT7:%.*]] = shufflevector <8 x double> [[TMP2]], <8 x double> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT: [[SPLIT8:%.*]] = shufflevector <8 x double> [[B:%.*]], <8 x double> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT: [[SPLIT9:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT: [[TMP3:%.*]] = fadd <4 x double> [[SPLIT]], [[SPLIT8]]
-; CHECK-NEXT: [[TMP4:%.*]] = fadd <4 x double> [[SPLIT7]], [[SPLIT9]]
-; CHECK-NEXT: [[TMP5:%.*]] = extractelement <4 x double> [[TMP3]], i64 0
-; CHECK-NEXT: [[TMP6:%.*]] = insertelement <2 x double> poison, double [[TMP5]], i64 0
-; CHECK-NEXT: [[TMP7:%.*]] = extractelement <4 x double> [[TMP4]], i64 0
-; CHECK-NEXT: [[TMP8:%.*]] = insertelement <2 x double> [[TMP6]], double [[TMP7]], i64 1
-; CHECK-NEXT: [[TMP9:%.*]] = extractelement <4 x double> [[TMP3]], i64 1
-; CHECK-NEXT: [[TMP10:%.*]] = insertelement <2 x double> poison, double [[TMP9]], i64 0
-; CHECK-NEXT: [[TMP11:%.*]] = extractelement <4 x double> [[TMP4]], i64 1
-; CHECK-NEXT: [[TMP12:%.*]] = insertelement <2 x double> [[TMP10]], double [[TMP11]], i64 1
-; CHECK-NEXT: [[TMP13:%.*]] = extractelement <4 x double> [[TMP3]], i64 2
-; CHECK-NEXT: [[TMP14:%.*]] = insertelement <2 x double> poison, double [[TMP13]], i64 0
-; CHECK-NEXT: [[TMP15:%.*]] = extractelement <4 x double> [[TMP4]], i64 2
-; CHECK-NEXT: [[TMP16:%.*]] = insertelement <2 x double> [[TMP14]], double [[TMP15]], i64 1
-; CHECK-NEXT: [[TMP17:%.*]] = extractelement <4 x double> [[TMP3]], i64 3
-; CHECK-NEXT: [[TMP18:%.*]] = insertelement <2 x double> poison, double [[TMP17]], i64 0
-; CHECK-NEXT: [[TMP19:%.*]] = extractelement <4 x double> [[TMP4]], i64 3
-; CHECK-NEXT: [[TMP20:%.*]] = insertelement <2 x double> [[TMP18]], double [[TMP19]], i64 1
-; CHECK-NEXT: [[TMP21:%.*]] = shufflevector <2 x double> [[TMP8]], <2 x double> [[TMP12]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT: [[TMP22:%.*]] = shufflevector <2 x double> [[TMP16]], <2 x double> [[TMP20]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT: [[TMP23:%.*]] = shufflevector <4 x double> [[TMP21]], <4 x double> [[TMP22]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT: ret <8 x double> [[TMP23]]
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[A_PTR]], i64 4
+; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <2 x double>, ptr [[VEC_GEP2]], align 8
+; CHECK-NEXT: [[VEC_GEP4:%.*]] = getelementptr double, ptr [[A_PTR]], i64 6
+; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x double>, ptr [[VEC_GEP4]], align 8
+; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x double> [[B:%.*]], <8 x double> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT: [[SPLIT6:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT: [[SPLIT7:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <2 x i32> <i32 4, i32 5>
+; CHECK-NEXT: [[SPLIT8:%.*]] = shufflevector <8 x double> [[B]], <8 x double> poison, <2 x i32> <i32 6, i32 7>
+; CHECK-NEXT: [[TMP0:%.*]] = fadd <2 x double> [[COL_LOAD]], [[SPLIT]]
+; CHECK-NEXT: [[TMP1:%.*]] = fadd <2 x double> [[COL_LOAD1]], [[SPLIT6]]
+; CHECK-NEXT: [[TMP2:%.*]] = fadd <2 x double> [[COL_LOAD3]], [[SPLIT7]]
+; CHECK-NEXT: [[TMP3:%.*]] = fadd <2 x double> [[COL_LOAD5]], [[SPLIT8]]
+; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x double> [[TMP0]], i64 0
+; CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x double> poison, double [[TMP4]], i64 0
+; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x double> [[TMP1]], i64 0
+; CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x double> [[TMP5]], double [[TMP6]], i64 1
+; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x double> [[TMP2]], i64 0
+; CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x double> [[TMP7]], double [[TMP8]], i64 2
+; CHECK-NEXT: [[TMP10:%.*]] = extractelement <2 x double> [[TMP3]], i64 0
+; CHECK-NEXT: [[TMP11:%.*]] = insertelement <4 x double> [[TMP9]], double [[TMP10]], i64 3
+; CHECK-NEXT: [[TMP12:%.*]] = extractelement <2 x double> [[TMP0]], i64 1
+; CHECK-NEXT: [[TMP13:%.*]] = insertelement <4 x double> poison, double [[TMP12]], i64 0
+; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x double> [[TMP1]], i64 1
+; CHECK-NEXT: [[TMP15:%.*]] = insertelement <4 x double> [[TMP13]], double [[TMP14]], i64 1
+; CHECK-NEXT: [[TMP16:%.*]] = extractelement <2 x double> [[TMP2]], i64 1
+; CHECK-NEXT: [[TMP17:%.*]] = insertelement <4 x double> [[TMP15]], double [[TMP16]], i64 2
+; CHECK-NEXT: [[TMP18:%.*]] = extractelement <2 x double> [[TMP3]], i64 1
+; CHECK-NEXT: [[TMP19:%.*]] = insertelement <4 x double> [[TMP17]], double [[TMP18]], i64 3
+; CHECK-NEXT: [[TMP20:%.*]] = shufflevector <4 x double> [[TMP11]], <4 x double> [[TMP19]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT: ret <8 x double> [[TMP20]]
;
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll
index d0c67556224c8..fcf83b03bc3d2 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts-lifting.ll
@@ -9,7 +9,7 @@ define <6 x double> @lift_through_add_matching_transpose_dimensions(<6 x double>
; CHECK-LABEL: define <6 x double> @lift_through_add_matching_transpose_dimensions(<6 x double> %a, <6 x double> %b) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[A:%.+]] = fadd <6 x double> %a, %b
-; CHECK-NEXT: [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[A]], i32 2, i32 3)
+; CHECK-NEXT: [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[A]], i32 3, i32 2)
; CHECK-NEXT: ret <6 x double> [[T]]
;
entry:
@@ -25,7 +25,7 @@ define <6 x double> @lift_through_add_matching_transpose_dimensions_ops_also_hav
; CHECK-NEXT: [[A:%.+]] = load <6 x double>, ptr %a.ptr
; CHECK-NEXT: [[B:%.+]] = load <6 x double>, ptr %b.ptr
; CHECK-NEXT: [[ADD:%.+]] = fadd <6 x double> [[A]], [[B]]
-; CHECK-NEXT: [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[ADD]], i32 2, i32 3)
+; CHECK-NEXT: [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[ADD]], i32 3, i32 2)
; CHECK-NEXT: ret <6 x double> [[T]]
;
entry:
@@ -41,10 +41,28 @@ define <6 x double> @lift_through_add_mismatching_dimensions_1(<6 x double> %a,
; CHECK-LABEL: define <6 x double> @lift_through_add_mismatching_dimensions_1(<6 x double> %a, <6 x double> %b) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[A:%.+]] = fadd <6 x double> %a, %b
-; CHECK-NEXT: [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[A]], i32 2, i32 3)
+; CHECK-NEXT: [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[A]], i32 1, i32 6)
+; CHECK-NEXT: ret <6 x double> [[T]]
+;
+entry:
+ %a.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %a, i32 1, i32 6)
+ %b.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %b, i32 3, i32 2)
+ %add = fadd <6 x double> %a.t, %b.t
+ ret <6 x double> %add
+}
+
+define <6 x double> @lift_through_add_mismatching_dimensions_1_transpose_dimensions_ops_also_have_shape_info(ptr %a.ptr, ptr %b.ptr) {
+; CHECK-LABEL: define <6 x double> @lift_through_add_mismatching_dimensions_1_transpose_dimensions_ops_also_have_shape_info(ptr %a.ptr, ptr %b.ptr)
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[A:%.+]] = load <6 x double>, ptr %a.ptr
+; CHECK-NEXT: [[B:%.+]] = load <6 x double>, ptr %b.ptr
+; CHECK-NEXT: [[ADD:%.+]] = fadd <6 x double> [[A]], [[B]]
+; CHECK-NEXT: [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[ADD]], i32 1, i32 6)
; CHECK-NEXT: ret <6 x double> [[T]]
;
entry:
+ %a = load <6 x double>, ptr %a.ptr
+ %b = load <6 x double>, ptr %b.ptr
%a.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %a, i32 1, i32 6)
%b.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %b, i32 3, i32 2)
%add = fadd <6 x double> %a.t, %b.t
@@ -55,7 +73,7 @@ define <6 x double> @lift_through_add_mismatching_dimensions_2(<6 x double> %a,
; CHECK-LABEL: define <6 x double> @lift_through_add_mismatching_dimensions_2(<6 x double> %a, <6 x double> %b) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[A:%.+]] = fadd <6 x double> %a, %b
-; CHECK-NEXT: [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[A]], i32 1, i32 6)
+; CHECK-NEXT: [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[A]], i32 3, i32 2)
; CHECK-NEXT: ret <6 x double> [[T]]
;
@@ -66,6 +84,24 @@ entry:
ret <6 x double> %add
}
+define <6 x double> @lift_through_add_mismatching_dimensions_2_transpose_dimensions_ops_also_have_shape_info(ptr %a.ptr, ptr %b.ptr) {
+; CHECK-LABEL: define <6 x double> @lift_through_add_mismatching_dimensions_2_transpose_dimensions_ops_also_have_shape_info(ptr %a.ptr, ptr %b.ptr)
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[A:%.+]] = load <6 x double>, ptr %a.ptr
+; CHECK-NEXT: [[B:%.+]] = load <6 x double>, ptr %b.ptr
+; CHECK-NEXT: [[ADD:%.+]] = fadd <6 x double> [[A]], [[B]]
+; CHECK-NEXT: [[T:%.+]] = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> [[ADD]], i32 3, i32 2)
+; CHECK-NEXT: ret <6 x double> [[T]]
+;
+entry:
+ %a = load <6 x double>, ptr %a.ptr
+ %b = load <6 x double>, ptr %b.ptr
+ %a.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %a, i32 3, i32 2)
+ %b.t = call <6 x double> @llvm.matrix.transpose.v6f64(<6 x double> %b, i32 6, i32 1)
+ %add = fadd <6 x double> %a.t, %b.t
+ ret <6 x double> %add
+}
+
define <9 x double> @lift_through_multiply(<6 x double> %a, <6 x double> %b) {
; CHECK-LABEL: define <9 x double> @lift_through_multiply(<6 x double> %a, <6 x double> %b) {
; CHECK-NEXT: entry:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Row and column arguments for matrix_transpose indicate the shape of the operand. When hoisting the transpose to the result of the add, the add operates on the original operand's shape, and so does the hoisted transpose. This patch also adds an assert that the shape for the original add and the transpose match, as well as the shape of the new add matches the cached shape for it. The assert could potentially be moved to updateShapeAndReplaceAllUsesWith. (cherry-picked from dbe4143)
Row and column arguments for matrix_transpose indicate the shape of the operand. When hoisting the transpose to the result of the add, the add operates on the original operand's shape, and so does the hoisted transpose.
This patch also adds an assert that the shape for the original add and the transpose match, as well as the shape of the new add matches the cached shape for it.
The assert could potentially be moved to updateShapeAndReplaceAllUsesWith.