Skip to content

[AArch64] Push mul into extend operands #94960

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 44 additions & 36 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17720,6 +17720,47 @@ static SDValue performMulVectorCmpZeroCombine(SDNode *N, SelectionDAG &DAG) {
return DAG.getNode(AArch64ISD::NVCAST, DL, VT, CM);
}

// Transform vector add(zext i8 to i32, zext i8 to i32)
// into sext(add(zext(i8 to i16), zext(i8 to i16)) to i32)
// This allows extra uses of saddl/uaddl at the lower vector widths, and less
// extends.
static SDValue performVectorExtCombine(SDNode *N, SelectionDAG &DAG) {
EVT VT = N->getValueType(0);
if (!VT.isFixedLengthVector() || VT.getSizeInBits() <= 128 ||
(N->getOperand(0).getOpcode() != ISD::ZERO_EXTEND &&
N->getOperand(0).getOpcode() != ISD::SIGN_EXTEND) ||
(N->getOperand(1).getOpcode() != ISD::ZERO_EXTEND &&
N->getOperand(1).getOpcode() != ISD::SIGN_EXTEND) ||
N->getOperand(0).getOperand(0).getValueType() !=
N->getOperand(1).getOperand(0).getValueType())
return SDValue();

if (N->getOpcode() == ISD::MUL &&
N->getOperand(0).getOpcode() != N->getOperand(1).getOpcode())
return SDValue();

SDValue N0 = N->getOperand(0).getOperand(0);
SDValue N1 = N->getOperand(1).getOperand(0);
EVT InVT = N0.getValueType();

EVT S1 = InVT.getScalarType();
EVT S2 = VT.getScalarType();
if ((S2 == MVT::i32 && S1 == MVT::i8) ||
(S2 == MVT::i64 && (S1 == MVT::i8 || S1 == MVT::i16))) {
SDLoc DL(N);
EVT HalfVT = EVT::getVectorVT(*DAG.getContext(),
S2.getHalfSizedIntegerVT(*DAG.getContext()),
VT.getVectorElementCount());
SDValue NewN0 = DAG.getNode(N->getOperand(0).getOpcode(), DL, HalfVT, N0);
SDValue NewN1 = DAG.getNode(N->getOperand(1).getOpcode(), DL, HalfVT, N1);
SDValue NewOp = DAG.getNode(N->getOpcode(), DL, HalfVT, NewN0, NewN1);
return DAG.getNode(N->getOpcode() == ISD::MUL ? N->getOperand(0).getOpcode()
: (unsigned)ISD::SIGN_EXTEND,
DL, VT, NewOp);
}
return SDValue();
}

static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
Expand All @@ -17728,6 +17769,8 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
return Ext;
if (SDValue Ext = performMulVectorCmpZeroCombine(N, DAG))
return Ext;
if (SDValue Ext = performVectorExtCombine(N, DAG))
return Ext;

if (DCI.isBeforeLegalizeOps())
return SDValue();
Expand Down Expand Up @@ -19604,41 +19647,6 @@ static SDValue foldADCToCINC(SDNode *N, SelectionDAG &DAG) {
return DAG.getNode(AArch64ISD::CSINC, DL, VT, LHS, LHS, CC, Cond);
}

