Skip to content

Commit 0938cdb

Browse files
committed
[X86] computeKnownBitsForTargetNode - add handling for (V)PMADDUBSW nodes
1 parent 417cd33 commit 0938cdb

File tree

2 files changed

+57
-30
lines changed

2 files changed

+57
-30
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37109,6 +37109,33 @@ static void computeKnownBitsForPMADDWD(SDValue LHS, SDValue RHS,
3710937109
/*NUW=*/false, Lo, Hi);
3711037110
}
3711137111

37112+
static void computeKnownBitsForPMADDUBSW(SDValue LHS, SDValue RHS,
37113+
KnownBits &Known,
37114+
const APInt &DemandedElts,
37115+
const SelectionDAG &DAG,
37116+
unsigned Depth) {
37117+
unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
37118+
37119+
// Multiply signed/unsigned i8 elements to create i16 values and add_sat Lo/Hi
37120+
// pairs.
37121+
APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
37122+
APInt DemandedLoElts =
37123+
DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b01));
37124+
APInt DemandedHiElts =
37125+
DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b10));
37126+
KnownBits LHSLo =
37127+
DAG.computeKnownBits(LHS, DemandedLoElts, Depth + 1).zext(16);
37128+
KnownBits LHSHi =
37129+
DAG.computeKnownBits(LHS, DemandedHiElts, Depth + 1).zext(16);
37130+
KnownBits RHSLo =
37131+
DAG.computeKnownBits(RHS, DemandedLoElts, Depth + 1).sext(16);
37132+
KnownBits RHSHi =
37133+
DAG.computeKnownBits(RHS, DemandedHiElts, Depth + 1).sext(16);
37134+
KnownBits Lo = KnownBits::mul(LHSLo, RHSLo);
37135+
KnownBits Hi = KnownBits::mul(LHSHi, RHSHi);
37136+
Known = KnownBits::sadd_sat(Lo, Hi);
37137+
}
37138+
3711237139
void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
3711337140
KnownBits &Known,
3711437141
const APInt &DemandedElts,
@@ -37294,6 +37321,16 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
3729437321
computeKnownBitsForPMADDWD(LHS, RHS, Known, DemandedElts, DAG, Depth);
3729537322
break;
3729637323
}
37324+
case X86ISD::VPMADDUBSW: {
37325+
SDValue LHS = Op.getOperand(0);
37326+
SDValue RHS = Op.getOperand(1);
37327+
assert(VT.getVectorElementType() == MVT::i16 &&
37328+
LHS.getValueType() == RHS.getValueType() &&
37329+
LHS.getValueType().getVectorElementType() == MVT::i8 &&
37330+
"Unexpected PMADDUBSW types");
37331+
computeKnownBitsForPMADDUBSW(LHS, RHS, Known, DemandedElts, DAG, Depth);
37332+
break;
37333+
}
3729737334
case X86ISD::PMULUDQ: {
3729837335
KnownBits Known2;
3729937336
Known = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
@@ -37442,6 +37479,18 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
3744237479
computeKnownBitsForPMADDWD(LHS, RHS, Known, DemandedElts, DAG, Depth);
3744337480
break;
3744437481
}
37482+
case Intrinsic::x86_ssse3_pmadd_ub_sw_128:
37483+
case Intrinsic::x86_avx2_pmadd_ub_sw:
37484+
case Intrinsic::x86_avx512_pmaddubs_w_512: {
37485+
SDValue LHS = Op.getOperand(1);
37486+
SDValue RHS = Op.getOperand(2);
37487+
assert(VT.getScalarType() == MVT::i16 &&
37488+
LHS.getValueType() == RHS.getValueType() &&
37489+
LHS.getValueType().getScalarType() == MVT::i8 &&
37490+
"Unexpected PMADDUBSW types");
37491+
computeKnownBitsForPMADDUBSW(LHS, RHS, Known, DemandedElts, DAG, Depth);
37492+
break;
37493+
}
3744537494
case Intrinsic::x86_sse2_psad_bw:
3744637495
case Intrinsic::x86_avx2_psad_bw:
3744737496
case Intrinsic::x86_avx512_psad_bw_512: {

llvm/test/CodeGen/X86/combine-pmadd.ll

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -73,43 +73,21 @@ define <8 x i16> @combine_pmaddubsw_zero_commute(<16 x i8> %a0, <16 x i8> %a1) {
7373
}
7474

7575
define i32 @combine_pmaddubsw_constant() {
76-
; SSE-LABEL: combine_pmaddubsw_constant:
77-
; SSE: # %bb.0:
78-
; SSE-NEXT: movdqa {{.*#+}} xmm0 = [0,1,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
79-
; SSE-NEXT: pmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [1,2,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
80-
; SSE-NEXT: pextrw $3, %xmm0, %eax
81-
; SSE-NEXT: cwtl
82-
; SSE-NEXT: retq
83-
;
84-
; AVX-LABEL: combine_pmaddubsw_constant:
85-
; AVX: # %bb.0:
86-
; AVX-NEXT: vmovdqa {{.*#+}} xmm0 = [0,1,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
87-
; AVX-NEXT: vpmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 # [1,2,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
88-
; AVX-NEXT: vpextrw $3, %xmm0, %eax
89-
; AVX-NEXT: cwtl
90-
; AVX-NEXT: retq
76+
; CHECK-LABEL: combine_pmaddubsw_constant:
77+
; CHECK: # %bb.0:
78+
; CHECK-NEXT: movl $1694, %eax # imm = 0x69E
79+
; CHECK-NEXT: retq
9180
%1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> <i8 0, i8 1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <16 x i8> <i8 1, i8 2, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>)
9281
%2 = extractelement <8 x i16> %1, i32 3 ; ((uint16_t)-6*7)+(7*-8) = (250*7)+(7*-8) = 1694
9382
%3 = sext i16 %2 to i32
9483
ret i32 %3
9584
}
9685

9786
define i32 @combine_pmaddubsw_constant_sat() {
98-
; SSE-LABEL: combine_pmaddubsw_constant_sat:
99-
; SSE: # %bb.0:
100-
; SSE-NEXT: movdqa {{.*#+}} xmm0 = [255,255,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
101-
; SSE-NEXT: pmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [128,128,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
102-
; SSE-NEXT: movd %xmm0, %eax
103-
; SSE-NEXT: cwtl
104-
; SSE-NEXT: retq
105-
;
106-
; AVX-LABEL: combine_pmaddubsw_constant_sat:
107-
; AVX: # %bb.0:
108-
; AVX-NEXT: vmovdqa {{.*#+}} xmm0 = [255,255,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
109-
; AVX-NEXT: vpmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 # [128,128,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
110-
; AVX-NEXT: vmovd %xmm0, %eax
111-
; AVX-NEXT: cwtl
112-
; AVX-NEXT: retq
87+
; CHECK-LABEL: combine_pmaddubsw_constant_sat:
88+
; CHECK: # %bb.0:
89+
; CHECK-NEXT: movl $-32768, %eax # imm = 0x8000
90+
; CHECK-NEXT: retq
11391
%1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> <i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <16 x i8> <i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>)
11492
%2 = extractelement <8 x i16> %1, i32 0 ; add_sat_i16(((uint16_t)-1*-128),((uint16_t)-1*-128)_ = add_sat_i16(255*-128),(255*-128)) = sat_i16(-65280) = -32768
11593
%3 = sext i16 %2 to i32

0 commit comments

Comments
 (0)