Skip to content

Commit e7f9d8e

Browse files
authored
[AArch64] Lower alias mask to a whilewr (#100769)
#100579 emits IR that creates a mask disabling lanes that could alias within a loop iteration, based on a pair of pointers. This PR lowers that IR to the WHILEWR instruction for AArch64.
1 parent 9fa17fe commit e7f9d8e

File tree

2 files changed

+1207
-0
lines changed

2 files changed

+1207
-0
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1523,6 +1523,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
15231523
setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
15241524
setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
15251525
setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
1526+
setOperationAction(ISD::OR, VT, Custom);
15261527

15271528
setOperationAction(ISD::SELECT_CC, VT, Expand);
15281529
setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
@@ -13808,8 +13809,128 @@ static SDValue tryLowerToSLI(SDNode *N, SelectionDAG &DAG) {
1380813809
return ResultSLI;
1380913810
}
1381013811

13812+
/// Try to lower the construction of a pointer alias mask to a WHILEWR.
13813+
/// The mask's enabled lanes represent the elements that will not overlap across
13814+
/// one loop iteration. This tries to match:
13815+
/// or (splat (setcc_lt (sub ptrA, ptrB), -(element_size - 1))),
13816+
/// (get_active_lane_mask 0, (div (sub ptrA, ptrB), element_size))
13817+
SDValue tryWhileWRFromOR(SDValue Op, SelectionDAG &DAG,
13818+
const AArch64Subtarget &Subtarget) {
13819+
if (!Subtarget.hasSVE2())
13820+
return SDValue();
13821+
SDValue LaneMask = Op.getOperand(0);
13822+
SDValue Splat = Op.getOperand(1);
13823+
13824+
if (Splat.getOpcode() != ISD::SPLAT_VECTOR)
13825+
std::swap(LaneMask, Splat);
13826+
13827+
if (LaneMask.getOpcode() != ISD::INTRINSIC_WO_CHAIN ||
13828+
LaneMask.getConstantOperandVal(0) != Intrinsic::get_active_lane_mask ||
13829+
Splat.getOpcode() != ISD::SPLAT_VECTOR)
13830+
return SDValue();
13831+
13832+
SDValue Cmp = Splat.getOperand(0);
13833+
if (Cmp.getOpcode() != ISD::SETCC)
13834+
return SDValue();
13835+
13836+
CondCodeSDNode *Cond = cast<CondCodeSDNode>(Cmp.getOperand(2));
13837+
13838+
auto ComparatorConst = dyn_cast<ConstantSDNode>(Cmp.getOperand(1));
13839+
if (!ComparatorConst || ComparatorConst->getSExtValue() > 0 ||
13840+
Cond->get() != ISD::CondCode::SETLT)
13841+
return SDValue();
13842+
unsigned CompValue = std::abs(ComparatorConst->getSExtValue());
13843+
unsigned EltSize = CompValue + 1;
13844+
if (!isPowerOf2_64(EltSize) || EltSize > 8)
13845+
return SDValue();
13846+
13847+
SDValue Diff = Cmp.getOperand(0);
13848+
if (Diff.getOpcode() != ISD::SUB || Diff.getValueType() != MVT::i64)
13849+
return SDValue();
13850+
13851+
if (!isNullConstant(LaneMask.getOperand(1)) ||
13852+
(EltSize != 1 && LaneMask.getOperand(2).getOpcode() != ISD::SRA))
13853+
return SDValue();
13854+
13855+
// The number of elements that alias is calculated by dividing the positive
13856+
// difference between the pointers by the element size. An alias mask for i8
13857+
// elements omits the division because it would just divide by 1
13858+
if (EltSize > 1) {
13859+
SDValue DiffDiv = LaneMask.getOperand(2);
13860+
auto DiffDivConst = dyn_cast<ConstantSDNode>(DiffDiv.getOperand(1));
13861+
if (!DiffDivConst || DiffDivConst->getZExtValue() != Log2_64(EltSize))
13862+
return SDValue();
13863+
if (EltSize > 2) {
13864+
// When masking i32 or i64 elements, the positive value of the
13865+
// possibly-negative difference comes from a select of the difference if
13866+
// it's positive, otherwise the difference plus the element size if it's
13867+
// negative: pos_diff = diff < 0 ? (diff + 7) : diff
13868+
SDValue Select = DiffDiv.getOperand(0);
13869+
// Make sure the difference is being compared by the select
13870+
if (Select.getOpcode() != ISD::SELECT_CC || Select.getOperand(3) != Diff)
13871+
return SDValue();
13872+
// Make sure it's checking if the difference is less than 0
13873+
if (!isNullConstant(Select.getOperand(1)) ||
13874+
cast<CondCodeSDNode>(Select.getOperand(4))->get() !=
13875+
ISD::CondCode::SETLT)
13876+
return SDValue();
13877+
// An add creates a positive value from the negative difference
13878+
SDValue Add = Select.getOperand(2);
13879+
if (Add.getOpcode() != ISD::ADD || Add.getOperand(0) != Diff)
13880+
return SDValue();
13881+
if (auto *AddConst = dyn_cast<ConstantSDNode>(Add.getOperand(1));
13882+
!AddConst || AddConst->getZExtValue() != EltSize - 1)
13883+
return SDValue();
13884+
} else {
13885+
// When masking i16 elements, this positive value comes from adding the
13886+
// difference's sign bit to the difference itself. This is equivalent to
13887+
// the 32 bit and 64 bit case: pos_diff = diff + sign_bit (diff)
13888+
SDValue Add = DiffDiv.getOperand(0);
13889+
if (Add.getOpcode() != ISD::ADD || Add.getOperand(0) != Diff)
13890+
return SDValue();
13891+
// A logical right shift by 63 extracts the sign bit from the difference
13892+
SDValue Shift = Add.getOperand(1);
13893+
if (Shift.getOpcode() != ISD::SRL || Shift.getOperand(0) != Diff)
13894+
return SDValue();
13895+
if (auto *ShiftConst = dyn_cast<ConstantSDNode>(Shift.getOperand(1));
13896+
!ShiftConst || ShiftConst->getZExtValue() != 63)
13897+
return SDValue();
13898+
}
13899+
} else if (LaneMask.getOperand(2) != Diff)
13900+
return SDValue();
13901+
13902+
SDValue StorePtr = Diff.getOperand(0);
13903+
SDValue ReadPtr = Diff.getOperand(1);
13904+
13905+
unsigned IntrinsicID = 0;
13906+
switch (EltSize) {
13907+
case 1:
13908+
IntrinsicID = Intrinsic::aarch64_sve_whilewr_b;
13909+
break;
13910+
case 2:
13911+
IntrinsicID = Intrinsic::aarch64_sve_whilewr_h;
13912+
break;
13913+
case 4:
13914+
IntrinsicID = Intrinsic::aarch64_sve_whilewr_s;
13915+
break;
13916+
case 8:
13917+
IntrinsicID = Intrinsic::aarch64_sve_whilewr_d;
13918+
break;
13919+
default:
13920+
return SDValue();
13921+
}
13922+
SDLoc DL(Op);
13923+
SDValue ID = DAG.getConstant(IntrinsicID, DL, MVT::i32);
13924+
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(), ID,
13925+
StorePtr, ReadPtr);
13926+
}
13927+
1381113928
SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op,
1381213929
SelectionDAG &DAG) const {
13930+
if (SDValue SV =
13931+
tryWhileWRFromOR(Op, DAG, DAG.getSubtarget<AArch64Subtarget>()))
13932+
return SV;
13933+
1381313934
if (useSVEForFixedLengthVectorVT(Op.getValueType(),
1381413935
!Subtarget->isNeonAvailable()))
1381513936
return LowerToScalableOp(Op, DAG);

0 commit comments

Comments
 (0)