// Transform vector add(zext i8 to i32, zext i8 to i32)
// into sext(add(zext(i8 to i16), zext(i8 to i16)) to i32)
// This allows extra uses of saddl/uaddl at the lower vector widths, and less
// extends.
static SDValue performVectorAddSubExtCombine(SDNode *N, SelectionDAG &DAG) {
EVT VT = N->getValueType(0);
if (!VT.isFixedLengthVector() || VT.getSizeInBits() <= 128 ||
(N->getOperand(0).getOpcode() != ISD::ZERO_EXTEND &&
N->getOperand(0).getOpcode() != ISD::SIGN_EXTEND) ||
(N->getOperand(1).getOpcode() != ISD::ZERO_EXTEND &&
N->getOperand(1).getOpcode() != ISD::SIGN_EXTEND) ||
N->getOperand(0).getOperand(0).getValueType() !=
N->getOperand(1).getOperand(0).getValueType())
return SDValue();

SDValue N0 = N->getOperand(0).getOperand(0);
SDValue N1 = N->getOperand(1).getOperand(0);
EVT InVT = N0.getValueType();

EVT S1 = InVT.getScalarType();
EVT S2 = VT.getScalarType();
if ((S2 == MVT::i32 && S1 == MVT::i8) ||
(S2 == MVT::i64 && (S1 == MVT::i8 || S1 == MVT::i16))) {
SDLoc DL(N);
EVT HalfVT = EVT::getVectorVT(*DAG.getContext(),
S2.getHalfSizedIntegerVT(*DAG.getContext()),
VT.getVectorElementCount());
SDValue NewN0 = DAG.getNode(N->getOperand(0).getOpcode(), DL, HalfVT, N0);
SDValue NewN1 = DAG.getNode(N->getOperand(1).getOpcode(), DL, HalfVT, N1);
SDValue NewOp = DAG.getNode(N->getOpcode(), DL, HalfVT, NewN0, NewN1);
return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, NewOp);
}
return SDValue();
}

