Skip to content

Commit f89fe08

Browse files
authored
[Matrix] Convert column-vector ops feeding dot product to row-vectors. (#72647)
Generalize the logic used to convert column-vector ops to row-vectors to support converting chains of operations. A potential next step is to further generalize this to convert column-vector ops to row-vector ops in general, not just for operands of dot products. Dot-product handling would then be driven by the general conversion, rather than the other way around. PR: #72647
1 parent 83eb812 commit f89fe08

File tree

2 files changed

+50
-51
lines changed

2 files changed

+50
-51
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,8 +1332,8 @@ class LowerMatrixIntrinsics {
13321332
if (!IsIntVec && !FMF.allowReassoc())
13331333
return;
13341334

1335-
auto CanBeFlattened = [this](Value *Op) {
1336-
if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end())
1335+
auto CanBeFlattened = [](Value *Op) {
1336+
if (match(Op, m_BinOp()))
13371337
return true;
13381338
return match(
13391339
Op, m_OneUse(m_CombineOr(
@@ -1346,6 +1346,9 @@ class LowerMatrixIntrinsics {
13461346
// the returned cost is < 0, the argument is cheaper to use in the
13471347
// dot-product lowering.
13481348
auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) {
1349+
if (ShapeMap.find(Op) == ShapeMap.end())
1350+
return InstructionCost::getInvalid();
1351+
13491352
if (!isa<Instruction>(Op))
13501353
return InstructionCost(0);
13511354

@@ -1356,7 +1359,7 @@ class LowerMatrixIntrinsics {
13561359
InstructionCost EmbedCost(0);
13571360
// Roughly estimate the cost for embedding the columns into a vector.
13581361
for (unsigned I = 1; I < N; ++I)
1359-
EmbedCost -=
1362+
EmbedCost +=
13601363
TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
13611364
std::nullopt, TTI::TCK_RecipThroughput);
13621365
return EmbedCost;
@@ -1378,7 +1381,7 @@ class LowerMatrixIntrinsics {
13781381
// vector.
13791382
InstructionCost EmbedCost(0);
13801383
for (unsigned I = 1; I < N; ++I)
1381-
EmbedCost +=
1384+
EmbedCost -=
13821385
TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
13831386
std::nullopt, TTI::TCK_RecipThroughput);
13841387
return EmbedCost;
@@ -1391,7 +1394,29 @@ class LowerMatrixIntrinsics {
13911394
return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -
13921395
N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);
13931396
};
1394-
auto LHSCost = GetCostForArg(LHS, LShape.NumColumns);
1397+
1398+
// Iterate over LHS and operations feeding LHS and check if it is profitable
1399+
// to flatten the visited ops. For each op, we compute the difference
1400+
// between the flattened and matrix versions.
1401+
SmallPtrSet<Value *, 4> Seen;
1402+
SmallVector<Value *> WorkList;
1403+
SmallVector<Value *> ToFlatten;
1404+
WorkList.push_back(LHS);
1405+
InstructionCost LHSCost(0);
1406+
while (!WorkList.empty()) {
1407+
Value *Op = WorkList.pop_back_val();
1408+
if (!Seen.insert(Op).second)
1409+
continue;
1410+
1411+
InstructionCost OpCost = GetCostForArg(Op, LShape.NumColumns);
1412+
if (OpCost + LHSCost >= LHSCost)
1413+
continue;
1414+
1415+
LHSCost += OpCost;
1416+
ToFlatten.push_back(Op);
1417+
if (auto *I = dyn_cast<Instruction>(Op))
1418+
WorkList.append(I->op_begin(), I->op_end());
1419+
}
13951420

13961421
// We compare the costs of a vector.reduce.add to sequential add.
13971422
int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
@@ -1412,16 +1437,16 @@ class LowerMatrixIntrinsics {
14121437
FusedInsts.insert(MatMul);
14131438
IRBuilder<> Builder(MatMul);
14141439
auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
1415-
this](Value *Op) -> Value * {
1440+
this](Value *Op) {
14161441
// Matmul must be the only user of loads because we don't use LowerLoad
14171442
// for row vectors (LowerLoad results in scalar loads and shufflevectors
14181443
// instead of single vector load).
14191444
if (!CanBeFlattened(Op))
1420-
return Op;
1445+
return;
14211446

14221447
if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) {
14231448
ShapeMap[Op] = ShapeMap[Op].t();
1424-
return Op;
1449+
return;
14251450
}
14261451

14271452
FusedInsts.insert(cast<Instruction>(Op));
@@ -1432,16 +1457,19 @@ class LowerMatrixIntrinsics {
14321457
auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
14331458
Op->replaceAllUsesWith(NewLoad);
14341459
cast<Instruction>(Op)->eraseFromParent();
1435-
return NewLoad;
1460+
return;
14361461
} else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
14371462
m_Value(Arg)))) {
14381463
ToRemove.push_back(cast<Instruction>(Op));
1439-
return Arg;
1464+
Op->replaceAllUsesWith(Arg);
1465+
return;
14401466
}
1441-
1442-
return Op;
14431467
};
1444-
LHS = FlattenArg(LHS);
1468+
1469+
for (auto *V : ToFlatten)
1470+
FlattenArg(V);
1471+
1472+
LHS = MatMul->getArgOperand(0);
14451473

14461474
// Insert mul/fmul and llvm.vector.reduce.fadd
14471475
Value *Mul =

llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int.ll

Lines changed: 9 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -119,44 +119,15 @@ entry:
119119
define <1 x i32> @add_chain_feeding_dotproduct_i32_v8_1(<8 x i32> %a, <8 x i32> %b, <8 x i32> %c, <8 x i32> %d) {
120120
; CHECK-LABEL: @add_chain_feeding_dotproduct_i32_v8_1(
121121
; CHECK-NEXT: entry:
122-
; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x i32> [[A:%.*]], <8 x i32> poison, <1 x i32> zeroinitializer
123-
; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 1>
124-
; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 2>
125-
; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 3>
126-
; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 4>
127-
; CHECK-NEXT: [[SPLIT5:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 5>
128-
; CHECK-NEXT: [[SPLIT6:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 6>
129-
; CHECK-NEXT: [[SPLIT7:%.*]] = shufflevector <8 x i32> [[A]], <8 x i32> poison, <1 x i32> <i32 7>
130-
; CHECK-NEXT: [[SPLIT8:%.*]] = shufflevector <8 x i32> [[B:%.*]], <8 x i32> poison, <1 x i32> zeroinitializer
131-
; CHECK-NEXT: [[SPLIT9:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 1>
132-
; CHECK-NEXT: [[SPLIT10:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 2>
133-
; CHECK-NEXT: [[SPLIT11:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 3>
134-
; CHECK-NEXT: [[SPLIT12:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 4>
135-
; CHECK-NEXT: [[SPLIT13:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 5>
136-
; CHECK-NEXT: [[SPLIT14:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 6>
137-
; CHECK-NEXT: [[SPLIT15:%.*]] = shufflevector <8 x i32> [[B]], <8 x i32> poison, <1 x i32> <i32 7>
138-
; CHECK-NEXT: [[TMP0:%.*]] = add <1 x i32> [[SPLIT]], [[SPLIT8]]
139-
; CHECK-NEXT: [[TMP1:%.*]] = add <1 x i32> [[SPLIT1]], [[SPLIT9]]
140-
; CHECK-NEXT: [[TMP2:%.*]] = add <1 x i32> [[SPLIT2]], [[SPLIT10]]
141-
; CHECK-NEXT: [[TMP3:%.*]] = add <1 x i32> [[SPLIT3]], [[SPLIT11]]
142-
; CHECK-NEXT: [[TMP4:%.*]] = add <1 x i32> [[SPLIT4]], [[SPLIT12]]
143-
; CHECK-NEXT: [[TMP5:%.*]] = add <1 x i32> [[SPLIT5]], [[SPLIT13]]
144-
; CHECK-NEXT: [[TMP6:%.*]] = add <1 x i32> [[SPLIT6]], [[SPLIT14]]
145-
; CHECK-NEXT: [[TMP7:%.*]] = add <1 x i32> [[SPLIT7]], [[SPLIT15]]
146-
; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <1 x i32> [[TMP0]], <1 x i32> [[TMP1]], <2 x i32> <i32 0, i32 1>
147-
; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <1 x i32> [[TMP2]], <1 x i32> [[TMP3]], <2 x i32> <i32 0, i32 1>
148-
; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <1 x i32> [[TMP4]], <1 x i32> [[TMP5]], <2 x i32> <i32 0, i32 1>
149-
; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <1 x i32> [[TMP6]], <1 x i32> [[TMP7]], <2 x i32> <i32 0, i32 1>
150-
; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <2 x i32> [[TMP8]], <2 x i32> [[TMP9]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
151-
; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x i32> [[TMP10]], <2 x i32> [[TMP11]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
152-
; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <4 x i32> [[TMP12]], <4 x i32> [[TMP13]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
153-
; CHECK-NEXT: [[SPLIT16:%.*]] = shufflevector <8 x i32> [[TMP14]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
154-
; CHECK-NEXT: [[SPLIT17:%.*]] = shufflevector <8 x i32> [[C:%.*]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
155-
; CHECK-NEXT: [[TMP15:%.*]] = add <8 x i32> [[SPLIT16]], [[SPLIT17]]
156-
; CHECK-NEXT: [[TMP16:%.*]] = mul <8 x i32> [[TMP15]], [[D:%.*]]
157-
; CHECK-NEXT: [[TMP17:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP16]])
158-
; CHECK-NEXT: [[TMP18:%.*]] = insertelement <1 x i32> poison, i32 [[TMP17]], i64 0
159-
; CHECK-NEXT: ret <1 x i32> [[TMP18]]
122+
; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x i32> [[A:%.*]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
123+
; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x i32> [[B:%.*]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
124+
; CHECK-NEXT: [[TMP0:%.*]] = add <8 x i32> [[SPLIT]], [[SPLIT1]]
125+
; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <8 x i32> [[C:%.*]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
126+
; CHECK-NEXT: [[TMP1:%.*]] = add <8 x i32> [[TMP0]], [[SPLIT2]]
127+
; CHECK-NEXT: [[TMP2:%.*]] = mul <8 x i32> [[TMP1]], [[D:%.*]]
128+
; CHECK-NEXT: [[TMP3:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP2]])
129+
; CHECK-NEXT: [[TMP4:%.*]] = insertelement <1 x i32> poison, i32 [[TMP3]], i64 0
130+
; CHECK-NEXT: ret <1 x i32> [[TMP4]]
160131
;
161132
entry:
162133
%add.1 = add <8 x i32> %a, %b

0 commit comments

Comments
 (0)