Skip to content

[AArch64] Improve index selection for histograms #111150

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 22, 2024
30 changes: 26 additions & 4 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1114,7 +1114,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT,
ISD::VECREDUCE_ADD, ISD::STEP_VECTOR});

setTargetDAGCombine({ISD::MGATHER, ISD::MSCATTER});
setTargetDAGCombine(
{ISD::MGATHER, ISD::MSCATTER, ISD::EXPERIMENTAL_VECTOR_HISTOGRAM});

setTargetDAGCombine(ISD::FP_EXTEND);

Expand Down Expand Up @@ -24079,12 +24080,32 @@ static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,

static SDValue performMaskedGatherScatterCombine(
SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
assert(MGS && "Can only combine gather load or scatter store nodes");

if (!DCI.isBeforeLegalize())
return SDValue();

if (N->getOpcode() == ISD::EXPERIMENTAL_VECTOR_HISTOGRAM) {
MaskedHistogramSDNode *HG = cast<MaskedHistogramSDNode>(N);
assert(HG &&
"Can only combine gather load, scatter store or histogram nodes");

SDValue Index = HG->getIndex();
if (ISD::isExtOpcode(Index->getOpcode())) {
SDLoc DL(HG);
SDValue ExtOp = Index.getOperand(0);
SDValue Ops[] = {HG->getChain(), HG->getInc(), HG->getMask(),
HG->getBasePtr(), ExtOp, HG->getScale(),
HG->getIntID()};
return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other),
HG->getMemoryVT(), DL, Ops,
HG->getMemOperand(), HG->getIndexType());
}
return SDValue();
}

MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
assert(MGS &&
"Can only combine gather load, scatter store or histogram nodes");

SDLoc DL(MGS);
SDValue Chain = MGS->getChain();
SDValue Scale = MGS->getScale();
Expand Down Expand Up @@ -26277,6 +26298,7 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performMSTORECombine(N, DCI, DAG, Subtarget);
case ISD::MGATHER:
case ISD::MSCATTER:
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
return performMaskedGatherScatterCombine(N, DCI, DAG);
case ISD::FP_EXTEND:
return performFPExtendCombine(N, DAG, DCI, Subtarget);
Expand Down
74 changes: 74 additions & 0 deletions llvm/test/CodeGen/AArch64/sve2-histcnt.ll
Original file line number Diff line number Diff line change
Expand Up @@ -267,5 +267,79 @@ define void @histogram_i16_8_lane(ptr %base, <vscale x 8 x i32> %indices, i16 %i
ret void
}

define void @histogram_i32_zextend(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
; CHECK-LABEL: histogram_i32_zextend:
; CHECK: // %bb.0:
; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
; CHECK-NEXT: mov z3.s, #1 // =0x1
; CHECK-NEXT: ld1w { z2.s }, p0/z, [x0, z0.s, sxtw #2]
; CHECK-NEXT: ptrue p1.s
; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
; CHECK-NEXT: st1w { z1.s }, p0, [x0, z0.s, sxtw #2]
; CHECK-NEXT: ret
%extended = zext <vscale x 4 x i32> %indices to <vscale x 4 x i64>
%buckets = getelementptr i32, ptr %base, <vscale x 4 x i64> %extended
call void @llvm.experimental.vector.histogram.add.nxv4p0.i32(<vscale x 4 x ptr> %buckets, i32 1, <vscale x 4 x i1> %mask)
ret void
}

define void @histogram_i32_8_lane_zextend(ptr %base, <vscale x 8 x i32> %indices, i32 %inc, <vscale x 8 x i1> %mask) #0 {
; CHECK-LABEL: histogram_i32_8_lane_zextend:
; CHECK: // %bb.0:
; CHECK-NEXT: punpklo p1.h, p0.b
; CHECK-NEXT: mov z4.s, w1
; CHECK-NEXT: ptrue p2.s
; CHECK-NEXT: histcnt z2.s, p1/z, z0.s, z0.s
; CHECK-NEXT: ld1w { z3.s }, p1/z, [x0, z0.s, sxtw #2]
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: mad z2.s, p2/m, z4.s, z3.s
; CHECK-NEXT: st1w { z2.s }, p1, [x0, z0.s, sxtw #2]
; CHECK-NEXT: histcnt z0.s, p0/z, z1.s, z1.s
; CHECK-NEXT: ld1w { z2.s }, p0/z, [x0, z1.s, sxtw #2]
; CHECK-NEXT: mad z0.s, p2/m, z4.s, z2.s
; CHECK-NEXT: st1w { z0.s }, p0, [x0, z1.s, sxtw #2]
; CHECK-NEXT: ret
%extended = zext <vscale x 8 x i32> %indices to <vscale x 8 x i64>
%buckets = getelementptr i32, ptr %base, <vscale x 8 x i64> %extended
call void @llvm.experimental.vector.histogram.add.nxv8p0.i32(<vscale x 8 x ptr> %buckets, i32 %inc, <vscale x 8 x i1> %mask)
ret void
}
define void @histogram_i32_sextend(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0{
; CHECK-LABEL: histogram_i32_sextend:
; CHECK: // %bb.0:
; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
; CHECK-NEXT: mov z3.s, #1 // =0x1
; CHECK-NEXT: ld1w { z2.s }, p0/z, [x0, z0.s, sxtw #2]
; CHECK-NEXT: ptrue p1.s
; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
; CHECK-NEXT: st1w { z1.s }, p0, [x0, z0.s, sxtw #2]
; CHECK-NEXT: ret
%extended = sext <vscale x 4 x i32> %indices to <vscale x 4 x i64>
%buckets = getelementptr i32, ptr %base, <vscale x 4 x i64> %extended
call void @llvm.experimental.vector.histogram.add.nxv4p0.i32(<vscale x 4 x ptr> %buckets, i32 1, <vscale x 4 x i1> %mask)
ret void
}
define void @histogram_i32_8_lane_sextend(ptr %base, <vscale x 8 x i32> %indices, i32 %inc, <vscale x 8 x i1> %mask) #0 {
; CHECK-LABEL: histogram_i32_8_lane_sextend:
; CHECK: // %bb.0:
; CHECK-NEXT: punpklo p1.h, p0.b
; CHECK-NEXT: mov z4.s, w1
; CHECK-NEXT: ptrue p2.s
; CHECK-NEXT: histcnt z2.s, p1/z, z0.s, z0.s
; CHECK-NEXT: ld1w { z3.s }, p1/z, [x0, z0.s, sxtw #2]
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: mad z2.s, p2/m, z4.s, z3.s
; CHECK-NEXT: st1w { z2.s }, p1, [x0, z0.s, sxtw #2]
; CHECK-NEXT: histcnt z0.s, p0/z, z1.s, z1.s
; CHECK-NEXT: ld1w { z2.s }, p0/z, [x0, z1.s, sxtw #2]
; CHECK-NEXT: mad z0.s, p2/m, z4.s, z2.s
; CHECK-NEXT: st1w { z0.s }, p0, [x0, z1.s, sxtw #2]
; CHECK-NEXT: ret
%extended = sext <vscale x 8 x i32> %indices to <vscale x 8 x i64>
%buckets = getelementptr i32, ptr %base, <vscale x 8 x i64> %extended
call void @llvm.experimental.vector.histogram.add.nxv8p0.i32(<vscale x 8 x ptr> %buckets, i32 %inc, <vscale x 8 x i1> %mask)
ret void
}


attributes #0 = { "target-features"="+sve2" vscale_range(1, 16) }
Loading