-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[Matrix] Convert column-vector ops feeding dot product to row-vectors. #72647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-transforms Author: Florian Hahn (fhahn) ChangesGeneralize the logic used to convert column-vector ops to row-vectors to support converting chains of operations. A potential next step is to further generalize this to convert column-vector ops to row-vector ops in general, not just for operands of dot products. Dot-product handling would then be driven by the general conversion, rather than the other way around. Depends on D148429. Differential Revision: https://reviews.llvm.org/D148430 Full diff: https://github.com/llvm/llvm-project/pull/72647.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 72b9db1e73d73dc..c6bb43d3a78cf3e 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1332,8 +1332,8 @@ class LowerMatrixIntrinsics {
if (!IsIntVec && !FMF.allowReassoc())
return;
- auto CanBeFlattened = [this](Value *Op) {
- if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end())
+ auto CanBeFlattened = [](Value *Op) {
+ if (match(Op, m_BinOp()))
return true;
return match(
Op, m_OneUse(m_CombineOr(
@@ -1346,6 +1346,9 @@ class LowerMatrixIntrinsics {
// the returned cost is < 0, the argument is cheaper to use in the
// dot-product lowering.
auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) {
+ if (ShapeMap.find(Op) == ShapeMap.end())
+ return InstructionCost::getInvalid();
+
if (!isa<Instruction>(Op))
return InstructionCost(0);
@@ -1356,7 +1359,7 @@ class LowerMatrixIntrinsics {
InstructionCost EmbedCost(0);
// Roughly estimate the cost for embedding the columns into a vector.
for (unsigned I = 1; I < N; ++I)
- EmbedCost -=
+ EmbedCost +=
TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
std::nullopt, TTI::TCK_RecipThroughput);
return EmbedCost;
@@ -1378,7 +1381,7 @@ class LowerMatrixIntrinsics {
// vector.
InstructionCost EmbedCost(0);
for (unsigned I = 1; I < N; ++I)
- EmbedCost +=
+ EmbedCost -=
TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
std::nullopt, TTI::TCK_RecipThroughput);
return EmbedCost;
@@ -1391,7 +1394,26 @@ class LowerMatrixIntrinsics {
return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -
N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);
};
- auto LHSCost = GetCostForArg(LHS, LShape.NumColumns);
+
+ SmallPtrSet<Value *, 4> Seen;
+ SmallVector<Value *> WorkList;
+ SmallVector<Value *> ToFlatten;
+ WorkList.push_back(LHS);
+ InstructionCost LHSCost(0);
+ while (!WorkList.empty()) {
+ Value *Op = WorkList.pop_back_val();
+ if (!Seen.insert(Op).second)
+ continue;
+
+ InstructionCost OpCost = GetCostForArg(Op, LShape.NumColumns);
+ if (OpCost + LHSCost >= LHSCost)
+ continue;
+
+ LHSCost += OpCost;
+ ToFlatten.push_back(Op);
+ if (auto *I = dyn_cast<Instruction>(Op))
+ WorkList.append(I->op_begin(), I->op_end());
+ }
// We compare the costs of a vector.reduce.add to sequential add.
int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
@@ -1412,16 +1434,16 @@ class LowerMatrixIntrinsics {
FusedInsts.insert(MatMul);
IRBuilder<> Builder(MatMul);
auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
- this](Value *Op) -> Value * {
+ this](Value *Op) {
// Matmul must be the only user of loads because we don't use LowerLoad
// for row vectors (LowerLoad results in scalar loads and shufflevectors
// instead of single vector load).
if (!CanBeFlattened(Op))
- return Op;
+ return;
if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) {
ShapeMap[Op] = ShapeMap[Op].t();
- return Op;
+ return;
}
FusedInsts.insert(cast<Instruction>(Op));
@@ -1432,16 +1454,19 @@ class LowerMatrixIntrinsics {
auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
Op->replaceAllUsesWith(NewLoad);
cast<Instruction>(Op)->eraseFromParent();
- return NewLoad;
+ return;
} else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(Arg)))) {
ToRemove.push_back(cast<Instruction>(Op));
- return Arg;
+ Op->replaceAllUsesWith(Arg);
+ return;
}
-
- return Op;
};
- LHS = FlattenArg(LHS);
+
+ for (auto *V : ToFlatten)
+ FlattenArg(V);
+
+ LHS = MatMul->getArgOperand(0);
// Insert mul/fmul and llvm.vector.reduce.fadd
Value *Mul =
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll
index 7bbd0c500485511..f15dbed1f1f5134 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll
@@ -119,44 +119,15 @@ entry:
define <1 x i32> @add_chain_feeding_dotproduct_i32_v8_1(<8 x i32> %a, <8 x i32> %b, <8 x i32> %c, <8 x i32> %d) {
; CHECK-LABEL: @add_chain_feeding_dotproduct_i32_v8_1(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x i32> [[A:%.*]], <8 x i32> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 1>
-; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 2>
-; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 3>
-; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 4>
-; CHECK-NEXT: [[SPLIT5:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 5>
-; CHECK-NEXT: [[SPLIT6:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 6>
-; CHECK-NEXT: [[SPLIT7:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 7>
-; CHECK-NEXT: [[SPLIT8:%.*]] = shufflevector <8 x i32> [[B:%.*]], <8 x i32> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT: [[SPLIT9:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 1>
-; CHECK-NEXT: [[SPLIT10:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 2>
-; CHECK-NEXT: [[SPLIT11:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 3>
-; CHECK-NEXT: [[SPLIT12:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 4>
-; CHECK-NEXT: [[SPLIT13:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 5>
-; CHECK-NEXT: [[SPLIT14:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 6>
-; CHECK-NEXT: [[SPLIT15:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 7>
-; CHECK-NEXT: [[TMP0:%.*]] = add <1 x i32> [[SPLIT]], [[SPLIT8]]
-; CHECK-NEXT: [[TMP1:%.*]] = add <1 x i32> [[SPLIT1]], [[SPLIT9]]
-; CHECK-NEXT: [[TMP2:%.*]] = add <1 x i32> [[SPLIT2]], [[SPLIT10]]
-; CHECK-NEXT: [[TMP3:%.*]] = add <1 x i32> [[SPLIT3]], [[SPLIT11]]
-; CHECK-NEXT: [[TMP4:%.*]] = add <1 x i32> [[SPLIT4]], [[SPLIT12]]
-; CHECK-NEXT: [[TMP5:%.*]] = add <1 x i32> [[SPLIT5]], [[SPLIT13]]
-; CHECK-NEXT: [[TMP6:%.*]] = add <1 x i32> [[SPLIT6]], [[SPLIT14]]
-; CHECK-NEXT: [[TMP7:%.*]] = add <1 x i32> [[SPLIT7]], [[SPLIT15]]
-; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <1 x i32> [[TMP0]], <1 x i32> [[TMP1]], <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <1 x i32> [[TMP2]], <1 x i32> [[TMP3]], <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <1 x i32> [[TMP4]], <1 x i32> [[TMP5]], <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <1 x i32> [[TMP6]], <1 x i32> [[TMP7]], <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <2 x i32> [[TMP8]], <2 x i32> [[TMP9]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x i32> [[TMP10]], <2 x i32> [[TMP11]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <4 x i32> [[TMP12]], <4 x i32> [[TMP13]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT: [[SPLIT16:%.*]] = shufflevector <8 x i32> [[TMP14]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT: [[SPLIT17:%.*]] = shufflevector <8 x i32> [[C:%.*]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT: [[TMP15:%.*]] = add <8 x i32> [[SPLIT16]], [[SPLIT17]]
-; CHECK-NEXT: [[TMP16:%.*]] = mul <8 x i32> [[TMP15]], [[D:%.*]]
-; CHECK-NEXT: [[TMP17:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP16]])
-; CHECK-NEXT: [[TMP18:%.*]] = insertelement <1 x i32> poison, i32 [[TMP17]], i64 0
-; CHECK-NEXT: ret <1 x i32> [[TMP18]]
+; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x i32> [[A:%.*]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x i32> [[B:%.*]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT: [[TMP0:%.*]] = add <8 x i32> [[SPLIT]], [[SPLIT1]]
+; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <8 x i32> [[C:%.*]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT: [[TMP1:%.*]] = add <8 x i32> [[TMP0]], [[SPLIT2]]
+; CHECK-NEXT: [[TMP2:%.*]] = mul <8 x i32> [[TMP1]], [[D:%.*]]
+; CHECK-NEXT: [[TMP3:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP2]])
+; CHECK-NEXT: [[TMP4:%.*]] = insertelement <1 x i32> poison, i32 [[TMP3]], i64 0
+; CHECK-NEXT: ret <1 x i32> [[TMP4]]
;
entry:
%add.1 = add <8 x i32> %a, %b
|
Generalize the logic used to convert column-vector ops to row-vectors to support converting chains of operations. A potential next step is to further generalize this to convert column-vector ops to row-vector ops in general, not just for operands of dot products. Dot-product handling would then be driven by the general conversion, rather than the other way around.
0674d40
to
3dfe867
Compare
ping :) |
ping :) |
Generalize the logic used to convert column-vector ops to row-vectors to support converting chains of operations.
A potential next step is to further generalize this to convert column-vector ops to row-vector ops in general, not just for operands of dot products. Dot-product handling would then be driven by the general conversion, rather than the other way around.