Skip to content

[GISEL] Add IRTranslation for shufflevector on scalable vector types #80378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llvm/docs/GlobalISel/GenericOpcode.rst
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,11 @@ Concatenate two vectors and shuffle the elements according to the mask operand.
The mask operand should be an IR Constant which exactly matches the
corresponding mask for the IR shufflevector instruction.

G_SPLAT_VECTOR
^^^^^^^^^^^^^^^^

Create a vector where all elements are the scalar from the source operand.

Comment on lines +642 to +646
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless we do the same in the DAG, I think this just adds complication down the line for pattern sharing. The buildBuildSplatVector can just create the G_BUILD_VECTOR with N copies of the input

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a ISD::SPLAT_VECTOR in DAG, used for scalable vectors. I was contemplating using G_BUILD_VECTOR for all splats in AArch64, to remove the current G_DUP and G_DUPLANE's.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Limiting G_SPLAT_VECTOR to scalable vectors would be a nice contribution. G_BUILD_VECTOR would be still be the tool for fixed size vectors.

Vector Reduction Operations
---------------------------

Expand Down
12 changes: 10 additions & 2 deletions llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -1063,8 +1063,7 @@ class MachineIRBuilder {

/// Build and insert \p Res = G_BUILD_VECTOR with \p Src replicated to fill
/// the number of elements
MachineInstrBuilder buildSplatVector(const DstOp &Res,
const SrcOp &Src);
MachineInstrBuilder buildSplatBuildVector(const DstOp &Res, const SrcOp &Src);

/// Build and insert \p Res = G_BUILD_VECTOR_TRUNC \p Op0, ...
///
Expand Down Expand Up @@ -1099,6 +1098,15 @@ class MachineIRBuilder {
MachineInstrBuilder buildShuffleVector(const DstOp &Res, const SrcOp &Src1,
const SrcOp &Src2, ArrayRef<int> Mask);

/// Build and insert \p Res = G_SPLAT_VECTOR \p Val
///
/// \pre setBasicBlock or setMI must have been called.
/// \pre \p Res must be a generic virtual register with vector type.
/// \pre \p Val must be a generic virtual register with scalar type.
///
/// \return a MachineInstrBuilder for the newly created instruction.
MachineInstrBuilder buildSplatVector(const DstOp &Res, const SrcOp &Val);

/// Build and insert \p Res = G_CONCAT_VECTORS \p Op0, ...
///
/// G_CONCAT_VECTORS creates a vector from the concatenation of 2 or more
Expand Down
3 changes: 3 additions & 0 deletions llvm/include/llvm/Support/TargetOpcodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,9 @@ HANDLE_TARGET_OPCODE(G_EXTRACT_VECTOR_ELT)
/// Generic shufflevector.
HANDLE_TARGET_OPCODE(G_SHUFFLE_VECTOR)

/// Generic splatvector.
HANDLE_TARGET_OPCODE(G_SPLAT_VECTOR)

/// Generic count trailing zeroes.
HANDLE_TARGET_OPCODE(G_CTTZ)

Expand Down
7 changes: 7 additions & 0 deletions llvm/include/llvm/Target/GenericOpcodes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,13 @@ def G_SHUFFLE_VECTOR: GenericInstruction {
let hasSideEffects = false;
}

// Generic splatvector.
def G_SPLAT_VECTOR: GenericInstruction {
let OutOperandList = (outs type0:$dst);
let InOperandList = (ins type1:$val);
let hasSideEffects = false;
}

//------------------------------------------------------------------------------
// Vector reductions
//------------------------------------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ MachineInstrBuilder CSEMIRBuilder::buildConstant(const DstOp &Res,
// For vectors, CSE the element only for now.
LLT Ty = Res.getLLTTy(*getMRI());
if (Ty.isVector())
return buildSplatVector(Res, buildConstant(Ty.getElementType(), Val));
return buildSplatBuildVector(Res, buildConstant(Ty.getElementType(), Val));

FoldingSetNodeID ID;
GISelInstProfileBuilder ProfBuilder(ID, *getMRI());
Expand All @@ -336,7 +336,7 @@ MachineInstrBuilder CSEMIRBuilder::buildFConstant(const DstOp &Res,
// For vectors, CSE the element only for now.
LLT Ty = Res.getLLTTy(*getMRI());
if (Ty.isVector())
return buildSplatVector(Res, buildFConstant(Ty.getElementType(), Val));
return buildSplatBuildVector(Res, buildFConstant(Ty.getElementType(), Val));

FoldingSetNodeID ID;
GISelInstProfileBuilder ProfBuilder(ID, *getMRI());
Expand Down
27 changes: 21 additions & 6 deletions llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1598,10 +1598,10 @@ bool IRTranslator::translateGetElementPtr(const User &U,
// We might need to splat the base pointer into a vector if the offsets
// are vectors.
if (WantSplatVector && !PtrTy.isVector()) {
BaseReg =
MIRBuilder
.buildSplatVector(LLT::fixed_vector(VectorWidth, PtrTy), BaseReg)
.getReg(0);
BaseReg = MIRBuilder
.buildSplatBuildVector(LLT::fixed_vector(VectorWidth, PtrTy),
BaseReg)
.getReg(0);
PtrIRTy = FixedVectorType::get(PtrIRTy, VectorWidth);
PtrTy = getLLTForType(*PtrIRTy, *DL);
OffsetIRTy = DL->getIndexType(PtrIRTy);
Expand Down Expand Up @@ -1639,8 +1639,10 @@ bool IRTranslator::translateGetElementPtr(const User &U,
LLT IdxTy = MRI->getType(IdxReg);
if (IdxTy != OffsetTy) {
if (!IdxTy.isVector() && WantSplatVector) {
IdxReg = MIRBuilder.buildSplatVector(
OffsetTy.changeElementType(IdxTy), IdxReg).getReg(0);
IdxReg = MIRBuilder
.buildSplatBuildVector(OffsetTy.changeElementType(IdxTy),
IdxReg)
.getReg(0);
}

IdxReg = MIRBuilder.buildSExtOrTrunc(OffsetTy, IdxReg).getReg(0);
Expand Down Expand Up @@ -2997,6 +2999,19 @@ bool IRTranslator::translateExtractElement(const User &U,

bool IRTranslator::translateShuffleVector(const User &U,
MachineIRBuilder &MIRBuilder) {
// A ShuffleVector that has operates on scalable vectors is a splat vector
// where the value of the splat vector is the 0th element of the first
// operand, since the index mask operand is the zeroinitializer (undef and
// poison are treated as zeroinitializer here).
if (U.getOperand(0)->getType()->isScalableTy()) {
Value *Op0 = U.getOperand(0);
auto SplatVal = MIRBuilder.buildExtractVectorElementConstant(
LLT::scalar(Op0->getType()->getScalarSizeInBits()),
getOrCreateVReg(*Op0), 0);
MIRBuilder.buildSplatVector(getOrCreateVReg(U), SplatVal);
return true;
}

ArrayRef<int> Mask;
if (auto *SVI = dyn_cast<ShuffleVectorInst>(&U))
Mask = SVI->getShuffleMask();
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8391,7 +8391,7 @@ static Register getMemsetValue(Register Val, LLT Ty, MachineIRBuilder &MIB) {

// For vector types create a G_BUILD_VECTOR.
if (Ty.isVector())
Val = MIB.buildSplatVector(Ty, Val).getReg(0);
Val = MIB.buildSplatBuildVector(Ty, Val).getReg(0);

return Val;
}
Expand Down
16 changes: 12 additions & 4 deletions llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ MachineInstrBuilder MachineIRBuilder::buildConstant(const DstOp &Res,
auto Const = buildInstr(TargetOpcode::G_CONSTANT)
.addDef(getMRI()->createGenericVirtualRegister(EltTy))
.addCImm(&Val);
return buildSplatVector(Res, Const);
return buildSplatBuildVector(Res, Const);
}

auto Const = buildInstr(TargetOpcode::G_CONSTANT);
Expand Down Expand Up @@ -363,7 +363,7 @@ MachineInstrBuilder MachineIRBuilder::buildFConstant(const DstOp &Res,
.addDef(getMRI()->createGenericVirtualRegister(EltTy))
.addFPImm(&Val);

return buildSplatVector(Res, Const);
return buildSplatBuildVector(Res, Const);
}

auto Const = buildInstr(TargetOpcode::G_FCONSTANT);
Expand Down Expand Up @@ -711,8 +711,8 @@ MachineIRBuilder::buildBuildVectorConstant(const DstOp &Res,
return buildInstr(TargetOpcode::G_BUILD_VECTOR, Res, TmpVec);
}

MachineInstrBuilder MachineIRBuilder::buildSplatVector(const DstOp &Res,
const SrcOp &Src) {
MachineInstrBuilder MachineIRBuilder::buildSplatBuildVector(const DstOp &Res,
const SrcOp &Src) {
SmallVector<SrcOp, 8> TmpVec(Res.getLLTTy(*getMRI()).getNumElements(), Src);
return buildInstr(TargetOpcode::G_BUILD_VECTOR, Res, TmpVec);
}
Expand Down Expand Up @@ -742,6 +742,14 @@ MachineInstrBuilder MachineIRBuilder::buildShuffleSplat(const DstOp &Res,
return buildShuffleVector(DstTy, InsElt, UndefVec, ZeroMask);
}

MachineInstrBuilder MachineIRBuilder::buildSplatVector(const DstOp &Res,
const SrcOp &Src) {
LLT DstTy = Res.getLLTTy(*getMRI());
assert(Src.getLLTTy(*getMRI()) == DstTy.getElementType() &&
"Expected Src to match Dst elt ty");
return buildInstr(TargetOpcode::G_SPLAT_VECTOR, Res, Src);
}

MachineInstrBuilder MachineIRBuilder::buildShuffleVector(const DstOp &Res,
const SrcOp &Src1,
const SrcOp &Src2,
Expand Down
18 changes: 18 additions & 0 deletions llvm/lib/CodeGen/MachineVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1640,6 +1640,24 @@ void MachineVerifier::verifyPreISelGenericInstruction(const MachineInstr *MI) {

break;
}

case TargetOpcode::G_SPLAT_VECTOR: {
LLT DstTy = MRI->getType(MI->getOperand(0).getReg());
LLT SrcTy = MRI->getType(MI->getOperand(1).getReg());

if (!DstTy.isScalableVector())
report("Destination type must be a scalable vector", MI);

if (!SrcTy.isScalar())
report("Source type must be a scalar", MI);

if (DstTy.getScalarType() != SrcTy)
report("Element type of the destination must be the same type as the "
"source type",
MI);

break;
}
case TargetOpcode::G_DYN_STACKALLOC: {
const MachineOperand &DstOp = MI->getOperand(0);
const MachineOperand &AllocOp = MI->getOperand(1);
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20920,7 +20920,8 @@ bool RISCVTargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
unsigned Op = Inst.getOpcode();
if (Op == Instruction::Add || Op == Instruction::Sub ||
Op == Instruction::And || Op == Instruction::Or ||
Op == Instruction::Xor || Op == Instruction::InsertElement)
Op == Instruction::Xor || Op == Instruction::InsertElement ||
Op == Instruction::Xor || Op == Instruction::ShuffleVector)
return false;

if (Inst.getType()->isScalableTy())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,9 @@
# DEBUG-NEXT: G_SHUFFLE_VECTOR (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
# DEBUG-NEXT: G_SPLAT_VECTOR (opcode 217): 2 type indices, 0 imm indices
# DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined
# DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined
# DEBUG-NEXT: G_CTTZ (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
Expand Down
Loading