@@ -1266,7 +1266,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
1266
1266
1267
1267
for (auto VT : { MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64,
1268
1268
MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64 }) {
1269
- setOperationAction(ISD::MLOAD, VT, Legal );
1269
+ setOperationAction(ISD::MLOAD, VT, Custom );
1270
1270
setOperationAction(ISD::MSTORE, VT, Legal);
1271
1271
}
1272
1272
@@ -1412,15 +1412,13 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
1412
1412
setTruncStoreAction(MVT::v16i32, MVT::v16i8, Legal);
1413
1413
setTruncStoreAction(MVT::v16i32, MVT::v16i16, Legal);
1414
1414
1415
- if (!Subtarget.hasVLX()) {
1416
- // With 512-bit vectors and no VLX, we prefer to widen MLOAD/MSTORE
1417
- // to 512-bit rather than use the AVX2 instructions so that we can use
1418
- // k-masks.
1419
- for (auto VT : {MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64,
1420
- MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64}) {
1421
- setOperationAction(ISD::MLOAD, VT, Custom);
1422
- setOperationAction(ISD::MSTORE, VT, Custom);
1423
- }
1415
+ // With 512-bit vectors and no VLX, we prefer to widen MLOAD/MSTORE
1416
+ // to 512-bit rather than use the AVX2 instructions so that we can use
1417
+ // k-masks.
1418
+ for (auto VT : {MVT::v4i32, MVT::v8i32, MVT::v2i64, MVT::v4i64,
1419
+ MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64}) {
1420
+ setOperationAction(ISD::MLOAD, VT, Subtarget.hasVLX() ? Legal : Custom);
1421
+ setOperationAction(ISD::MSTORE, VT, Subtarget.hasVLX() ? Legal : Custom);
1424
1422
}
1425
1423
1426
1424
setOperationAction(ISD::TRUNCATE, MVT::v8i32, Custom);
@@ -26914,8 +26912,28 @@ static SDValue LowerMLOAD(SDValue Op, const X86Subtarget &Subtarget,
26914
26912
MVT VT = Op.getSimpleValueType();
26915
26913
MVT ScalarVT = VT.getScalarType();
26916
26914
SDValue Mask = N->getMask();
26915
+ MVT MaskVT = Mask.getSimpleValueType();
26916
+ SDValue PassThru = N->getPassThru();
26917
26917
SDLoc dl(Op);
26918
26918
26919
+ // Handle AVX masked loads which don't support passthru other than 0.
26920
+ if (MaskVT.getVectorElementType() != MVT::i1) {
26921
+ // We also allow undef in the isel pattern.
26922
+ if (PassThru.isUndef() || ISD::isBuildVectorAllZeros(PassThru.getNode()))
26923
+ return Op;
26924
+
26925
+ SDValue NewLoad = DAG.getMaskedLoad(VT, dl, N->getChain(),
26926
+ N->getBasePtr(), Mask,
26927
+ getZeroVector(VT, Subtarget, DAG, dl),
26928
+ N->getMemoryVT(), N->getMemOperand(),
26929
+ N->getExtensionType(),
26930
+ N->isExpandingLoad());
26931
+ // Emit a blend.
26932
+ SDValue Select = DAG.getNode(ISD::VSELECT, dl, MaskVT, Mask, NewLoad,
26933
+ PassThru);
26934
+ return DAG.getMergeValues({ Select, NewLoad.getValue(1) }, dl);
26935
+ }
26936
+
26919
26937
assert((!N->isExpandingLoad() || Subtarget.hasAVX512()) &&
26920
26938
"Expanding masked load is supported on AVX-512 target only!");
26921
26939
@@ -26934,7 +26952,7 @@ static SDValue LowerMLOAD(SDValue Op, const X86Subtarget &Subtarget,
26934
26952
// VLX the vector should be widened to 512 bit
26935
26953
unsigned NumEltsInWideVec = 512 / VT.getScalarSizeInBits();
26936
26954
MVT WideDataVT = MVT::getVectorVT(ScalarVT, NumEltsInWideVec);
26937
- SDValue PassThru = ExtendToType(N->getPassThru() , WideDataVT, DAG);
26955
+ PassThru = ExtendToType(PassThru , WideDataVT, DAG);
26938
26956
26939
26957
// Mask element has to be i1.
26940
26958
assert(Mask.getSimpleValueType().getScalarType() == MVT::i1 &&
0 commit comments