Skip to content

[Matrix] Convert column-vector ops feeding dot product to row-vectors. #72647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 6, 2024

Conversation

fhahn
Copy link
Contributor

@fhahn fhahn commented Nov 17, 2023

Generalize the logic used to convert column-vector ops to row-vectors to support converting chains of operations.

A potential next step is to further generalize this to convert column-vector ops to row-vector ops in general, not just for operands of dot products. Dot-product handling would then be driven by the general conversion, rather than the other way around.

@llvmbot
Copy link
Member

llvmbot commented Nov 17, 2023

@llvm/pr-subscribers-llvm-transforms

Author: Florian Hahn (fhahn)

Changes

Generalize the logic used to convert column-vector ops to row-vectors to support converting chains of operations.

A potential next step is to further generalize this to convert column-vector ops to row-vector ops in general, not just for operands of dot products. Dot-product handling would then be driven by the general conversion, rather than the other way around.

Depends on D148429.

Differential Revision: https://reviews.llvm.org/D148430


Full diff: https://github.com/llvm/llvm-project/pull/72647.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp (+38-13)
  • (modified) llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll (+9-38)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 72b9db1e73d73dc..c6bb43d3a78cf3e 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1332,8 +1332,8 @@ class LowerMatrixIntrinsics {
     if (!IsIntVec && !FMF.allowReassoc())
       return;
 
-    auto CanBeFlattened = [this](Value *Op) {
-      if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end())
+    auto CanBeFlattened = [](Value *Op) {
+      if (match(Op, m_BinOp()))
         return true;
       return match(
           Op, m_OneUse(m_CombineOr(
@@ -1346,6 +1346,9 @@ class LowerMatrixIntrinsics {
     // the returned cost is < 0, the argument is cheaper to use in the
     // dot-product lowering.
     auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) {
+      if (ShapeMap.find(Op) == ShapeMap.end())
+        return InstructionCost::getInvalid();
+
       if (!isa<Instruction>(Op))
         return InstructionCost(0);
 
@@ -1356,7 +1359,7 @@ class LowerMatrixIntrinsics {
         InstructionCost EmbedCost(0);
         // Roughly estimate the cost for embedding the columns into a vector.
         for (unsigned I = 1; I < N; ++I)
-          EmbedCost -=
+          EmbedCost +=
               TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
                                  std::nullopt, TTI::TCK_RecipThroughput);
         return EmbedCost;
@@ -1378,7 +1381,7 @@ class LowerMatrixIntrinsics {
         // vector.
         InstructionCost EmbedCost(0);
         for (unsigned I = 1; I < N; ++I)
-          EmbedCost +=
+          EmbedCost -=
               TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
                                  std::nullopt, TTI::TCK_RecipThroughput);
         return EmbedCost;
@@ -1391,7 +1394,26 @@ class LowerMatrixIntrinsics {
       return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -
              N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);
     };
-    auto LHSCost = GetCostForArg(LHS, LShape.NumColumns);
+
+    SmallPtrSet<Value *, 4> Seen;
+    SmallVector<Value *> WorkList;
+    SmallVector<Value *> ToFlatten;
+    WorkList.push_back(LHS);
+    InstructionCost LHSCost(0);
+    while (!WorkList.empty()) {
+      Value *Op = WorkList.pop_back_val();
+      if (!Seen.insert(Op).second)
+        continue;
+
+      InstructionCost OpCost = GetCostForArg(Op, LShape.NumColumns);
+      if (OpCost + LHSCost >= LHSCost)
+        continue;
+
+      LHSCost += OpCost;
+      ToFlatten.push_back(Op);
+      if (auto *I = dyn_cast<Instruction>(Op))
+        WorkList.append(I->op_begin(), I->op_end());
+    }
 
     // We compare the costs of a vector.reduce.add to sequential add.
     int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
@@ -1412,16 +1434,16 @@ class LowerMatrixIntrinsics {
     FusedInsts.insert(MatMul);
     IRBuilder<> Builder(MatMul);
     auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
-                       this](Value *Op) -> Value * {
+                       this](Value *Op) {
       // Matmul must be the only user of loads because we don't use LowerLoad
       // for row vectors (LowerLoad results in scalar loads and shufflevectors
       // instead of single vector load).
       if (!CanBeFlattened(Op))
-        return Op;
+        return;
 
       if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) {
         ShapeMap[Op] = ShapeMap[Op].t();
-        return Op;
+        return;
       }
 
       FusedInsts.insert(cast<Instruction>(Op));
@@ -1432,16 +1454,19 @@ class LowerMatrixIntrinsics {
         auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
         Op->replaceAllUsesWith(NewLoad);
         cast<Instruction>(Op)->eraseFromParent();
-        return NewLoad;
+        return;
       } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
                                m_Value(Arg)))) {
         ToRemove.push_back(cast<Instruction>(Op));
-        return Arg;
+        Op->replaceAllUsesWith(Arg);
+        return;
       }
-
-      return Op;
     };
-    LHS = FlattenArg(LHS);
+
+    for (auto *V : ToFlatten)
+      FlattenArg(V);
+
+    LHS = MatMul->getArgOperand(0);
 
     // Insert mul/fmul and llvm.vector.reduce.fadd
     Value *Mul =
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll
index 7bbd0c500485511..f15dbed1f1f5134 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll
@@ -119,44 +119,15 @@ entry:
 define <1 x i32> @add_chain_feeding_dotproduct_i32_v8_1(<8 x i32> %a, <8 x i32> %b, <8 x i32> %c, <8 x i32> %d) {
 ; CHECK-LABEL: @add_chain_feeding_dotproduct_i32_v8_1(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x i32> [[A:%.*]], <8 x i32> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT:    [[SPLIT1:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 1>
-; CHECK-NEXT:    [[SPLIT2:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 2>
-; CHECK-NEXT:    [[SPLIT3:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 3>
-; CHECK-NEXT:    [[SPLIT4:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 4>
-; CHECK-NEXT:    [[SPLIT5:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 5>
-; CHECK-NEXT:    [[SPLIT6:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 6>
-; CHECK-NEXT:    [[SPLIT7:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 7>
-; CHECK-NEXT:    [[SPLIT8:%.*]] = shufflevector <8 x i32> [[B:%.*]], <8 x i32> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT:    [[SPLIT9:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 1>
-; CHECK-NEXT:    [[SPLIT10:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 2>
-; CHECK-NEXT:    [[SPLIT11:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 3>
-; CHECK-NEXT:    [[SPLIT12:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 4>
-; CHECK-NEXT:    [[SPLIT13:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 5>
-; CHECK-NEXT:    [[SPLIT14:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 6>
-; CHECK-NEXT:    [[SPLIT15:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 7>
-; CHECK-NEXT:    [[TMP0:%.*]] = add <1 x i32> [[SPLIT]], [[SPLIT8]]
-; CHECK-NEXT:    [[TMP1:%.*]] = add <1 x i32> [[SPLIT1]], [[SPLIT9]]
-; CHECK-NEXT:    [[TMP2:%.*]] = add <1 x i32> [[SPLIT2]], [[SPLIT10]]
-; CHECK-NEXT:    [[TMP3:%.*]] = add <1 x i32> [[SPLIT3]], [[SPLIT11]]
-; CHECK-NEXT:    [[TMP4:%.*]] = add <1 x i32> [[SPLIT4]], [[SPLIT12]]
-; CHECK-NEXT:    [[TMP5:%.*]] = add <1 x i32> [[SPLIT5]], [[SPLIT13]]
-; CHECK-NEXT:    [[TMP6:%.*]] = add <1 x i32> [[SPLIT6]], [[SPLIT14]]
-; CHECK-NEXT:    [[TMP7:%.*]] = add <1 x i32> [[SPLIT7]], [[SPLIT15]]
-; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <1 x i32> [[TMP0]], <1 x i32> [[TMP1]], <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <1 x i32> [[TMP2]], <1 x i32> [[TMP3]], <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT:    [[TMP10:%.*]] = shufflevector <1 x i32> [[TMP4]], <1 x i32> [[TMP5]], <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT:    [[TMP11:%.*]] = shufflevector <1 x i32> [[TMP6]], <1 x i32> [[TMP7]], <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT:    [[TMP12:%.*]] = shufflevector <2 x i32> [[TMP8]], <2 x i32> [[TMP9]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[TMP13:%.*]] = shufflevector <2 x i32> [[TMP10]], <2 x i32> [[TMP11]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[TMP14:%.*]] = shufflevector <4 x i32> [[TMP12]], <4 x i32> [[TMP13]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    [[SPLIT16:%.*]] = shufflevector <8 x i32> [[TMP14]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    [[SPLIT17:%.*]] = shufflevector <8 x i32> [[C:%.*]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    [[TMP15:%.*]] = add <8 x i32> [[SPLIT16]], [[SPLIT17]]
-; CHECK-NEXT:    [[TMP16:%.*]] = mul <8 x i32> [[TMP15]], [[D:%.*]]
-; CHECK-NEXT:    [[TMP17:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP16]])
-; CHECK-NEXT:    [[TMP18:%.*]] = insertelement <1 x i32> poison, i32 [[TMP17]], i64 0
-; CHECK-NEXT:    ret <1 x i32> [[TMP18]]
+; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x i32> [[A:%.*]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[SPLIT1:%.*]] = shufflevector <8 x i32> [[B:%.*]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP0:%.*]] = add <8 x i32> [[SPLIT]], [[SPLIT1]]
+; CHECK-NEXT:    [[SPLIT2:%.*]] = shufflevector <8 x i32> [[C:%.*]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP1:%.*]] = add <8 x i32> [[TMP0]], [[SPLIT2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul <8 x i32> [[TMP1]], [[D:%.*]]
+; CHECK-NEXT:    [[TMP3:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP2]])
+; CHECK-NEXT:    [[TMP4:%.*]] = insertelement <1 x i32> poison, i32 [[TMP3]], i64 0
+; CHECK-NEXT:    ret <1 x i32> [[TMP4]]
 ;
 entry:
   %add.1 = add <8 x i32> %a, %b

Generalize the logic used to convert column-vector ops to row-vectors to
support converting chains of operations.

A potential next step is to further generalize this to convert
column-vector ops to row-vector ops in general, not just for operands of
dot products. Dot-product handling would then be driven by the general
conversion, rather than the other way around.
@fhahn fhahn force-pushed the matrix-to-rowvector branch from 0674d40 to 3dfe867 Compare November 27, 2023 18:50
@fhahn
Copy link
Contributor Author

fhahn commented Nov 27, 2023

ping :)

@fhahn
Copy link
Contributor Author

fhahn commented Jan 9, 2024

ping :)

@fhahn fhahn merged commit f89fe08 into llvm:main Feb 6, 2024
@fhahn fhahn deleted the matrix-to-rowvector branch February 6, 2024 13:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants