Skip to content

Commit e06a9ca

Browse files
[LLVM][CodeGen][SVE] Improve lowering of fixed length masked mem ops. (#134402)
Converting fixed length masks, as used by MLOAD, to scalable vectors is done by comparing the mask to zero. When the mask is the result of a compare we can instead promote the operands and regenerate the original compare. At worst this reduces the dependecy chain and in most cases removes the need for multiple compares.
1 parent 483edfe commit e06a9ca

File tree

4 files changed

+46
-29
lines changed

4 files changed

+46
-29
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20190,6 +20190,12 @@ performInsertSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
2019020190
EVT VecVT = Vec.getValueType();
2019120191
EVT SubVT = SubVec.getValueType();
2019220192

20193+
// Promote fixed length vector zeros.
20194+
if (VecVT.isScalableVector() && SubVT.isFixedLengthVector() &&
20195+
Vec.isUndef() && isZerosVector(SubVec.getNode()))
20196+
return VecVT.isInteger() ? DAG.getConstant(0, DL, VecVT)
20197+
: DAG.getConstantFP(0, DL, VecVT);
20198+
2019320199
// Only do this for legal fixed vector types.
2019420200
if (!VecVT.isFixedLengthVector() ||
2019520201
!DAG.getTargetLoweringInfo().isTypeLegal(VecVT) ||
@@ -28697,17 +28703,36 @@ static SDValue convertFixedMaskToScalableVector(SDValue Mask,
2869728703
SDLoc DL(Mask);
2869828704
EVT InVT = Mask.getValueType();
2869928705
EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
28700-
28701-
auto Pg = getPredicateForFixedLengthVector(DAG, DL, InVT);
28706+
SDValue Pg = getPredicateForFixedLengthVector(DAG, DL, InVT);
2870228707

2870328708
if (ISD::isBuildVectorAllOnes(Mask.getNode()))
2870428709
return Pg;
2870528710

28706-
auto Op1 = convertToScalableVector(DAG, ContainerVT, Mask);
28707-
auto Op2 = DAG.getConstant(0, DL, ContainerVT);
28711+
bool InvertCond = false;
28712+
if (isBitwiseNot(Mask)) {
28713+
InvertCond = true;
28714+
Mask = Mask.getOperand(0);
28715+
}
28716+
28717+
SDValue Op1, Op2;
28718+
ISD::CondCode CC;
28719+
28720+
// When Mask is the result of a SETCC, it's better to regenerate the compare.
28721+
if (Mask.getOpcode() == ISD::SETCC) {
28722+
Op1 = convertToScalableVector(DAG, ContainerVT, Mask.getOperand(0));
28723+
Op2 = convertToScalableVector(DAG, ContainerVT, Mask.getOperand(1));
28724+
CC = cast<CondCodeSDNode>(Mask.getOperand(2))->get();
28725+
} else {
28726+
Op1 = convertToScalableVector(DAG, ContainerVT, Mask);
28727+
Op2 = DAG.getConstant(0, DL, ContainerVT);
28728+
CC = ISD::SETNE;
28729+
}
28730+
28731+
if (InvertCond)
28732+
CC = getSetCCInverse(CC, Op1.getValueType());
2870828733

2870928734
return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, DL, Pg.getValueType(),
28710-
{Pg, Op1, Op2, DAG.getCondCode(ISD::SETNE)});
28735+
{Pg, Op1, Op2, DAG.getCondCode(CC)});
2871128736
}
2871228737

2871328738
// Convert all fixed length vector loads larger than NEON to masked_loads.

llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -460,10 +460,9 @@ define void @masked_gather_v1i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
460460
define void @masked_gather_v2i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
461461
; CHECK-LABEL: masked_gather_v2i64:
462462
; CHECK: // %bb.0:
463-
; CHECK-NEXT: ldr q0, [x0]
464463
; CHECK-NEXT: ptrue p0.d, vl2
465-
; CHECK-NEXT: cmeq v0.2d, v0.2d, #0
466-
; CHECK-NEXT: cmpne p0.d, p0/z, z0.d, #0
464+
; CHECK-NEXT: ldr q0, [x0]
465+
; CHECK-NEXT: cmpeq p0.d, p0/z, z0.d, #0
467466
; CHECK-NEXT: ldr q0, [x1]
468467
; CHECK-NEXT: ld1d { z0.d }, p0/z, [z0.d]
469468
; CHECK-NEXT: str q0, [x0]

