Skip to content

clastb representation in existing IR, and AArch64 codegen #112738

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 6, 2025
Merged
4 changes: 4 additions & 0 deletions llvm/include/llvm/Target/TargetSelectionDAG.td
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,10 @@ def vector_insert_subvec : SDNode<"ISD::INSERT_SUBVECTOR",
def extract_subvector : SDNode<"ISD::EXTRACT_SUBVECTOR", SDTSubVecExtract, []>;
def insert_subvector : SDNode<"ISD::INSERT_SUBVECTOR", SDTSubVecInsert, []>;

def find_last_active
: SDNode<"ISD::VECTOR_FIND_LAST_ACTIVE",
SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisVec<1>]>, []>;

// Nodes for intrinsics, you should use the intrinsic itself and let tblgen use
// these internally. Don't reference these directly.
def intrinsic_void : SDNode<"ISD::INTRINSIC_VOID",
Expand Down
68 changes: 68 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1452,6 +1452,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECTOR_DEINTERLEAVE, VT, Custom);
setOperationAction(ISD::VECTOR_INTERLEAVE, VT, Custom);
}
for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1})
setOperationAction(ISD::VECTOR_FIND_LAST_ACTIVE, VT, Legal);
}

if (Subtarget->isSVEorStreamingSVEAvailable()) {
Expand Down Expand Up @@ -19730,6 +19732,33 @@ performLastTrueTestVectorCombine(SDNode *N,
return getPTest(DAG, N->getValueType(0), Pg, N0, AArch64CC::LAST_ACTIVE);
}

static SDValue
performExtractLastActiveCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT);
SelectionDAG &DAG = DCI.DAG;
SDValue Vec = N->getOperand(0);
SDValue Idx = N->getOperand(1);

if (DCI.isBeforeLegalize() || Idx.getOpcode() != ISD::VECTOR_FIND_LAST_ACTIVE)
return SDValue();

// Only legal for 8, 16, 32, and 64 bit element types.
EVT EltVT = Vec.getValueType().getVectorElementType();
if (!is_contained(ArrayRef({MVT::i8, MVT::i16, MVT::i32, MVT::i64, MVT::f16,
MVT::bf16, MVT::f32, MVT::f64}),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about f16 and bf16?

EltVT.getSimpleVT().SimpleTy))
return SDValue();

SDValue Mask = Idx.getOperand(0);
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
if (!TLI.isOperationLegal(ISD::VECTOR_FIND_LAST_ACTIVE, Mask.getValueType()))
return SDValue();

return DAG.getNode(AArch64ISD::LASTB, SDLoc(N), N->getValueType(0), Mask,
Vec);
}

static SDValue
performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
Expand All @@ -19738,6 +19767,8 @@ performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
return Res;
if (SDValue Res = performLastTrueTestVectorCombine(N, DCI, Subtarget))
return Res;
if (SDValue Res = performExtractLastActiveCombine(N, DCI, Subtarget))
return Res;

SelectionDAG &DAG = DCI.DAG;
SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
Expand Down Expand Up @@ -24852,6 +24883,39 @@ static SDValue reassociateCSELOperandsForCSE(SDNode *N, SelectionDAG &DAG) {
}
}

static SDValue foldCSELofLASTB(SDNode *Op, SelectionDAG &DAG) {
AArch64CC::CondCode OpCC =
static_cast<AArch64CC::CondCode>(Op->getConstantOperandVal(2));

if (OpCC != AArch64CC::NE)
return SDValue();

SDValue PTest = Op->getOperand(3);
if (PTest.getOpcode() != AArch64ISD::PTEST_ANY)
return SDValue();

SDValue TruePred = PTest.getOperand(0);
SDValue AnyPred = PTest.getOperand(1);

if (TruePred.getOpcode() == AArch64ISD::REINTERPRET_CAST)
TruePred = TruePred.getOperand(0);

if (AnyPred.getOpcode() == AArch64ISD::REINTERPRET_CAST)
AnyPred = AnyPred.getOperand(0);

if (TruePred != AnyPred && TruePred.getOpcode() != AArch64ISD::PTRUE)
return SDValue();

SDValue LastB = Op->getOperand(0);
SDValue Default = Op->getOperand(1);

if (LastB.getOpcode() != AArch64ISD::LASTB || LastB.getOperand(0) != AnyPred)
return SDValue();

return DAG.getNode(AArch64ISD::CLASTB_N, SDLoc(Op), Op->getValueType(0),
AnyPred, Default, LastB.getOperand(1));
}

// Optimize CSEL instructions
static SDValue performCSELCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
Expand Down Expand Up @@ -24897,6 +24961,10 @@ static SDValue performCSELCombine(SDNode *N,
}
}

// CSEL (LASTB P, Z), X, NE(ANY P) -> CLASTB P, X, Z
if (SDValue CondLast = foldCSELofLASTB(N, DAG))
return CondLast;

return performCONDCombine(N, DCI, DAG, 2, 3);
}

Expand Down
14 changes: 14 additions & 0 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -3379,6 +3379,20 @@ let Predicates = [HasSVE_or_SME] in {
def : Pat<(i64 (vector_extract nxv2i64:$vec, VectorIndexD:$index)),
(UMOVvi64 (v2i64 (EXTRACT_SUBREG ZPR:$vec, zsub)), VectorIndexD:$index)>;

// Find index of last active lane. This is a fallback in case we miss the
// opportunity to fold into a lastb or clastb directly.
Comment on lines +3382 to +3383
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these fallback patterns tested in the final patch?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it would be good to have some tests for these.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sadly it's pretty difficult to do this once the combines have been added. I don't see a global switch to disable combining, just the target-indepedent combines. We do check the optimization level in a few places in AArch64ISelLowering, but mostly for TLI methods for IR-level decisions. Deliberately turning off (c)lastb pattern matching at O0 feels odd. Adding a new switch just for this feature also feels excessive.

I could potentially add a globalisel-based test, though I'm not sure how much code that requires. We've added a few new ISD nodes recently, and none have added support in globalisel.

I guess this is mostly due to it being hard to just create a selectiondag without IR and run selection over it.

def : Pat<(i64(find_last_active nxv16i1:$P1)),
(INSERT_SUBREG(IMPLICIT_DEF), (LASTB_RPZ_B $P1, (INDEX_II_B 0, 1)),
sub_32)>;
def : Pat<(i64(find_last_active nxv8i1:$P1)),
(INSERT_SUBREG(IMPLICIT_DEF), (LASTB_RPZ_H $P1, (INDEX_II_H 0, 1)),
sub_32)>;
def : Pat<(i64(find_last_active nxv4i1:$P1)),
(INSERT_SUBREG(IMPLICIT_DEF), (LASTB_RPZ_S $P1, (INDEX_II_S 0, 1)),
sub_32)>;
def : Pat<(i64(find_last_active nxv2i1:$P1)), (LASTB_RPZ_D $P1, (INDEX_II_D 0,
1))>;

// Move element from the bottom 128-bits of a scalable vector to a single-element vector.
// Alternative case where insertelement is just scalar_to_vector rather than vector_insert.
def : Pat<(v1f64 (scalar_to_vector
Expand Down
Loading