@@ -1523,6 +1523,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1523
1523
setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
1524
1524
setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
1525
1525
setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
1526
+ setOperationAction(ISD::OR, VT, Custom);
1526
1527
1527
1528
setOperationAction(ISD::SELECT_CC, VT, Expand);
1528
1529
setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
@@ -13808,8 +13809,128 @@ static SDValue tryLowerToSLI(SDNode *N, SelectionDAG &DAG) {
13808
13809
return ResultSLI;
13809
13810
}
13810
13811
13812
+ /// Try to lower the construction of a pointer alias mask to a WHILEWR.
13813
+ /// The mask's enabled lanes represent the elements that will not overlap across
13814
+ /// one loop iteration. This tries to match:
13815
+ /// or (splat (setcc_lt (sub ptrA, ptrB), -(element_size - 1))),
13816
+ /// (get_active_lane_mask 0, (div (sub ptrA, ptrB), element_size))
13817
+ SDValue tryWhileWRFromOR(SDValue Op, SelectionDAG &DAG,
13818
+ const AArch64Subtarget &Subtarget) {
13819
+ if (!Subtarget.hasSVE2())
13820
+ return SDValue();
13821
+ SDValue LaneMask = Op.getOperand(0);
13822
+ SDValue Splat = Op.getOperand(1);
13823
+
13824
+ if (Splat.getOpcode() != ISD::SPLAT_VECTOR)
13825
+ std::swap(LaneMask, Splat);
13826
+
13827
+ if (LaneMask.getOpcode() != ISD::INTRINSIC_WO_CHAIN ||
13828
+ LaneMask.getConstantOperandVal(0) != Intrinsic::get_active_lane_mask ||
13829
+ Splat.getOpcode() != ISD::SPLAT_VECTOR)
13830
+ return SDValue();
13831
+
13832
+ SDValue Cmp = Splat.getOperand(0);
13833
+ if (Cmp.getOpcode() != ISD::SETCC)
13834
+ return SDValue();
13835
+
13836
+ CondCodeSDNode *Cond = cast<CondCodeSDNode>(Cmp.getOperand(2));
13837
+
13838
+ auto ComparatorConst = dyn_cast<ConstantSDNode>(Cmp.getOperand(1));
13839
+ if (!ComparatorConst || ComparatorConst->getSExtValue() > 0 ||
13840
+ Cond->get() != ISD::CondCode::SETLT)
13841
+ return SDValue();
13842
+ unsigned CompValue = std::abs(ComparatorConst->getSExtValue());
13843
+ unsigned EltSize = CompValue + 1;
13844
+ if (!isPowerOf2_64(EltSize) || EltSize > 8)
13845
+ return SDValue();
13846
+
13847
+ SDValue Diff = Cmp.getOperand(0);
13848
+ if (Diff.getOpcode() != ISD::SUB || Diff.getValueType() != MVT::i64)
13849
+ return SDValue();
13850
+
13851
+ if (!isNullConstant(LaneMask.getOperand(1)) ||
13852
+ (EltSize != 1 && LaneMask.getOperand(2).getOpcode() != ISD::SRA))
13853
+ return SDValue();
13854
+
13855
+ // The number of elements that alias is calculated by dividing the positive
13856
+ // difference between the pointers by the element size. An alias mask for i8
13857
+ // elements omits the division because it would just divide by 1
13858
+ if (EltSize > 1) {
13859
+ SDValue DiffDiv = LaneMask.getOperand(2);
13860
+ auto DiffDivConst = dyn_cast<ConstantSDNode>(DiffDiv.getOperand(1));
13861
+ if (!DiffDivConst || DiffDivConst->getZExtValue() != Log2_64(EltSize))
13862
+ return SDValue();
13863
+ if (EltSize > 2) {
13864
+ // When masking i32 or i64 elements, the positive value of the
13865
+ // possibly-negative difference comes from a select of the difference if
13866
+ // it's positive, otherwise the difference plus the element size if it's
13867
+ // negative: pos_diff = diff < 0 ? (diff + 7) : diff
13868
+ SDValue Select = DiffDiv.getOperand(0);
13869
+ // Make sure the difference is being compared by the select
13870
+ if (Select.getOpcode() != ISD::SELECT_CC || Select.getOperand(3) != Diff)
13871
+ return SDValue();
13872
+ // Make sure it's checking if the difference is less than 0
13873
+ if (!isNullConstant(Select.getOperand(1)) ||
13874
+ cast<CondCodeSDNode>(Select.getOperand(4))->get() !=
13875
+ ISD::CondCode::SETLT)
13876
+ return SDValue();
13877
+ // An add creates a positive value from the negative difference
13878
+ SDValue Add = Select.getOperand(2);
13879
+ if (Add.getOpcode() != ISD::ADD || Add.getOperand(0) != Diff)
13880
+ return SDValue();
13881
+ if (auto *AddConst = dyn_cast<ConstantSDNode>(Add.getOperand(1));
13882
+ !AddConst || AddConst->getZExtValue() != EltSize - 1)
13883
+ return SDValue();
13884
+ } else {
13885
+ // When masking i16 elements, this positive value comes from adding the
13886
+ // difference's sign bit to the difference itself. This is equivalent to
13887
+ // the 32 bit and 64 bit case: pos_diff = diff + sign_bit (diff)
13888
+ SDValue Add = DiffDiv.getOperand(0);
13889
+ if (Add.getOpcode() != ISD::ADD || Add.getOperand(0) != Diff)
13890
+ return SDValue();
13891
+ // A logical right shift by 63 extracts the sign bit from the difference
13892
+ SDValue Shift = Add.getOperand(1);
13893
+ if (Shift.getOpcode() != ISD::SRL || Shift.getOperand(0) != Diff)
13894
+ return SDValue();
13895
+ if (auto *ShiftConst = dyn_cast<ConstantSDNode>(Shift.getOperand(1));
13896
+ !ShiftConst || ShiftConst->getZExtValue() != 63)
13897
+ return SDValue();
13898
+ }
13899
+ } else if (LaneMask.getOperand(2) != Diff)
13900
+ return SDValue();
13901
+
13902
+ SDValue StorePtr = Diff.getOperand(0);
13903
+ SDValue ReadPtr = Diff.getOperand(1);
13904
+
13905
+ unsigned IntrinsicID = 0;
13906
+ switch (EltSize) {
13907
+ case 1:
13908
+ IntrinsicID = Intrinsic::aarch64_sve_whilewr_b;
13909
+ break;
13910
+ case 2:
13911
+ IntrinsicID = Intrinsic::aarch64_sve_whilewr_h;
13912
+ break;
13913
+ case 4:
13914
+ IntrinsicID = Intrinsic::aarch64_sve_whilewr_s;
13915
+ break;
13916
+ case 8:
13917
+ IntrinsicID = Intrinsic::aarch64_sve_whilewr_d;
13918
+ break;
13919
+ default:
13920
+ return SDValue();
13921
+ }
13922
+ SDLoc DL(Op);
13923
+ SDValue ID = DAG.getConstant(IntrinsicID, DL, MVT::i32);
13924
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(), ID,
13925
+ StorePtr, ReadPtr);
13926
+ }
13927
+
13811
13928
SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op,
13812
13929
SelectionDAG &DAG) const {
13930
+ if (SDValue SV =
13931
+ tryWhileWRFromOR(Op, DAG, DAG.getSubtarget<AArch64Subtarget>()))
13932
+ return SV;
13933
+
13813
13934
if (useSVEForFixedLengthVectorVT(Op.getValueType(),
13814
13935
!Subtarget->isNeonAvailable()))
13815
13936
return LowerToScalableOp(Op, DAG);
0 commit comments