llvm/test/CodeGen/AArch64/sve-fixed-length-masked-loads.ll

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -401,11 +401,10 @@ define void @masked_load_sext_v32i8i16(ptr %ap, ptr %bp, ptr %c) #0 {
401401
define void @masked_load_sext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
402402
; VBITS_GE_256-LABEL: masked_load_sext_v16i8i32:
403403
; VBITS_GE_256: // %bb.0:
404-
; VBITS_GE_256-NEXT: ldr q0, [x1]
405404
; VBITS_GE_256-NEXT: ptrue p0.b, vl16
405+
; VBITS_GE_256-NEXT: ldr q0, [x1]
406406
; VBITS_GE_256-NEXT: mov x8, #8 // =0x8
407-
; VBITS_GE_256-NEXT: cmeq v0.16b, v0.16b, #0
408-
; VBITS_GE_256-NEXT: cmpne p0.b, p0/z, z0.b, #0
407+
; VBITS_GE_256-NEXT: cmpeq p0.b, p0/z, z0.b, #0
409408
; VBITS_GE_256-NEXT: ld1b { z0.b }, p0/z, [x0]
410409
; VBITS_GE_256-NEXT: ptrue p0.s, vl8
411410
; VBITS_GE_256-NEXT: ext v1.16b, v0.16b, v0.16b, #8
@@ -436,11 +435,10 @@ define void @masked_load_sext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
436435
define void @masked_load_sext_v8i8i64(ptr %ap, ptr %bp, ptr %c) #0 {
437436
; VBITS_GE_256-LABEL: masked_load_sext_v8i8i64:
438437
; VBITS_GE_256: // %bb.0:
439-
; VBITS_GE_256-NEXT: ldr d0, [x1]
440438
; VBITS_GE_256-NEXT: ptrue p0.b, vl8
439+
; VBITS_GE_256-NEXT: ldr d0, [x1]
441440
; VBITS_GE_256-NEXT: mov x8, #4 // =0x4
442-
; VBITS_GE_256-NEXT: cmeq v0.8b, v0.8b, #0
443-
; VBITS_GE_256-NEXT: cmpne p0.b, p0/z, z0.b, #0
441+
; VBITS_GE_256-NEXT: cmpeq p0.b, p0/z, z0.b, #0
444442
; VBITS_GE_256-NEXT: ld1b { z0.b }, p0/z, [x0]
445443
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
446444
; VBITS_GE_256-NEXT: sshll v0.8h, v0.8b, #0
@@ -504,11 +502,10 @@ define void @masked_load_sext_v16i16i32(ptr %ap, ptr %bp, ptr %c) #0 {
504502
define void @masked_load_sext_v8i16i64(ptr %ap, ptr %bp, ptr %c) #0 {
505503
; VBITS_GE_256-LABEL: masked_load_sext_v8i16i64:
506504
; VBITS_GE_256: // %bb.0:
507-
; VBITS_GE_256-NEXT: ldr q0, [x1]
508505
; VBITS_GE_256-NEXT: ptrue p0.h, vl8
506+
; VBITS_GE_256-NEXT: ldr q0, [x1]
509507
; VBITS_GE_256-NEXT: mov x8, #4 // =0x4
510-
; VBITS_GE_256-NEXT: cmeq v0.8h, v0.8h, #0
511-
; VBITS_GE_256-NEXT: cmpne p0.h, p0/z, z0.h, #0
508+
; VBITS_GE_256-NEXT: cmpeq p0.h, p0/z, z0.h, #0
512509
; VBITS_GE_256-NEXT: ld1h { z0.h }, p0/z, [x0]
513510
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
514511
; VBITS_GE_256-NEXT: ext v1.16b, v0.16b, v0.16b, #8
@@ -603,11 +600,10 @@ define void @masked_load_zext_v32i8i16(ptr %ap, ptr %bp, ptr %c) #0 {
603600
define void @masked_load_zext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
604601
; VBITS_GE_256-LABEL: masked_load_zext_v16i8i32:
605602
; VBITS_GE_256: // %bb.0:
606-
; VBITS_GE_256-NEXT: ldr q0, [x1]
607603
; VBITS_GE_256-NEXT: ptrue p0.b, vl16
604+
; VBITS_GE_256-NEXT: ldr q0, [x1]
608605
; VBITS_GE_256-NEXT: mov x8, #8 // =0x8
609-
; VBITS_GE_256-NEXT: cmeq v0.16b, v0.16b, #0
610-
; VBITS_GE_256-NEXT: cmpne p0.b, p0/z, z0.b, #0
606+
; VBITS_GE_256-NEXT: cmpeq p0.b, p0/z, z0.b, #0
611607
; VBITS_GE_256-NEXT: ld1b { z0.b }, p0/z, [x0]
612608
; VBITS_GE_256-NEXT: ptrue p0.s, vl8
613609
; VBITS_GE_256-NEXT: ext v1.16b, v0.16b, v0.16b, #8
@@ -638,11 +634,10 @@ define void @masked_load_zext_v16i8i32(ptr %ap, ptr %bp, ptr %c) #0 {
638634
define void @masked_load_zext_v8i8i64(ptr %ap, ptr %bp, ptr %c) #0 {
639635
; VBITS_GE_256-LABEL: masked_load_zext_v8i8i64:
640636
; VBITS_GE_256: // %bb.0:
641-
; VBITS_GE_256-NEXT: ldr d0, [x1]
642637
; VBITS_GE_256-NEXT: ptrue p0.b, vl8
638+
; VBITS_GE_256-NEXT: ldr d0, [x1]
643639
; VBITS_GE_256-NEXT: mov x8, #4 // =0x4
644-
; VBITS_GE_256-NEXT: cmeq v0.8b, v0.8b, #0
645-
; VBITS_GE_256-NEXT: cmpne p0.b, p0/z, z0.b, #0
640+
; VBITS_GE_256-NEXT: cmpeq p0.b, p0/z, z0.b, #0
646641
; VBITS_GE_256-NEXT: ld1b { z0.b }, p0/z, [x0]
647642
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
648643
; VBITS_GE_256-NEXT: ushll v0.8h, v0.8b, #0
@@ -706,11 +701,10 @@ define void @masked_load_zext_v16i16i32(ptr %ap, ptr %bp, ptr %c) #0 {
706701
define void @masked_load_zext_v8i16i64(ptr %ap, ptr %bp, ptr %c) #0 {
707702
; VBITS_GE_256-LABEL: masked_load_zext_v8i16i64:
708703
; VBITS_GE_256: // %bb.0:
709-
; VBITS_GE_256-NEXT: ldr q0, [x1]
710704
; VBITS_GE_256-NEXT: ptrue p0.h, vl8
705+
; VBITS_GE_256-NEXT: ldr q0, [x1]
711706
; VBITS_GE_256-NEXT: mov x8, #4 // =0x4
712-
; VBITS_GE_256-NEXT: cmeq v0.8h, v0.8h, #0
713-
; VBITS_GE_256-NEXT: cmpne p0.h, p0/z, z0.h, #0
707+
; VBITS_GE_256-NEXT: cmpeq p0.h, p0/z, z0.h, #0
714708
; VBITS_GE_256-NEXT: ld1h { z0.h }, p0/z, [x0]
715709
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
716710
; VBITS_GE_256-NEXT: ext v1.16b, v0.16b, v0.16b, #8

llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -433,11 +433,10 @@ define void @masked_scatter_v1i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
433433
define void @masked_scatter_v2i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
434434
; CHECK-LABEL: masked_scatter_v2i64:
435435
; CHECK: // %bb.0:
436-
; CHECK-NEXT: ldr q0, [x0]
437436
; CHECK-NEXT: ptrue p0.d, vl2
438-
; CHECK-NEXT: cmeq v1.2d, v0.2d, #0
439-
; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, #0
437+
; CHECK-NEXT: ldr q0, [x0]
440438
; CHECK-NEXT: ldr q1, [x1]
439+
; CHECK-NEXT: cmpeq p0.d, p0/z, z0.d, #0
441440
; CHECK-NEXT: st1d { z0.d }, p0, [z1.d]
442441
; CHECK-NEXT: ret
443442
%vals = load <2 x i64>, ptr %a

0 commit comments

Comments
 (0)