Skip to content

Commit b74176a

Browse files
authored
[AArch64][SelectionDAG] Generate clastb for extract.last.active (#112738)
This patch improves SVE codegen for the vector extract last active intrinsic, using either the lastb instruction (if the passthru value was poison or undef), or the clastb instruction.
1 parent c9d0a46 commit b74176a

File tree

4 files changed

+239
-78
lines changed

4 files changed

+239
-78
lines changed

llvm/include/llvm/Target/TargetSelectionDAG.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,10 @@ def vector_insert_subvec : SDNode<"ISD::INSERT_SUBVECTOR",
840840
def extract_subvector : SDNode<"ISD::EXTRACT_SUBVECTOR", SDTSubVecExtract, []>;
841841
def insert_subvector : SDNode<"ISD::INSERT_SUBVECTOR", SDTSubVecInsert, []>;
842842

843+
def find_last_active
844+
: SDNode<"ISD::VECTOR_FIND_LAST_ACTIVE",
845+
SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisVec<1>]>, []>;
846+
843847
// Nodes for intrinsics, you should use the intrinsic itself and let tblgen use
844848
// these internally. Don't reference these directly.
845849
def intrinsic_void : SDNode<"ISD::INTRINSIC_VOID",

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,6 +1452,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
14521452
setOperationAction(ISD::VECTOR_DEINTERLEAVE, VT, Custom);
14531453
setOperationAction(ISD::VECTOR_INTERLEAVE, VT, Custom);
14541454
}
1455+
for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1})
1456+
setOperationAction(ISD::VECTOR_FIND_LAST_ACTIVE, VT, Legal);
14551457
}
14561458

14571459
if (Subtarget->isSVEorStreamingSVEAvailable()) {
@@ -19730,6 +19732,33 @@ performLastTrueTestVectorCombine(SDNode *N,
1973019732
return getPTest(DAG, N->getValueType(0), Pg, N0, AArch64CC::LAST_ACTIVE);
1973119733
}
1973219734

19735+
static SDValue
19736+
performExtractLastActiveCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
19737+
const AArch64Subtarget *Subtarget) {
19738+
assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT);
19739+
SelectionDAG &DAG = DCI.DAG;
19740+
SDValue Vec = N->getOperand(0);
19741+
SDValue Idx = N->getOperand(1);
19742+
19743+
if (DCI.isBeforeLegalize() || Idx.getOpcode() != ISD::VECTOR_FIND_LAST_ACTIVE)
19744+
return SDValue();
19745+
19746+
// Only legal for 8, 16, 32, and 64 bit element types.
19747+
EVT EltVT = Vec.getValueType().getVectorElementType();
19748+
if (!is_contained(ArrayRef({MVT::i8, MVT::i16, MVT::i32, MVT::i64, MVT::f16,
19749+
MVT::bf16, MVT::f32, MVT::f64}),
19750+
EltVT.getSimpleVT().SimpleTy))
19751+
return SDValue();
19752+
19753+
SDValue Mask = Idx.getOperand(0);
19754+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
19755+
if (!TLI.isOperationLegal(ISD::VECTOR_FIND_LAST_ACTIVE, Mask.getValueType()))
19756+
return SDValue();
19757+
19758+
return DAG.getNode(AArch64ISD::LASTB, SDLoc(N), N->getValueType(0), Mask,
19759+
Vec);
19760+
}
19761+
1973319762
static SDValue
1973419763
performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
1973519764
const AArch64Subtarget *Subtarget) {
@@ -19738,6 +19767,8 @@ performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
1973819767
return Res;
1973919768
if (SDValue Res = performLastTrueTestVectorCombine(N, DCI, Subtarget))
1974019769
return Res;
19770+
if (SDValue Res = performExtractLastActiveCombine(N, DCI, Subtarget))
19771+
return Res;
1974119772

1974219773
SelectionDAG &DAG = DCI.DAG;
1974319774
SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
@@ -24852,6 +24883,39 @@ static SDValue reassociateCSELOperandsForCSE(SDNode *N, SelectionDAG &DAG) {
2485224883
}
2485324884
}
2485424885

