Skip to content

Commit ebf15e9

Browse files
committed
Not-quite-working prototype with ISD node
1 parent 59720dc commit ebf15e9

File tree

8 files changed

+177
-31
lines changed

8 files changed

+177
-31
lines changed

llvm/include/llvm/CodeGen/ISDOpcodes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,6 +1480,10 @@ enum NodeType {
14801480
// Output: Output Chain
14811481
EXPERIMENTAL_VECTOR_HISTOGRAM,
14821482

1483+
// experimental.vector.extract.last.active intrinsic
1484+
// Operands: Data, Mask, PassThru
1485+
VECTOR_EXTRACT_LAST_ACTIVE,
1486+
14831487
// llvm.clear_cache intrinsic
14841488
// Operands: Input Chain, Start Addres, End Address
14851489
// Outputs: Output Chain

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,10 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
155155
case ISD::ZERO_EXTEND_VECTOR_INREG:
156156
Res = PromoteIntRes_EXTEND_VECTOR_INREG(N); break;
157157

158+
case ISD::VECTOR_EXTRACT_LAST_ACTIVE:
159+
Res = PromoteIntRes_VECTOR_EXTRACT_LAST_ACTIVE(N);
160+
break;
161+
158162
case ISD::SIGN_EXTEND:
159163
case ISD::VP_SIGN_EXTEND:
160164
case ISD::ZERO_EXTEND:
@@ -2069,6 +2073,9 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
20692073
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
20702074
Res = PromoteIntOp_VECTOR_HISTOGRAM(N, OpNo);
20712075
break;
2076+
case ISD::VECTOR_EXTRACT_LAST_ACTIVE:
2077+
Res = PromoteIntOp_VECTOR_EXTRACT_LAST_ACTIVE(N, OpNo);
2078+
break;
20722079
}
20732080

20742081
// If the result is null, the sub-method took care of registering results etc.
@@ -2803,6 +2810,14 @@ SDValue DAGTypeLegalizer::PromoteIntOp_VECTOR_HISTOGRAM(SDNode *N,
28032810
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
28042811
}
28052812

2813+
SDValue
2814+
DAGTypeLegalizer::PromoteIntOp_VECTOR_EXTRACT_LAST_ACTIVE(SDNode *N,
2815+
unsigned OpNo) {
2816+
SmallVector<SDValue, 3> NewOps(N->ops());
2817+
NewOps[OpNo] = GetPromotedInteger(N->getOperand(OpNo));
2818+
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
2819+
}
2820+
28062821
//===----------------------------------------------------------------------===//
28072822
// Integer Result Expansion
28082823
//===----------------------------------------------------------------------===//
@@ -2840,6 +2855,9 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
28402855
case ISD::BUILD_PAIR: ExpandRes_BUILD_PAIR(N, Lo, Hi); break;
28412856
case ISD::EXTRACT_ELEMENT: ExpandRes_EXTRACT_ELEMENT(N, Lo, Hi); break;
28422857
case ISD::EXTRACT_VECTOR_ELT: ExpandRes_EXTRACT_VECTOR_ELT(N, Lo, Hi); break;
2858+
case ISD::VECTOR_EXTRACT_LAST_ACTIVE:
2859+
ExpandRes_VECTOR_EXTRACT_LAST_ACTIVE(N, Lo, Hi);
2860+
break;
28432861
case ISD::VAARG: ExpandRes_VAARG(N, Lo, Hi); break;
28442862

28452863
case ISD::ANY_EXTEND: ExpandIntRes_ANY_EXTEND(N, Lo, Hi); break;
@@ -6102,6 +6120,38 @@ SDValue DAGTypeLegalizer::PromoteIntRes_EXTEND_VECTOR_INREG(SDNode *N) {
61026120
return DAG.getNode(N->getOpcode(), dl, NVT, N->getOperand(0));
61036121
}
61046122

