@@ -1452,6 +1452,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1452
1452
setOperationAction(ISD::VECTOR_DEINTERLEAVE, VT, Custom);
1453
1453
setOperationAction(ISD::VECTOR_INTERLEAVE, VT, Custom);
1454
1454
}
1455
+ for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1})
1456
+ setOperationAction(ISD::VECTOR_FIND_LAST_ACTIVE, VT, Legal);
1455
1457
}
1456
1458
1457
1459
if (Subtarget->isSVEorStreamingSVEAvailable()) {
@@ -19730,6 +19732,33 @@ performLastTrueTestVectorCombine(SDNode *N,
19730
19732
return getPTest(DAG, N->getValueType(0), Pg, N0, AArch64CC::LAST_ACTIVE);
19731
19733
}
19732
19734
19735
+ static SDValue
19736
+ performExtractLastActiveCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
19737
+ const AArch64Subtarget *Subtarget) {
19738
+ assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT);
19739
+ SelectionDAG &DAG = DCI.DAG;
19740
+ SDValue Vec = N->getOperand(0);
19741
+ SDValue Idx = N->getOperand(1);
19742
+
19743
+ if (DCI.isBeforeLegalize() || Idx.getOpcode() != ISD::VECTOR_FIND_LAST_ACTIVE)
19744
+ return SDValue();
19745
+
19746
+ // Only legal for 8, 16, 32, and 64 bit element types.
19747
+ EVT EltVT = Vec.getValueType().getVectorElementType();
19748
+ if (!is_contained(ArrayRef({MVT::i8, MVT::i16, MVT::i32, MVT::i64, MVT::f16,
19749
+ MVT::bf16, MVT::f32, MVT::f64}),
19750
+ EltVT.getSimpleVT().SimpleTy))
19751
+ return SDValue();
19752
+
19753
+ SDValue Mask = Idx.getOperand(0);
19754
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
19755
+ if (!TLI.isOperationLegal(ISD::VECTOR_FIND_LAST_ACTIVE, Mask.getValueType()))
19756
+ return SDValue();
19757
+
19758
+ return DAG.getNode(AArch64ISD::LASTB, SDLoc(N), N->getValueType(0), Mask,
19759
+ Vec);
19760
+ }
19761
+
19733
19762
static SDValue
19734
19763
performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
19735
19764
const AArch64Subtarget *Subtarget) {
@@ -19738,6 +19767,8 @@ performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
19738
19767
return Res;
19739
19768
if (SDValue Res = performLastTrueTestVectorCombine(N, DCI, Subtarget))
19740
19769
return Res;
19770
+ if (SDValue Res = performExtractLastActiveCombine(N, DCI, Subtarget))
19771
+ return Res;
19741
19772
19742
19773
SelectionDAG &DAG = DCI.DAG;
19743
19774
SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
@@ -24852,6 +24883,39 @@ static SDValue reassociateCSELOperandsForCSE(SDNode *N, SelectionDAG &DAG) {
24852
24883
}
24853
24884
}
24854
24885
24886
+ static SDValue foldCSELofLASTB(SDNode *Op, SelectionDAG &DAG) {
24887
+ AArch64CC::CondCode OpCC =
24888
+ static_cast<AArch64CC::CondCode>(Op->getConstantOperandVal(2));
24889
+
24890
+ if (OpCC != AArch64CC::NE)
24891
+ return SDValue();
24892
+
24893
+ SDValue PTest = Op->getOperand(3);
24894
+ if (PTest.getOpcode() != AArch64ISD::PTEST_ANY)
24895
+ return SDValue();
24896
+
24897
+ SDValue TruePred = PTest.getOperand(0);
24898
+ SDValue AnyPred = PTest.getOperand(1);
24899
+
24900
+ if (TruePred.getOpcode() == AArch64ISD::REINTERPRET_CAST)
24901
+ TruePred = TruePred.getOperand(0);
24902
+
24903
+ if (AnyPred.getOpcode() == AArch64ISD::REINTERPRET_CAST)
24904
+ AnyPred = AnyPred.getOperand(0);
24905
+
24906
+ if (TruePred != AnyPred && TruePred.getOpcode() != AArch64ISD::PTRUE)
24907
+ return SDValue();
24908
+
24909
+ SDValue LastB = Op->getOperand(0);
24910
+ SDValue Default = Op->getOperand(1);
24911
+
24912
+ if (LastB.getOpcode() != AArch64ISD::LASTB || LastB.getOperand(0) != AnyPred)
24913
+ return SDValue();
24914
+
24915
+ return DAG.getNode(AArch64ISD::CLASTB_N, SDLoc(Op), Op->getValueType(0),
24916
+ AnyPred, Default, LastB.getOperand(1));
24917
+ }
24918
+
24855
24919
// Optimize CSEL instructions
24856
24920
static SDValue performCSELCombine(SDNode *N,
24857
24921
TargetLowering::DAGCombinerInfo &DCI,
@@ -24897,6 +24961,10 @@ static SDValue performCSELCombine(SDNode *N,
24897
24961
}
24898
24962
}
24899
24963
24964
+ // CSEL (LASTB P, Z), X, NE(ANY P) -> CLASTB P, X, Z
24965
+ if (SDValue CondLast = foldCSELofLASTB(N, DAG))
24966
+ return CondLast;
24967
+
24900
24968
return performCONDCombine(N, DCI, DAG, 2, 3);
24901
24969
}
24902
24970
0 commit comments