Skip to content

Commit 17c567b

Browse files
committed
[X86] combineVPMADD - add constant folding support for PMADDWD/PMADDUBSW instructions
1 parent f447597 commit 17c567b

File tree

2 files changed

+47
-57
lines changed

2 files changed

+47
-57
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56922,16 +56922,41 @@ static SDValue combinePMULDQ(SDNode *N, SelectionDAG &DAG,
5692256922
// Simplify VPMADDUBSW/VPMADDWD operations.
5692356923
static SDValue combineVPMADD(SDNode *N, SelectionDAG &DAG,
5692456924
TargetLowering::DAGCombinerInfo &DCI) {
56925-
EVT VT = N->getValueType(0);
56925+
MVT VT = N->getSimpleValueType(0);
5692656926
SDValue LHS = N->getOperand(0);
5692756927
SDValue RHS = N->getOperand(1);
56928+
unsigned Opc = N->getOpcode();
56929+
bool IsPMADDWD = Opc == X86ISD::VPMADDWD;
56930+
assert((Opc == X86ISD::VPMADDWD || Opc == X86ISD::VPMADDUBSW) &&
56931+
"Unexpected PMADD opcode");
5692856932

5692956933
// Multiply by zero.
5693056934
// Don't return LHS/RHS as it may contain UNDEFs.
5693156935
if (ISD::isBuildVectorAllZeros(LHS.getNode()) ||
5693256936
ISD::isBuildVectorAllZeros(RHS.getNode()))
5693356937
return DAG.getConstant(0, SDLoc(N), VT);
5693456938

56939+
// Constant folding.
56940+
APInt LHSUndefs, RHSUndefs;
56941+
SmallVector<APInt> LHSBits, RHSBits;
56942+
unsigned SrcEltBits = LHS.getScalarValueSizeInBits();
56943+
unsigned DstEltBits = VT.getScalarSizeInBits();
56944+
if (getTargetConstantBitsFromNode(LHS, SrcEltBits, LHSUndefs, LHSBits) &&
56945+
getTargetConstantBitsFromNode(RHS, SrcEltBits, RHSUndefs, RHSBits)) {
56946+
SmallVector<APInt> Result;
56947+
for (unsigned I = 0, E = LHSBits.size(); I != E; I += 2) {
56948+
APInt LHSLo = LHSBits[I + 0], LHSHi = LHSBits[I + 1];
56949+
APInt RHSLo = RHSBits[I + 0], RHSHi = RHSBits[I + 1];
56950+
LHSLo = IsPMADDWD ? LHSLo.sext(DstEltBits) : LHSLo.zext(DstEltBits);
56951+
LHSHi = IsPMADDWD ? LHSHi.sext(DstEltBits) : LHSHi.zext(DstEltBits);
56952+
APInt Lo = LHSLo * RHSLo.sext(DstEltBits);
56953+
APInt Hi = LHSHi * RHSHi.sext(DstEltBits);
56954+
APInt Res = IsPMADDWD ? (Lo + Hi) : Lo.sadd_sat(Hi);
56955+
Result.push_back(Res);
56956+
}
56957+
return getConstVector(Result, VT, DAG, SDLoc(N));
56958+
}
56959+
5693556960
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
5693656961
APInt DemandedElts = APInt::getAllOnes(VT.getVectorNumElements());
5693756962
if (TLI.SimplifyDemandedVectorElts(SDValue(N, 0), DemandedElts, DCI))

llvm/test/CodeGen/X86/combine-pmadd.ll

Lines changed: 21 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -88,43 +88,33 @@ define <4 x i32> @combine_pmaddwd_demandedelts(<8 x i16> %a0, <8 x i16> %a1) {
8888
ret <4 x i32> %4
8989
}
9090

91-
; TODO: [2] = (-5*13)+(6*-15) = -155 = 4294967141
91+
; [2]: (-5*13)+(6*-15) = -155 = 4294967141
9292
define <4 x i32> @combine_pmaddwd_constant() {
9393
; SSE-LABEL: combine_pmaddwd_constant:
9494
; SSE: # %bb.0:
95-
; SSE-NEXT: pmovsxbw {{.*#+}} xmm0 = [65535,2,3,65532,65531,6,7,65528]
96-
; SSE-NEXT: pmaddwd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [65531,7,65527,65525,13,65521,17,65517]
95+
; SSE-NEXT: movaps {{.*#+}} xmm0 = [19,17,4294967141,271]
9796
; SSE-NEXT: retq
9897
;
9998
; AVX-LABEL: combine_pmaddwd_constant:
10099
; AVX: # %bb.0:
101-
; AVX-NEXT: vpmovsxbw {{.*#+}} xmm0 = [65535,2,3,65532,65531,6,7,65528]
102-
; AVX-NEXT: vpmaddwd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 # [65531,7,65527,65525,13,65521,17,65517]
100+
; AVX-NEXT: vmovaps {{.*#+}} xmm0 = [19,17,4294967141,271]
103101
; AVX-NEXT: retq
104102
%1 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> <i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8>, <8 x i16> <i16 -5, i16 7, i16 -9, i16 -11, i16 13, i16 -15, i16 17, i16 -19>)
105103
ret <4 x i32> %1
106104
}
107105

108106
; ensure we don't assume pmaddwd performs add nsw
109-
; TODO: (-32768*-32768)+(-32768*-32768) = 0x80000000 = 2147483648
107+
; [0]: (-32768*-32768)+(-32768*-32768) = 0x80000000 = 2147483648
110108
define <4 x i32> @combine_pmaddwd_constant_nsw() {
111109
; SSE-LABEL: combine_pmaddwd_constant_nsw:
112110
; SSE: # %bb.0:
113-
; SSE-NEXT: movdqa {{.*#+}} xmm0 = [32768,32768,32768,32768,32768,32768,32768,32768]
114-
; SSE-NEXT: pmaddwd %xmm0, %xmm0
111+
; SSE-NEXT: movaps {{.*#+}} xmm0 = [2147483648,2147483648,2147483648,2147483648]
115112
; SSE-NEXT: retq
116113
;
117-
; AVX1-LABEL: combine_pmaddwd_constant_nsw:
118-
; AVX1: # %bb.0:
119-
; AVX1-NEXT: vbroadcastss {{.*#+}} xmm0 = [32768,32768,32768,32768,32768,32768,32768,32768]
120-
; AVX1-NEXT: vpmaddwd %xmm0, %xmm0, %xmm0
121-
; AVX1-NEXT: retq
122-
;
123-
; AVX2-LABEL: combine_pmaddwd_constant_nsw:
124-
; AVX2: # %bb.0:
125-
; AVX2-NEXT: vpbroadcastw {{.*#+}} xmm0 = [32768,32768,32768,32768,32768,32768,32768,32768]
126-
; AVX2-NEXT: vpmaddwd %xmm0, %xmm0, %xmm0
127-
; AVX2-NEXT: retq
114+
; AVX-LABEL: combine_pmaddwd_constant_nsw:
115+
; AVX: # %bb.0:
116+
; AVX-NEXT: vbroadcastss {{.*#+}} xmm0 = [2147483648,2147483648,2147483648,2147483648]
117+
; AVX-NEXT: retq
128118
%1 = insertelement <8 x i16> undef, i16 32768, i32 0
129119
%2 = shufflevector <8 x i16> %1, <8 x i16> undef, <8 x i32> zeroinitializer
130120
%3 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %2, <8 x i16> %2)
@@ -213,51 +203,26 @@ define <8 x i16> @combine_pmaddubsw_demandedelts(<16 x i8> %a0, <16 x i8> %a1) {
213203
ret <8 x i16> %4
214204
}
215205

216-
; TODO
206+
; [3]: ((uint16_t)-6*7)+(7*-8) = (250*7)+(7*-8) = 1694
217207
define i32 @combine_pmaddubsw_constant() {
218-
; SSE-LABEL: combine_pmaddubsw_constant:
219-
; SSE: # %bb.0:
220-
; SSE-NEXT: movdqa {{.*#+}} xmm0 = [0,1,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
221-
; SSE-NEXT: pmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [1,2,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
222-
; SSE-NEXT: pextrw $3, %xmm0, %eax
223-
; SSE-NEXT: cwtl
224-
; SSE-NEXT: retq
225-
;
226-
; AVX-LABEL: combine_pmaddubsw_constant:
227-
; AVX: # %bb.0:
228-
; AVX-NEXT: vmovdqa {{.*#+}} xmm0 = [0,1,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
229-
; AVX-NEXT: vpmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 # [1,2,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
230-
; AVX-NEXT: vpextrw $3, %xmm0, %eax
231-
; AVX-NEXT: cwtl
232-
; AVX-NEXT: retq
208+
; CHECK-LABEL: combine_pmaddubsw_constant:
209+
; CHECK: # %bb.0:
210+
; CHECK-NEXT: movl $1694, %eax # imm = 0x69E
211+
; CHECK-NEXT: retq
233212
%1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> <i8 0, i8 1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <16 x i8> <i8 1, i8 2, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>)
234-
%2 = extractelement <8 x i16> %1, i32 3 ; ((uint16_t)-6*7)+(7*-8) = (250*7)+(7*-8) = 1694
213+
%2 = extractelement <8 x i16> %1, i32 3
235214
%3 = sext i16 %2 to i32
236215
ret i32 %3
237216
}
238217

239-
; TODO
218+
; [0]: add_sat_i16(((uint16_t)-1*-128),((uint16_t)-1*-128)_ = add_sat_i16(255*-128),(255*-128)) = sat_i16(-65280) = -32768
240219
define i32 @combine_pmaddubsw_constant_sat() {
241-
; SSE-LABEL: combine_pmaddubsw_constant_sat:
242-
; SSE: # %bb.0:
243-
; SSE-NEXT: movdqa {{.*#+}} xmm0 = [255,255,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
244-
; SSE-NEXT: pmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [128,128,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
245-
; SSE-NEXT: movd %xmm0, %eax
246-
; SSE-NEXT: cwtl
247-
; SSE-NEXT: retq
248-
;
249-
; AVX-LABEL: combine_pmaddubsw_constant_sat:
250-
; AVX: # %bb.0:
251-
; AVX-NEXT: vmovdqa {{.*#+}} xmm0 = [255,255,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
252-
; AVX-NEXT: vpmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 # [128,128,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
253-
; AVX-NEXT: vmovd %xmm0, %eax
254-
; AVX-NEXT: cwtl
255-
; AVX-NEXT: retq
220+
; CHECK-LABEL: combine_pmaddubsw_constant_sat:
221+
; CHECK: # %bb.0:
222+
; CHECK-NEXT: movl $-32768, %eax # imm = 0x8000
223+
; CHECK-NEXT: retq
256224
%1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> <i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <16 x i8> <i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>)
257-
%2 = extractelement <8 x i16> %1, i32 0 ; add_sat_i16(((uint16_t)-1*-128),((uint16_t)-1*-128)_ = add_sat_i16(255*-128),(255*-128)) = sat_i16(-65280) = -32768
225+
%2 = extractelement <8 x i16> %1, i32 0
258226
%3 = sext i16 %2 to i32
259227
ret i32 %3
260228
}
261-
262-
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
263-
; CHECK: {{.*}}

0 commit comments

Comments
 (0)