Skip to content

Commit 2f005df

Browse files
authored
[DAG][X86] Fold mgather/mscatter/etc with splat index (#65980)
A splat index means the operation is reading from (writing to) the same memory location. Generally, zero is the cheapest value to splat. As such, we'd prefer to add the splatted value to the base, and use a constant zero as the index operand.
1 parent 8b47913 commit 2f005df

File tree

2 files changed

+26
-15
lines changed

2 files changed

+26
-15
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

+15-2
Original file line numberDiff line numberDiff line change
@@ -11637,8 +11637,6 @@ static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
1163711637

1163811638
bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
1163911639
SelectionDAG &DAG, const SDLoc &DL) {
11640-
if (Index.getOpcode() != ISD::ADD)
11641-
return false;
1164211640

1164311641
// Only perform the transformation when existing operands can be reused.
1164411642
if (IndexIsScaled)
@@ -11648,6 +11646,21 @@ bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
1164811646
return false;
1164911647

1165011648
EVT VT = BasePtr.getValueType();
11649+
11650+
if (SDValue SplatVal = DAG.getSplatValue(Index);
11651+
SplatVal && !isNullConstant(SplatVal) &&
11652+
SplatVal.getValueType() == VT) {
11653+
if (isNullConstant(BasePtr))
11654+
BasePtr = SplatVal;
11655+
else
11656+
BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
11657+
Index = DAG.getSplat(Index.getValueType(), DL, DAG.getConstant(0, DL, VT));
11658+
return true;
11659+
}
11660+
11661+
if (Index.getOpcode() != ISD::ADD)
11662+
return false;
11663+
1165111664
if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(0));
1165211665
SplatVal && SplatVal.getValueType() == VT) {
1165311666
if (isNullConstant(BasePtr))

llvm/test/CodeGen/X86/masked_gather.ll

+11-13
Original file line numberDiff line numberDiff line change
@@ -1747,29 +1747,27 @@ define <8 x i32> @gather_v8i32_v8i32(<8 x i32> %trigger) {
17471747
; AVX512F-NEXT: vptestnmd %zmm0, %zmm0, %k0
17481748
; AVX512F-NEXT: kshiftlw $8, %k0, %k0
17491749
; AVX512F-NEXT: kshiftrw $8, %k0, %k1
1750-
; AVX512F-NEXT: vpbroadcastd {{.*#+}} zmm0 = [12,12,12,12,12,12,12,12,12,12,12,12,12,12,12,12]
1750+
; AVX512F-NEXT: vpxor %xmm0, %xmm0, %xmm0
17511751
; AVX512F-NEXT: vpxor %xmm1, %xmm1, %xmm1
1752-
; AVX512F-NEXT: vpxor %xmm2, %xmm2, %xmm2
17531752
; AVX512F-NEXT: kmovw %k1, %k2
1754-
; AVX512F-NEXT: vpgatherdd c(,%zmm0), %zmm2 {%k2}
1755-
; AVX512F-NEXT: vpbroadcastd {{.*#+}} zmm0 = [28,28,28,28,28,28,28,28,28,28,28,28,28,28,28,28]
1756-
; AVX512F-NEXT: vpgatherdd c(,%zmm0), %zmm1 {%k1}
1757-
; AVX512F-NEXT: vpaddd %ymm1, %ymm2, %ymm0
1758-
; AVX512F-NEXT: vpaddd %ymm1, %ymm0, %ymm0
1753+
; AVX512F-NEXT: vpgatherdd c+12(,%zmm0), %zmm1 {%k2}
1754+
; AVX512F-NEXT: vpxor %xmm2, %xmm2, %xmm2
1755+
; AVX512F-NEXT: vpgatherdd c+28(,%zmm0), %zmm2 {%k1}
1756+
; AVX512F-NEXT: vpaddd %ymm2, %ymm1, %ymm0
1757+
; AVX512F-NEXT: vpaddd %ymm2, %ymm0, %ymm0
17591758
; AVX512F-NEXT: retq
17601759
;
17611760
; AVX512VL-LABEL: gather_v8i32_v8i32:
17621761
; AVX512VL: # %bb.0:
17631762
; AVX512VL-NEXT: vptestnmd %ymm0, %ymm0, %k1
17641763
; AVX512VL-NEXT: vpxor %xmm0, %xmm0, %xmm0
1765-
; AVX512VL-NEXT: vpbroadcastd {{.*#+}} ymm1 = [12,12,12,12,12,12,12,12]
17661764
; AVX512VL-NEXT: kmovw %k1, %k2
1765+
; AVX512VL-NEXT: vpxor %xmm1, %xmm1, %xmm1
1766+
; AVX512VL-NEXT: vpgatherdd c+12(,%ymm0), %ymm1 {%k2}
17671767
; AVX512VL-NEXT: vpxor %xmm2, %xmm2, %xmm2
1768-
; AVX512VL-NEXT: vpgatherdd c(,%ymm1), %ymm2 {%k2}
1769-
; AVX512VL-NEXT: vpbroadcastd {{.*#+}} ymm1 = [28,28,28,28,28,28,28,28]
1770-
; AVX512VL-NEXT: vpgatherdd c(,%ymm1), %ymm0 {%k1}
1771-
; AVX512VL-NEXT: vpaddd %ymm0, %ymm2, %ymm1
1772-
; AVX512VL-NEXT: vpaddd %ymm0, %ymm1, %ymm0
1768+
; AVX512VL-NEXT: vpgatherdd c+28(,%ymm0), %ymm2 {%k1}
1769+
; AVX512VL-NEXT: vpaddd %ymm2, %ymm1, %ymm0
1770+
; AVX512VL-NEXT: vpaddd %ymm2, %ymm0, %ymm0
17731771
; AVX512VL-NEXT: retq
17741772
%1 = icmp eq <8 x i32> %trigger, zeroinitializer
17751773
%2 = call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> getelementptr (%struct.a, <8 x ptr> <ptr @c, ptr @c, ptr @c, ptr @c, ptr @c, ptr @c, ptr @c, ptr @c>, <8 x i64> zeroinitializer, i32 0, <8 x i64> <i64 3, i64 3, i64 3, i64 3, i64 3, i64 3, i64 3, i64 3>), i32 4, <8 x i1> %1, <8 x i32> undef)

0 commit comments

Comments
 (0)