Skip to content

Commit 2b8aaef

Browse files
[GISEL] Add IRTranslation for shufflevector on scalable vector types (#80378)
This patch is stacked on #80372, #80307, and #80306. ShuffleVector on scalable vector types gets IRTranslate'd to G_SPLAT_VECTOR since a ShuffleVector that has operates on scalable vectors is a splat vector where the value of the splat vector is the 0th element of the first operand, because the index mask operand is the zeroinitializer (undef and poison are treated as zeroinitializer here). This is analogous to what happens in SelectionDAG for ShuffleVector. `buildSplatVector` is renamed to`buildBuildVectorSplatVector`. I did not make this a separate patch because it would cause problems to revert that change without reverting this change too.
1 parent 043a020 commit 2b8aaef

File tree

15 files changed

+1890
-21
lines changed

15 files changed

+1890
-21
lines changed

llvm/docs/GlobalISel/GenericOpcode.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,11 @@ Concatenate two vectors and shuffle the elements according to the mask operand.
639639
The mask operand should be an IR Constant which exactly matches the
640640
corresponding mask for the IR shufflevector instruction.
641641

642+
G_SPLAT_VECTOR
643+
^^^^^^^^^^^^^^^^
644+
645+
Create a vector where all elements are the scalar from the source operand.
646+
642647
Vector Reduction Operations
643648
---------------------------
644649

llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,8 +1063,7 @@ class MachineIRBuilder {
10631063

10641064
/// Build and insert \p Res = G_BUILD_VECTOR with \p Src replicated to fill
10651065
/// the number of elements
1066-
MachineInstrBuilder buildSplatVector(const DstOp &Res,
1067-
const SrcOp &Src);
1066+
MachineInstrBuilder buildSplatBuildVector(const DstOp &Res, const SrcOp &Src);
10681067

10691068
/// Build and insert \p Res = G_BUILD_VECTOR_TRUNC \p Op0, ...
10701069
///
@@ -1099,6 +1098,15 @@ class MachineIRBuilder {
10991098
MachineInstrBuilder buildShuffleVector(const DstOp &Res, const SrcOp &Src1,
11001099
const SrcOp &Src2, ArrayRef<int> Mask);
11011100

1101+
/// Build and insert \p Res = G_SPLAT_VECTOR \p Val
1102+
///
1103+
/// \pre setBasicBlock or setMI must have been called.
1104+
/// \pre \p Res must be a generic virtual register with vector type.
1105+
/// \pre \p Val must be a generic virtual register with scalar type.
1106+
///
1107+
/// \return a MachineInstrBuilder for the newly created instruction.
1108+
MachineInstrBuilder buildSplatVector(const DstOp &Res, const SrcOp &Val);
1109+
11021110
/// Build and insert \p Res = G_CONCAT_VECTORS \p Op0, ...
11031111
///
11041112
/// G_CONCAT_VECTORS creates a vector from the concatenation of 2 or more

llvm/include/llvm/Support/TargetOpcodes.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,9 @@ HANDLE_TARGET_OPCODE(G_EXTRACT_VECTOR_ELT)
736736
/// Generic shufflevector.
737737
HANDLE_TARGET_OPCODE(G_SHUFFLE_VECTOR)
738738

739+
/// Generic splatvector.
740+
HANDLE_TARGET_OPCODE(G_SPLAT_VECTOR)
741+
739742
/// Generic count trailing zeroes.
740743
HANDLE_TARGET_OPCODE(G_CTTZ)
741744

llvm/include/llvm/Target/GenericOpcodes.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,6 +1450,13 @@ def G_SHUFFLE_VECTOR: GenericInstruction {
14501450
let hasSideEffects = false;
14511451
}
14521452

1453+
// Generic splatvector.
1454+
def G_SPLAT_VECTOR: GenericInstruction {
1455+
let OutOperandList = (outs type0:$dst);
1456+
let InOperandList = (ins type1:$val);
1457+
let hasSideEffects = false;
1458+
}
1459+
14531460
//------------------------------------------------------------------------------
14541461
// Vector reductions
14551462
//------------------------------------------------------------------------------

llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ MachineInstrBuilder CSEMIRBuilder::buildConstant(const DstOp &Res,
309309
// For vectors, CSE the element only for now.
310310
LLT Ty = Res.getLLTTy(*getMRI());
311311
if (Ty.isVector())
312-
return buildSplatVector(Res, buildConstant(Ty.getElementType(), Val));
312+
return buildSplatBuildVector(Res, buildConstant(Ty.getElementType(), Val));
313313

314314
FoldingSetNodeID ID;
315315
GISelInstProfileBuilder ProfBuilder(ID, *getMRI());
@@ -336,7 +336,7 @@ MachineInstrBuilder CSEMIRBuilder::buildFConstant(const DstOp &Res,
336336
// For vectors, CSE the element only for now.
337337
LLT Ty = Res.getLLTTy(*getMRI());
338338
if (Ty.isVector())
339-
return buildSplatVector(Res, buildFConstant(Ty.getElementType(), Val));
339+
return buildSplatBuildVector(Res, buildFConstant(Ty.getElementType(), Val));
340340

341341
FoldingSetNodeID ID;
342342
GISelInstProfileBuilder ProfBuilder(ID, *getMRI());

llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,10 +1598,10 @@ bool IRTranslator::translateGetElementPtr(const User &U,
15981598
// We might need to splat the base pointer into a vector if the offsets
15991599
// are vectors.
16001600
if (WantSplatVector && !PtrTy.isVector()) {
1601-
BaseReg =
1602-
MIRBuilder
1603-
.buildSplatVector(LLT::fixed_vector(VectorWidth, PtrTy), BaseReg)
1604-
.getReg(0);
1601+
BaseReg = MIRBuilder
1602+
.buildSplatBuildVector(LLT::fixed_vector(VectorWidth, PtrTy),
1603+
BaseReg)
1604+
.getReg(0);
16051605
PtrIRTy = FixedVectorType::get(PtrIRTy, VectorWidth);
16061606
PtrTy = getLLTForType(*PtrIRTy, *DL);
16071607
OffsetIRTy = DL->getIndexType(PtrIRTy);
@@ -1639,8 +1639,10 @@ bool IRTranslator::translateGetElementPtr(const User &U,
16391639
LLT IdxTy = MRI->getType(IdxReg);
16401640
if (IdxTy != OffsetTy) {
16411641
if (!IdxTy.isVector() && WantSplatVector) {
1642-
IdxReg = MIRBuilder.buildSplatVector(
1643-
OffsetTy.changeElementType(IdxTy), IdxReg).getReg(0);
1642+
IdxReg = MIRBuilder
1643+
.buildSplatBuildVector(OffsetTy.changeElementType(IdxTy),
1644+
IdxReg)
1645+
.getReg(0);
16441646
}
16451647

16461648
IdxReg = MIRBuilder.buildSExtOrTrunc(OffsetTy, IdxReg).getReg(0);
@@ -2997,6 +2999,19 @@ bool IRTranslator::translateExtractElement(const User &U,
29972999

29983000
bool IRTranslator::translateShuffleVector(const User &U,
29993001
MachineIRBuilder &MIRBuilder) {
3002+
// A ShuffleVector that has operates on scalable vectors is a splat vector
3003+
// where the value of the splat vector is the 0th element of the first
3004+
// operand, since the index mask operand is the zeroinitializer (undef and
3005+
// poison are treated as zeroinitializer here).
3006+
if (U.getOperand(0)->getType()->isScalableTy()) {
3007+
Value *Op0 = U.getOperand(0);
3008+
auto SplatVal = MIRBuilder.buildExtractVectorElementConstant(
3009+
LLT::scalar(Op0->getType()->getScalarSizeInBits()),
3010+
getOrCreateVReg(*Op0), 0);
3011+
MIRBuilder.buildSplatVector(getOrCreateVReg(U), SplatVal);
3012+
return true;
3013+
}
3014+
30003015
ArrayRef<int> Mask;
30013016
if (auto *SVI = dyn_cast<ShuffleVectorInst>(&U))
30023017
Mask = SVI->getShuffleMask();

llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8391,7 +8391,7 @@ static Register getMemsetValue(Register Val, LLT Ty, MachineIRBuilder &MIB) {
83918391

83928392
// For vector types create a G_BUILD_VECTOR.
83938393
if (Ty.isVector())
8394-
Val = MIB.buildSplatVector(Ty, Val).getReg(0);
8394+
Val = MIB.buildSplatBuildVector(Ty, Val).getReg(0);
83958395

83968396
return Val;
83978397
}

llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ MachineInstrBuilder MachineIRBuilder::buildConstant(const DstOp &Res,
326326
auto Const = buildInstr(TargetOpcode::G_CONSTANT)
327327
.addDef(getMRI()->createGenericVirtualRegister(EltTy))
328328
.addCImm(&Val);
329-
return buildSplatVector(Res, Const);
329+
return buildSplatBuildVector(Res, Const);
330330
}
331331

332332
auto Const = buildInstr(TargetOpcode::G_CONSTANT);
@@ -363,7 +363,7 @@ MachineInstrBuilder MachineIRBuilder::buildFConstant(const DstOp &Res,
363363
.addDef(getMRI()->createGenericVirtualRegister(EltTy))
364364
.addFPImm(&Val);
365365

366-
return buildSplatVector(Res, Const);
366+
return buildSplatBuildVector(Res, Const);
367367
}
368368

369369
auto Const = buildInstr(TargetOpcode::G_FCONSTANT);
@@ -711,8 +711,8 @@ MachineIRBuilder::buildBuildVectorConstant(const DstOp &Res,
711711
return buildInstr(TargetOpcode::G_BUILD_VECTOR, Res, TmpVec);
712712
}
713713

714-
MachineInstrBuilder MachineIRBuilder::buildSplatVector(const DstOp &Res,
715-
const SrcOp &Src) {
714+
MachineInstrBuilder MachineIRBuilder::buildSplatBuildVector(const DstOp &Res,
715+
const SrcOp &Src) {
716716
SmallVector<SrcOp, 8> TmpVec(Res.getLLTTy(*getMRI()).getNumElements(), Src);
717717
return buildInstr(TargetOpcode::G_BUILD_VECTOR, Res, TmpVec);
718718
}
@@ -742,6 +742,14 @@ MachineInstrBuilder MachineIRBuilder::buildShuffleSplat(const DstOp &Res,
742742
return buildShuffleVector(DstTy, InsElt, UndefVec, ZeroMask);
743743
}
744744

745+
MachineInstrBuilder MachineIRBuilder::buildSplatVector(const DstOp &Res,
746+
const SrcOp &Src) {
747+
LLT DstTy = Res.getLLTTy(*getMRI());
748+
assert(Src.getLLTTy(*getMRI()) == DstTy.getElementType() &&
749+
"Expected Src to match Dst elt ty");
750+
return buildInstr(TargetOpcode::G_SPLAT_VECTOR, Res, Src);
751+
}
752+
745753
MachineInstrBuilder MachineIRBuilder::buildShuffleVector(const DstOp &Res,
746754
const SrcOp &Src1,
747755
const SrcOp &Src2,

llvm/lib/CodeGen/MachineVerifier.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,6 +1640,24 @@ void MachineVerifier::verifyPreISelGenericInstruction(const MachineInstr *MI) {
16401640

16411641
break;
16421642
}
1643+
1644+
case TargetOpcode::G_SPLAT_VECTOR: {
1645+
LLT DstTy = MRI->getType(MI->getOperand(0).getReg());
1646+
LLT SrcTy = MRI->getType(MI->getOperand(1).getReg());
1647+
1648+
if (!DstTy.isScalableVector())
1649+
report("Destination type must be a scalable vector", MI);
1650+
1651+
if (!SrcTy.isScalar())
1652+
report("Source type must be a scalar", MI);
1653+
1654+
if (DstTy.getScalarType() != SrcTy)
1655+
report("Element type of the destination must be the same type as the "
1656+
"source type",
1657+
MI);
1658+
1659+
break;
1660+
}
16431661
case TargetOpcode::G_DYN_STACKALLOC: {
16441662
const MachineOperand &DstOp = MI->getOperand(0);
16451663
const MachineOperand &AllocOp = MI->getOperand(1);

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20920,7 +20920,8 @@ bool RISCVTargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
2092020920
unsigned Op = Inst.getOpcode();
2092120921
if (Op == Instruction::Add || Op == Instruction::Sub ||
2092220922
Op == Instruction::And || Op == Instruction::Or ||
20923-
Op == Instruction::Xor || Op == Instruction::InsertElement)
20923+
Op == Instruction::Xor || Op == Instruction::InsertElement ||
20924+
Op == Instruction::Xor || Op == Instruction::ShuffleVector)
2092420925
return false;
2092520926

2092620927
if (Inst.getType()->isScalableTy())

llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,9 @@
625625
# DEBUG-NEXT: G_SHUFFLE_VECTOR (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
626626
# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
627627
# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
628+
# DEBUG-NEXT: G_SPLAT_VECTOR (opcode 217): 2 type indices, 0 imm indices
629+
# DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined
630+
# DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined
628631
# DEBUG-NEXT: G_CTTZ (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
629632
# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
630633
# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected

0 commit comments

Comments
 (0)