6123+
SDValue DAGTypeLegalizer::PromoteIntRes_VECTOR_EXTRACT_LAST_ACTIVE(SDNode *N) {
6124+
EVT VT = N->getValueType(0);
6125+
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
6126+
6127+
SDLoc dl(N);
6128+
6129+
// SDValue Data = N->getOperand(0);
6130+
// SDValue Mask = N->getOperand(1);
6131+
// SDValue PassThru = N->getOperand(2);
6132+
//
6133+
return DAG.getNode(ISD::VECTOR_EXTRACT_LAST_ACTIVE, dl, NVT, N->ops());
6134+
6135+
//
6136+
// // If the input also needs to be promoted, do that first so we can get a
6137+
// // get a good idea for the output type.
6138+
// if (TLI.getTypeAction(*DAG.getContext(), Op0.getValueType())
6139+
// == TargetLowering::TypePromoteInteger) {
6140+
// SDValue In = GetPromotedInteger(Op0);
6141+
//
6142+
// // If the new type is larger than NVT, use it. We probably won't need to
6143+
// // promote it again.
6144+
// EVT SVT = In.getValueType().getScalarType();
6145+
// if (SVT.bitsGE(NVT)) {
6146+
// SDValue Ext = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, SVT, In, Op1);
6147+
// return DAG.getAnyExtOrTrunc(Ext, dl, NVT);
6148+
// }
6149+
// }
6150+
//
6151+
// return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, NVT, Op0, Op1);
6152+
//
6153+
}
6154+
61056155
SDValue DAGTypeLegalizer::PromoteIntRes_INSERT_VECTOR_ELT(SDNode *N) {
61066156
EVT OutVT = N->getValueType(0);
61076157
EVT NOutVT = TLI.getTypeToTransformTo(*DAG.getContext(), OutVT);

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
378378
SDValue PromoteIntRes_VPFunnelShift(SDNode *N);
379379
SDValue PromoteIntRes_IS_FPCLASS(SDNode *N);
380380
SDValue PromoteIntRes_PATCHPOINT(SDNode *N);
381+
SDValue PromoteIntRes_VECTOR_EXTRACT_LAST_ACTIVE(SDNode *N);
381382

382383
// Integer Operand Promotion.
383384
bool PromoteIntegerOperand(SDNode *N, unsigned OpNo);
@@ -428,6 +429,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
428429
SDValue PromoteIntOp_VP_STRIDED(SDNode *N, unsigned OpNo);
429430
SDValue PromoteIntOp_VP_SPLICE(SDNode *N, unsigned OpNo);
430431
SDValue PromoteIntOp_VECTOR_HISTOGRAM(SDNode *N, unsigned OpNo);
432+
SDValue PromoteIntOp_VECTOR_EXTRACT_LAST_ACTIVE(SDNode *N, unsigned OpNo);
431433

432434
void SExtOrZExtPromotedOperands(SDValue &LHS, SDValue &RHS);
433435
void PromoteSetCCOperands(SDValue &LHS,SDValue &RHS, ISD::CondCode Code);
@@ -1214,6 +1216,8 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
12141216
void ExpandRes_BUILD_PAIR (SDNode *N, SDValue &Lo, SDValue &Hi);
12151217
void ExpandRes_EXTRACT_ELEMENT (SDNode *N, SDValue &Lo, SDValue &Hi);
12161218
void ExpandRes_EXTRACT_VECTOR_ELT(SDNode *N, SDValue &Lo, SDValue &Hi);
1219+
void ExpandRes_VECTOR_EXTRACT_LAST_ACTIVE(SDNode *N, SDValue &Lo,
1220+
SDValue &Hi);
12171221
void ExpandRes_NormalLoad (SDNode *N, SDValue &Lo, SDValue &Hi);
12181222
void ExpandRes_VAARG (SDNode *N, SDValue &Lo, SDValue &Hi);
12191223

llvm/lib/CodeGen/SelectionDAG/LegalizeTypesGeneric.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,16 @@ void DAGTypeLegalizer::ExpandRes_EXTRACT_VECTOR_ELT(SDNode *N, SDValue &Lo,
244244
std::swap(Lo, Hi);
245245
}
246246

247+
void DAGTypeLegalizer::ExpandRes_VECTOR_EXTRACT_LAST_ACTIVE(SDNode *N,
248+
SDValue &Lo,
249+
SDValue &Hi) {
250+
// FIXME: We need to do this by casting to smaller elements, deinterleaving,
251+
// then performing 2 extract_last_active operations and returning the
252+
// two parts.
253+
254+
llvm_unreachable("Implement extract_last_active expand result!");
255+
}
256+
247257
void DAGTypeLegalizer::ExpandRes_NormalLoad(SDNode *N, SDValue &Lo,
248258
SDValue &Hi) {
249259
assert(ISD::isNormalLoad(N) && "This routine only for normal loads!");

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "llvm/ADT/DenseMap.h"
3030
#include "llvm/ADT/SmallVector.h"
3131
#include "llvm/Analysis/TargetLibraryInfo.h"
32+
#include "llvm/Analysis/ValueTracking.h"
3233
#include "llvm/Analysis/VectorUtils.h"
3334
#include "llvm/CodeGen/ISDOpcodes.h"
3435
#include "llvm/CodeGen/SelectionDAG.h"
@@ -138,6 +139,7 @@ class VectorLegalizer {
138139
SDValue ExpandVP_FNEG(SDNode *Node);
139140
SDValue ExpandVP_FABS(SDNode *Node);
140141
SDValue ExpandVP_FCOPYSIGN(SDNode *Node);
142+
SDValue ExpandVECTOR_EXTRACT_LAST_ACTIVE(SDNode *Node);
141143
SDValue ExpandSELECT(SDNode *Node);
142144
std::pair<SDValue, SDValue> ExpandLoad(SDNode *N);
143145
SDValue ExpandStore(SDNode *N);
@@ -465,6 +467,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
465467
case ISD::VECTOR_COMPRESS:
466468
case ISD::SCMP:
467469
case ISD::UCMP:
470+
case ISD::VECTOR_EXTRACT_LAST_ACTIVE:
468471
Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0));
469472
break;
470473
case ISD::SMULFIX:
@@ -1202,6 +1205,9 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
12021205
case ISD::VECTOR_COMPRESS:
12031206
Results.push_back(TLI.expandVECTOR_COMPRESS(Node, DAG));
12041207
return;
1208+
case ISD::VECTOR_EXTRACT_LAST_ACTIVE:
1209+
Results.push_back(ExpandVECTOR_EXTRACT_LAST_ACTIVE(Node));
1210+
return;
12051211
case ISD::SCMP:
12061212
case ISD::UCMP:
12071213
Results.push_back(TLI.expandCMP(Node, DAG));
@@ -1713,6 +1719,61 @@ SDValue VectorLegalizer::ExpandVP_FCOPYSIGN(SDNode *Node) {
17131719
return DAG.getNode(ISD::BITCAST, DL, VT, CopiedSign);
17141720
}
17151721

1722+
SDValue VectorLegalizer::ExpandVECTOR_EXTRACT_LAST_ACTIVE(SDNode *Node) {
1723+
dbgs() << "Expanding extract_last_active!!\n";
1724+
SDLoc DL(Node);
1725+
SDValue Data = Node->getOperand(0);
1726+
SDValue Mask = Node->getOperand(1);
1727+
SDValue PassThru = Node->getOperand(2);
1728+
1729+
EVT DataVT = Data.getValueType();
1730+
EVT ScalarVT = PassThru.getValueType();
1731+
EVT BoolVT = Mask.getValueType().getScalarType();
1732+
1733+
// Find a suitable type for a stepvector.
1734+
ConstantRange VScaleRange(1, /*isFullSet=*/true); // Dummy value.
1735+
if (DataVT.isScalableVector())
1736+
VScaleRange = getVScaleRange(&DAG.getMachineFunction().getFunction(), 64);
1737+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
1738+
unsigned EltWidth = TLI.getBitWidthForCttzElements(
1739+
ScalarVT.getTypeForEVT(*DAG.getContext()), DataVT.getVectorElementCount(),
1740+
/*ZeroIsPoison=*/true, &VScaleRange);
1741+
EVT StepVT = MVT::getIntegerVT(EltWidth);
1742+
EVT StepVecVT = DataVT.changeVectorElementType(StepVT);
1743+
1744+
// Promote to a legal type if necessary.
1745+
if (TLI.getTypeAction(StepVecVT.getSimpleVT()) ==
1746+
TargetLowering::TypePromoteInteger) {
1747+
StepVecVT = TLI.getTypeToTransformTo(*DAG.getContext(), StepVecVT);
1748+
StepVT = StepVecVT.getVectorElementType();
1749+
}
1750+
1751+
// Zero out lanes with inactive elements, then find the highest remaining
1752+
// value from the stepvector.
1753+
SDValue Zeroes = DAG.getConstant(0, DL, StepVecVT);
1754+
SDValue StepVec = DAG.getStepVector(DL, StepVecVT);
1755+
SDValue ActiveElts = DAG.getSelect(DL, StepVecVT, Mask, StepVec, Zeroes);
1756+
// Unfortunately, VectorLegalizer does not recursively legalize all added
1757+
// nodes, just the end result nodes. LegalizeDAG doesn't handle VSELECT at
1758+
// all presently. So if we need to legalize a vselect then we have to do
1759+
// it here.
1760+
if (!TLI.isTypeLegal(StepVecVT) ||
1761+
TLI.getOperationAction(ISD::VSELECT, StepVecVT) == TargetLowering::Expand)
1762+
ActiveElts = LegalizeOp(ActiveElts);
1763+
1764+
SDValue HighestIdx = DAG.getNode(ISD::VECREDUCE_UMAX, DL, StepVT, ActiveElts);
1765+
1766+
// Extract the corresponding lane from the data vector
1767+
EVT ExtVT = TLI.getVectorIdxTy(DAG.getDataLayout());
1768+
SDValue Idx = DAG.getZExtOrTrunc(HighestIdx, DL, ExtVT);
1769+
SDValue Extract =
1770+
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarVT, Data, Idx);
1771+
1772+
// If all mask lanes were inactive, choose the passthru value instead.
1773+
SDValue AnyActive = DAG.getNode(ISD::VECREDUCE_OR, DL, BoolVT, Mask);
1774+
return DAG.getSelect(DL, ScalarVT, AnyActive, Extract, PassThru);
1775+
}
1776+
17161777
void VectorLegalizer::ExpandFP_TO_UINT(SDNode *Node,
17171778
SmallVectorImpl<SDValue> &Results) {
17181779
// Attempt to expand using TargetLowering.

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6431,44 +6431,55 @@ void SelectionDAGBuilder::visitVectorExtractLastActive(const CallInst &I,
64316431
unsigned Intrinsic) {
64326432
assert(Intrinsic == Intrinsic::experimental_vector_extract_last_active &&
64336433
"Tried lowering invalid vector extract last");
6434+
6435+
// VECTOR_EXTRACT_LAST_ACTIVE,
6436+
64346437
SDLoc sdl = getCurSDLoc();
64356438
SDValue Data = getValue(I.getOperand(0));
64366439
SDValue Mask = getValue(I.getOperand(1));
64376440
SDValue PassThru = getValue(I.getOperand(2));
64386441

6439-
EVT DataVT = Data.getValueType();
6440-
EVT ScalarVT = PassThru.getValueType();
6441-
EVT BoolVT = Mask.getValueType().getScalarType();
6442-
6443-
// Find a suitable type for a stepvector.
6444-
ConstantRange VScaleRange(1, /*isFullSet=*/true); // Dummy value.
6445-
if (DataVT.isScalableVector())
6446-
VScaleRange = getVScaleRange(I.getCaller(), 64);
64476442
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6448-
unsigned EltWidth = TLI.getBitWidthForCttzElements(
6449-
I.getType(), DataVT.getVectorElementCount(), /*ZeroIsPoison=*/true,
6450-
&VScaleRange);
6451-
MVT StepVT = MVT::getIntegerVT(EltWidth);
6452-
EVT StepVecVT = DataVT.changeVectorElementType(StepVT);
6453-
6454-
// Zero out lanes with inactive elements, then find the highest remaining
6455-
// value from the stepvector.
6456-
SDValue Zeroes = DAG.getConstant(0, sdl, StepVecVT);
6457-
SDValue StepVec = DAG.getStepVector(sdl, StepVecVT);
6458-
SDValue ActiveElts = DAG.getSelect(sdl, StepVecVT, Mask, StepVec, Zeroes);
6459-
SDValue HighestIdx =
6460-
DAG.getNode(ISD::VECREDUCE_UMAX, sdl, StepVT, ActiveElts);
6461-
6462-
// Extract the corresponding lane from the data vector
6463-
EVT ExtVT = TLI.getVectorIdxTy(DAG.getDataLayout());
6464-
SDValue Idx = DAG.getZExtOrTrunc(HighestIdx, sdl, ExtVT);
6465-
SDValue Extract =
6466-
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, sdl, ScalarVT, Data, Idx);
6467-
6468-
// If all mask lanes were inactive, choose the passthru value instead.
6469-
SDValue AnyActive = DAG.getNode(ISD::VECREDUCE_OR, sdl, BoolVT, Mask);
6470-
SDValue Result = DAG.getSelect(sdl, ScalarVT, AnyActive, Extract, PassThru);
6443+
EVT ResultVT = TLI.getValueType(DAG.getDataLayout(), I.getType());
6444+
6445+
SDValue Result = DAG.getNode(ISD::VECTOR_EXTRACT_LAST_ACTIVE, sdl, ResultVT,
6446+
Data, Mask, PassThru);
6447+
64716448
setValue(&I, Result);
6449+
6450+
// EVT DataVT = Data.getValueType();
6451+
// EVT ScalarVT = PassThru.getValueType();
6452+
// EVT BoolVT = Mask.getValueType().getScalarType();
6453+
//
6454+
// // Find a suitable type for a stepvector.
6455+
// ConstantRange VScaleRange(1, /*isFullSet=*/true); // Dummy value.
6456+
// if (DataVT.isScalableVector())
6457+
// VScaleRange = getVScaleRange(I.getCaller(), 64);
6458+
// const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6459+
// unsigned EltWidth = TLI.getBitWidthForCttzElements(
6460+
// I.getType(), DataVT.getVectorElementCount(), /*ZeroIsPoison=*/true,
6461+
// &VScaleRange);
6462+
// MVT StepVT = MVT::getIntegerVT(EltWidth);
6463+
// EVT StepVecVT = DataVT.changeVectorElementType(StepVT);
6464+
//
6465+
// // Zero out lanes with inactive elements, then find the highest remaining
6466+
// // value from the stepvector.
6467+
// SDValue Zeroes = DAG.getConstant(0, sdl, StepVecVT);
6468+
// SDValue StepVec = DAG.getStepVector(sdl, StepVecVT);
6469+
// SDValue ActiveElts = DAG.getSelect(sdl, StepVecVT, Mask, StepVec, Zeroes);
6470+
// SDValue HighestIdx =
6471+
// DAG.getNode(ISD::VECREDUCE_UMAX, sdl, StepVT, ActiveElts);
6472+
//
6473+
// // Extract the corresponding lane from the data vector
6474+
// EVT ExtVT = TLI.getVectorIdxTy(DAG.getDataLayout());
6475+
// SDValue Idx = DAG.getZExtOrTrunc(HighestIdx, sdl, ExtVT);
6476+
// SDValue Extract =
6477+
// DAG.getNode(ISD::EXTRACT_VECTOR_ELT, sdl, ScalarVT, Data, Idx);
6478+
//
6479+
// // If all mask lanes were inactive, choose the passthru value instead.
6480+
// SDValue AnyActive = DAG.getNode(ISD::VECREDUCE_OR, sdl, BoolVT, Mask);
6481+
// SDValue Result = DAG.getSelect(sdl, ScalarVT, AnyActive, Extract,
6482+
// PassThru); setValue(&I, Result);
64726483
}
64736484

64746485
/// Lower the call to the specified intrinsic function.

llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,9 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
567567
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
568568
return "histogram";
569569

570+
case ISD::VECTOR_EXTRACT_LAST_ACTIVE:
571+
return "extract_last_active";
572+
570573
// Vector Predication
571574
#define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...) \
572575
case ISD::SDID: \

llvm/lib/CodeGen/TargetLoweringBase.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,9 @@ void TargetLoweringBase::initActions() {
818818
setOperationAction(ISD::SDOPC, VT, Expand);
819819
#include "llvm/IR/VPIntrinsics.def"
820820

821+
// Masked vector extracts default to expand.
822+
setOperationAction(ISD::VECTOR_EXTRACT_LAST_ACTIVE, VT, Expand);
823+
821824
// FP environment operations default to expand.
822825
setOperationAction(ISD::GET_FPENV, VT, Expand);
823826
setOperationAction(ISD::SET_FPENV, VT, Expand);

0 commit comments

Comments
 (0)