24886+
static SDValue foldCSELofLASTB(SDNode *Op, SelectionDAG &DAG) {
24887+
AArch64CC::CondCode OpCC =
24888+
static_cast<AArch64CC::CondCode>(Op->getConstantOperandVal(2));
24889+
24890+
if (OpCC != AArch64CC::NE)
24891+
return SDValue();
24892+
24893+
SDValue PTest = Op->getOperand(3);
24894+
if (PTest.getOpcode() != AArch64ISD::PTEST_ANY)
24895+
return SDValue();
24896+
24897+
SDValue TruePred = PTest.getOperand(0);
24898+
SDValue AnyPred = PTest.getOperand(1);
24899+
24900+
if (TruePred.getOpcode() == AArch64ISD::REINTERPRET_CAST)
24901+
TruePred = TruePred.getOperand(0);
24902+
24903+
if (AnyPred.getOpcode() == AArch64ISD::REINTERPRET_CAST)
24904+
AnyPred = AnyPred.getOperand(0);
24905+
24906+
if (TruePred != AnyPred && TruePred.getOpcode() != AArch64ISD::PTRUE)
24907+
return SDValue();
24908+
24909+
SDValue LastB = Op->getOperand(0);
24910+
SDValue Default = Op->getOperand(1);
24911+
24912+
if (LastB.getOpcode() != AArch64ISD::LASTB || LastB.getOperand(0) != AnyPred)
24913+
return SDValue();
24914+
24915+
return DAG.getNode(AArch64ISD::CLASTB_N, SDLoc(Op), Op->getValueType(0),
24916+
AnyPred, Default, LastB.getOperand(1));
24917+
}
24918+
2485524919
// Optimize CSEL instructions
2485624920
static SDValue performCSELCombine(SDNode *N,
2485724921
TargetLowering::DAGCombinerInfo &DCI,
@@ -24897,6 +24961,10 @@ static SDValue performCSELCombine(SDNode *N,
2489724961
}
2489824962
}
2489924963

24964+
// CSEL (LASTB P, Z), X, NE(ANY P) -> CLASTB P, X, Z
24965+
if (SDValue CondLast = foldCSELofLASTB(N, DAG))
24966+
return CondLast;
24967+
2490024968
return performCONDCombine(N, DCI, DAG, 2, 3);
2490124969
}
2490224970

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3379,6 +3379,20 @@ let Predicates = [HasSVE_or_SME] in {
33793379
def : Pat<(i64 (vector_extract nxv2i64:$vec, VectorIndexD:$index)),
33803380
(UMOVvi64 (v2i64 (EXTRACT_SUBREG ZPR:$vec, zsub)), VectorIndexD:$index)>;
33813381

3382+
// Find index of last active lane. This is a fallback in case we miss the
3383+
// opportunity to fold into a lastb or clastb directly.
3384+
def : Pat<(i64(find_last_active nxv16i1:$P1)),
3385+
(INSERT_SUBREG(IMPLICIT_DEF), (LASTB_RPZ_B $P1, (INDEX_II_B 0, 1)),
3386+
sub_32)>;
3387+
def : Pat<(i64(find_last_active nxv8i1:$P1)),
3388+
(INSERT_SUBREG(IMPLICIT_DEF), (LASTB_RPZ_H $P1, (INDEX_II_H 0, 1)),
3389+
sub_32)>;
3390+
def : Pat<(i64(find_last_active nxv4i1:$P1)),
3391+
(INSERT_SUBREG(IMPLICIT_DEF), (LASTB_RPZ_S $P1, (INDEX_II_S 0, 1)),
3392+
sub_32)>;
3393+
def : Pat<(i64(find_last_active nxv2i1:$P1)), (LASTB_RPZ_D $P1, (INDEX_II_D 0,
3394+
1))>;
3395+
33823396
// Move element from the bottom 128-bits of a scalable vector to a single-element vector.
33833397
// Alternative case where insertelement is just scalar_to_vector rather than vector_insert.
33843398
def : Pat<(v1f64 (scalar_to_vector

0 commit comments

Comments
 (0)