Skip to content

Commit 2309b66

Browse files
justinfargnoliNoumanAmir657
authored andcommitted
Reland "[NVPTX] Prefer prmt.b32 over bfi.b32" (llvm#114326)
Fix [failure](llvm#110766 (comment)) identified by @akuegel. --- In [[NVPTX] Improve lowering of v4i8](llvm@cbafb6f) @Artem-B add the ability to lower ISD::BUILD_VECTOR with bfi PTX instructions. @Artem-B did this because: (llvm#67866 (comment)) Under the hood byte extraction/insertion ends up as BFI/BFE instructions, so we may as well do that in PTX, too. https://godbolt.org/z/Tb3zWbj9b However, the example that @Artem-B linked was targeting sm_52. On modern architectures, ptxas uses prmt.b32. [Example](https://godbolt.org/z/Ye4W1n84o). Thus, remove uses of NVPTXISD::BFI in favor of NVPTXISD::PRMT.
1 parent 866beec commit 2309b66

File tree

3 files changed

+339
-334
lines changed

3 files changed

+339
-334
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2318,32 +2318,33 @@ SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
23182318
EVT VT = Op->getValueType(0);
23192319
if (!(Isv2x16VT(VT) || VT == MVT::v4i8))
23202320
return Op;
2321-
23222321
SDLoc DL(Op);
23232322

23242323
if (!llvm::all_of(Op->ops(), [](SDValue Operand) {
23252324
return Operand->isUndef() || isa<ConstantSDNode>(Operand) ||
23262325
isa<ConstantFPSDNode>(Operand);
23272326
})) {
2327+
if (VT != MVT::v4i8)
2328+
return Op;
23282329
// Lower non-const v4i8 vector as byte-wise constructed i32, which allows us
23292330
// to optimize calculation of constant parts.
2330-
if (VT == MVT::v4i8) {
2331-
SDValue C8 = DAG.getConstant(8, DL, MVT::i32);
2332-
SDValue E01 = DAG.getNode(
2333-
NVPTXISD::BFI, DL, MVT::i32,
2334-
DAG.getAnyExtOrTrunc(Op->getOperand(1), DL, MVT::i32),
2335-
DAG.getAnyExtOrTrunc(Op->getOperand(0), DL, MVT::i32), C8, C8);
2336-
SDValue E012 =
2337-
DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
2338-
DAG.getAnyExtOrTrunc(Op->getOperand(2), DL, MVT::i32),
2339-
E01, DAG.getConstant(16, DL, MVT::i32), C8);
2340-
SDValue E0123 =
2341-
DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
2342-
DAG.getAnyExtOrTrunc(Op->getOperand(3), DL, MVT::i32),
2343-
E012, DAG.getConstant(24, DL, MVT::i32), C8);
2344-
return DAG.getNode(ISD::BITCAST, DL, VT, E0123);
2345-
}
2346-
return Op;
2331+
auto GetPRMT = [&](const SDValue Left, const SDValue Right, bool Cast,
2332+
uint64_t SelectionValue) -> SDValue {
2333+
SDValue L = Left;
2334+
SDValue R = Right;
2335+
if (Cast) {
2336+
L = DAG.getAnyExtOrTrunc(L, DL, MVT::i32);
2337+
R = DAG.getAnyExtOrTrunc(R, DL, MVT::i32);
2338+
}
2339+
return DAG.getNode(
2340+
NVPTXISD::PRMT, DL, MVT::v4i8,
2341+
{L, R, DAG.getConstant(SelectionValue, DL, MVT::i32),
2342+
DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
2343+
};
2344+
auto PRMT__10 = GetPRMT(Op->getOperand(0), Op->getOperand(1), true, 0x3340);
2345+
auto PRMT__32 = GetPRMT(Op->getOperand(2), Op->getOperand(3), true, 0x3340);
2346+
auto PRMT3210 = GetPRMT(PRMT__10, PRMT__32, false, 0x5410);
2347+
return DAG.getNode(ISD::BITCAST, DL, VT, PRMT3210);
23472348
}
23482349

23492350
// Get value or the Nth operand as an APInt(32). Undef values treated as 0.
@@ -2374,8 +2375,8 @@ SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
23742375
} else {
23752376
llvm_unreachable("Unsupported type");
23762377
}
2377-
SDValue Const = DAG.getConstant(Value, SDLoc(Op), MVT::i32);
2378-
return DAG.getNode(ISD::BITCAST, SDLoc(Op), Op->getValueType(0), Const);
2378+
SDValue Const = DAG.getConstant(Value, DL, MVT::i32);
2379+
return DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), Const);
23792380
}
23802381

23812382
SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,

0 commit comments

Comments
 (0)