static SDValue performBuildVectorCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG) {
Expand Down Expand Up @@ -20260,7 +20268,7 @@ static SDValue performAddSubCombine(SDNode *N,
return Val;
if (SDValue Val = performNegCSelCombine(N, DCI.DAG))
return Val;
if (SDValue Val = performVectorAddSubExtCombine(N, DCI.DAG))
if (SDValue Val = performVectorExtCombine(N, DCI.DAG))
return Val;
if (SDValue Val = performAddCombineForShiftedOperands(N, DCI.DAG))
return Val;
Expand Down
108 changes: 41 additions & 67 deletions llvm/test/CodeGen/AArch64/aarch64-wide-mul.ll
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,12 @@ entry:
define <16 x i32> @mul_i32(<16 x i8> %a, <16 x i8> %b) {
; CHECK-SD-LABEL: mul_i32:
; CHECK-SD: // %bb.0: // %entry
; CHECK-SD-NEXT: ushll v2.8h, v0.8b, #0
; CHECK-SD-NEXT: ushll v4.8h, v1.8b, #0
; CHECK-SD-NEXT: ushll2 v5.8h, v0.16b, #0
; CHECK-SD-NEXT: ushll2 v6.8h, v1.16b, #0
; CHECK-SD-NEXT: umull v0.4s, v2.4h, v4.4h
; CHECK-SD-NEXT: umull2 v1.4s, v2.8h, v4.8h
; CHECK-SD-NEXT: umull2 v3.4s, v5.8h, v6.8h
; CHECK-SD-NEXT: umull v2.4s, v5.4h, v6.4h
; CHECK-SD-NEXT: umull v2.8h, v0.8b, v1.8b
; CHECK-SD-NEXT: umull2 v4.8h, v0.16b, v1.16b
; CHECK-SD-NEXT: ushll v0.4s, v2.4h, #0
; CHECK-SD-NEXT: ushll2 v3.4s, v4.8h, #0
; CHECK-SD-NEXT: ushll2 v1.4s, v2.8h, #0
; CHECK-SD-NEXT: ushll v2.4s, v4.4h, #0
; CHECK-SD-NEXT: ret
;
; CHECK-GI-LABEL: mul_i32:
Expand All @@ -59,26 +57,20 @@ entry:
define <16 x i64> @mul_i64(<16 x i8> %a, <16 x i8> %b) {
; CHECK-SD-LABEL: mul_i64:
; CHECK-SD: // %bb.0: // %entry
; CHECK-SD-NEXT: ushll v2.8h, v0.8b, #0
; CHECK-SD-NEXT: ushll2 v0.8h, v0.16b, #0
; CHECK-SD-NEXT: ushll v3.8h, v1.8b, #0
; CHECK-SD-NEXT: ushll2 v1.8h, v1.16b, #0
; CHECK-SD-NEXT: ushll v4.4s, v2.4h, #0
; CHECK-SD-NEXT: ushll v5.4s, v0.4h, #0
; CHECK-SD-NEXT: ushll v6.4s, v3.4h, #0
; CHECK-SD-NEXT: umull v2.8h, v0.8b, v1.8b
; CHECK-SD-NEXT: umull2 v0.8h, v0.16b, v1.16b
; CHECK-SD-NEXT: ushll v3.4s, v2.4h, #0
; CHECK-SD-NEXT: ushll2 v2.4s, v2.8h, #0
; CHECK-SD-NEXT: ushll v16.4s, v1.4h, #0
; CHECK-SD-NEXT: ushll2 v7.4s, v3.8h, #0
; CHECK-SD-NEXT: ushll2 v17.4s, v0.8h, #0
; CHECK-SD-NEXT: ushll2 v18.4s, v1.8h, #0
; CHECK-SD-NEXT: umull2 v1.2d, v4.4s, v6.4s
; CHECK-SD-NEXT: umull v0.2d, v4.2s, v6.2s
; CHECK-SD-NEXT: umull2 v3.2d, v2.4s, v7.4s
; CHECK-SD-NEXT: umull v2.2d, v2.2s, v7.2s
; CHECK-SD-NEXT: umull v4.2d, v5.2s, v16.2s
; CHECK-SD-NEXT: umull2 v7.2d, v17.4s, v18.4s
; CHECK-SD-NEXT: umull2 v5.2d, v5.4s, v16.4s
; CHECK-SD-NEXT: umull v6.2d, v17.2s, v18.2s
; CHECK-SD-NEXT: ushll v5.4s, v0.4h, #0
; CHECK-SD-NEXT: ushll2 v6.4s, v0.8h, #0
; CHECK-SD-NEXT: ushll2 v1.2d, v3.4s, #0
; CHECK-SD-NEXT: ushll v0.2d, v3.2s, #0
; CHECK-SD-NEXT: ushll2 v3.2d, v2.4s, #0
; CHECK-SD-NEXT: ushll v2.2d, v2.2s, #0
; CHECK-SD-NEXT: ushll v4.2d, v5.2s, #0
; CHECK-SD-NEXT: ushll2 v7.2d, v6.4s, #0
; CHECK-SD-NEXT: ushll2 v5.2d, v5.4s, #0
; CHECK-SD-NEXT: ushll v6.2d, v6.2s, #0
; CHECK-SD-NEXT: ret
;
; CHECK-GI-LABEL: mul_i64:
Expand Down Expand Up @@ -139,17 +131,12 @@ entry:
define <16 x i32> @mla_i32(<16 x i8> %a, <16 x i8> %b, <16 x i32> %c) {
; CHECK-SD-LABEL: mla_i32:
; CHECK-SD: // %bb.0: // %entry
; CHECK-SD-NEXT: ushll v6.8h, v0.8b, #0
; CHECK-SD-NEXT: ushll v7.8h, v1.8b, #0
; CHECK-SD-NEXT: ushll2 v0.8h, v0.16b, #0
; CHECK-SD-NEXT: ushll2 v1.8h, v1.16b, #0
; CHECK-SD-NEXT: umlal v2.4s, v6.4h, v7.4h
; CHECK-SD-NEXT: umlal2 v3.4s, v6.8h, v7.8h
; CHECK-SD-NEXT: umlal2 v5.4s, v0.8h, v1.8h
; CHECK-SD-NEXT: umlal v4.4s, v0.4h, v1.4h
; CHECK-SD-NEXT: mov v0.16b, v2.16b
; CHECK-SD-NEXT: mov v1.16b, v3.16b
; CHECK-SD-NEXT: mov v2.16b, v4.16b
; CHECK-SD-NEXT: umull2 v7.8h, v0.16b, v1.16b
; CHECK-SD-NEXT: umull v6.8h, v0.8b, v1.8b
; CHECK-SD-NEXT: uaddw2 v5.4s, v5.4s, v7.8h
; CHECK-SD-NEXT: uaddw v0.4s, v2.4s, v6.4h
; CHECK-SD-NEXT: uaddw2 v1.4s, v3.4s, v6.8h
; CHECK-SD-NEXT: uaddw v2.4s, v4.4s, v7.4h
; CHECK-SD-NEXT: mov v3.16b, v5.16b
; CHECK-SD-NEXT: ret
;
Expand Down Expand Up @@ -179,35 +166,22 @@ entry:
define <16 x i64> @mla_i64(<16 x i8> %a, <16 x i8> %b, <16 x i64> %c) {
; CHECK-SD-LABEL: mla_i64:
; CHECK-SD: // %bb.0: // %entry
; CHECK-SD-NEXT: mov v17.16b, v7.16b
; CHECK-SD-NEXT: mov v16.16b, v6.16b
; CHECK-SD-NEXT: ushll v6.8h, v0.8b, #0
; CHECK-SD-NEXT: ushll2 v0.8h, v0.16b, #0
; CHECK-SD-NEXT: ushll v7.8h, v1.8b, #0
; CHECK-SD-NEXT: ushll2 v1.8h, v1.16b, #0
; CHECK-SD-NEXT: ushll v18.4s, v6.4h, #0
; CHECK-SD-NEXT: ushll2 v21.4s, v6.8h, #0
; CHECK-SD-NEXT: ushll v19.4s, v0.4h, #0
; CHECK-SD-NEXT: ushll v20.4s, v7.4h, #0
; CHECK-SD-NEXT: ushll v22.4s, v1.4h, #0
; CHECK-SD-NEXT: ushll2 v23.4s, v7.8h, #0
; CHECK-SD-NEXT: ldp q6, q7, [sp]
; CHECK-SD-NEXT: ushll2 v0.4s, v0.8h, #0
; CHECK-SD-NEXT: ushll2 v1.4s, v1.8h, #0
; CHECK-SD-NEXT: umlal2 v3.2d, v18.4s, v20.4s
; CHECK-SD-NEXT: umlal v2.2d, v18.2s, v20.2s
; CHECK-SD-NEXT: umlal v16.2d, v19.2s, v22.2s
; CHECK-SD-NEXT: umlal2 v5.2d, v21.4s, v23.4s
; CHECK-SD-NEXT: umlal v4.2d, v21.2s, v23.2s
; CHECK-SD-NEXT: umlal2 v17.2d, v19.4s, v22.4s
; CHECK-SD-NEXT: umlal2 v7.2d, v0.4s, v1.4s
; CHECK-SD-NEXT: umlal v6.2d, v0.2s, v1.2s
; CHECK-SD-NEXT: mov v0.16b, v2.16b
; CHECK-SD-NEXT: mov v1.16b, v3.16b
; CHECK-SD-NEXT: mov v2.16b, v4.16b
; CHECK-SD-NEXT: mov v3.16b, v5.16b
; CHECK-SD-NEXT: mov v4.16b, v16.16b
; CHECK-SD-NEXT: mov v5.16b, v17.16b
; CHECK-SD-NEXT: umull v16.8h, v0.8b, v1.8b
; CHECK-SD-NEXT: umull2 v0.8h, v0.16b, v1.16b
; CHECK-SD-NEXT: ldp q20, q21, [sp]
; CHECK-SD-NEXT: ushll v17.4s, v16.4h, #0
; CHECK-SD-NEXT: ushll2 v16.4s, v16.8h, #0
; CHECK-SD-NEXT: ushll2 v19.4s, v0.8h, #0
; CHECK-SD-NEXT: ushll v18.4s, v0.4h, #0
; CHECK-SD-NEXT: uaddw2 v1.2d, v3.2d, v17.4s
; CHECK-SD-NEXT: uaddw v0.2d, v2.2d, v17.2s
; CHECK-SD-NEXT: uaddw2 v3.2d, v5.2d, v16.4s
; CHECK-SD-NEXT: uaddw v2.2d, v4.2d, v16.2s
; CHECK-SD-NEXT: uaddw2 v16.2d, v21.2d, v19.4s
; CHECK-SD-NEXT: uaddw v4.2d, v6.2d, v18.2s
; CHECK-SD-NEXT: uaddw2 v5.2d, v7.2d, v18.4s
; CHECK-SD-NEXT: uaddw v6.2d, v20.2d, v19.2s
; CHECK-SD-NEXT: mov v7.16b, v16.16b
; CHECK-SD-NEXT: ret
;
; CHECK-GI-LABEL: mla_i64:
Expand Down
32 changes: 14 additions & 18 deletions llvm/test/CodeGen/AArch64/addp-shuffle.ll
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,13 @@ define <4 x double> @deinterleave_shuffle_v8f64(<8 x double> %a) {
define <4 x i32> @udot(<4 x i32> %z, <16 x i8> %a, <16 x i8> %b) {
; CHECK-LABEL: udot:
; CHECK: // %bb.0:
; CHECK-NEXT: ushll v3.8h, v1.8b, #0
; CHECK-NEXT: ushll v4.8h, v2.8b, #0
; CHECK-NEXT: ushll2 v1.8h, v1.16b, #0
; CHECK-NEXT: ushll2 v2.8h, v2.16b, #0
; CHECK-NEXT: umull2 v5.4s, v3.8h, v4.8h
; CHECK-NEXT: umull v3.4s, v3.4h, v4.4h
; CHECK-NEXT: umull2 v4.4s, v1.8h, v2.8h
; CHECK-NEXT: umull v1.4s, v1.4h, v2.4h
; CHECK-NEXT: addp v2.4s, v3.4s, v5.4s
; CHECK-NEXT: umull v3.8h, v1.8b, v2.8b
; CHECK-NEXT: umull2 v1.8h, v1.16b, v2.16b
; CHECK-NEXT: ushll2 v2.4s, v3.8h, #0
; CHECK-NEXT: ushll v3.4s, v3.4h, #0
; CHECK-NEXT: ushll2 v4.4s, v1.8h, #0
; CHECK-NEXT: ushll v1.4s, v1.4h, #0
; CHECK-NEXT: addp v2.4s, v3.4s, v2.4s
; CHECK-NEXT: addp v1.4s, v1.4s, v4.4s
; CHECK-NEXT: addp v1.4s, v2.4s, v1.4s
; CHECK-NEXT: add v0.4s, v0.4s, v1.4s
Expand All @@ -165,15 +163,13 @@ define <4 x i32> @udot(<4 x i32> %z, <16 x i8> %a, <16 x i8> %b) {
define <4 x i32> @sdot(<4 x i32> %z, <16 x i8> %a, <16 x i8> %b) {
; CHECK-LABEL: sdot:
; CHECK: // %bb.0:
; CHECK-NEXT: sshll v3.8h, v1.8b, #0
; CHECK-NEXT: sshll v4.8h, v2.8b, #0
; CHECK-NEXT: sshll2 v1.8h, v1.16b, #0
; CHECK-NEXT: sshll2 v2.8h, v2.16b, #0
; CHECK-NEXT: smull2 v5.4s, v3.8h, v4.8h
; CHECK-NEXT: smull v3.4s, v3.4h, v4.4h
; CHECK-NEXT: smull2 v4.4s, v1.8h, v2.8h
; CHECK-NEXT: smull v1.4s, v1.4h, v2.4h
; CHECK-NEXT: addp v2.4s, v3.4s, v5.4s
; CHECK-NEXT: smull v3.8h, v1.8b, v2.8b
; CHECK-NEXT: smull2 v1.8h, v1.16b, v2.16b
; CHECK-NEXT: sshll2 v2.4s, v3.8h, #0
; CHECK-NEXT: sshll v3.4s, v3.4h, #0
; CHECK-NEXT: sshll2 v4.4s, v1.8h, #0
; CHECK-NEXT: sshll v1.4s, v1.4h, #0
; CHECK-NEXT: addp v2.4s, v3.4s, v2.4s
; CHECK-NEXT: addp v1.4s, v1.4s, v4.4s
; CHECK-NEXT: addp v1.4s, v2.4s, v1.4s
; CHECK-NEXT: add v0.4s, v0.4s, v1.4s
Expand Down
Loading
Loading