Skip to content

Commit c7308d4

Browse files
authored
[LSR][AArch64] Optimize chain generation based on legal addressing modes (#94453)
LSR will generate chains of related instructions with a known increment between them. With SVE, in the case of the test case, this can include increments like 'vscale * 16 + 8'. The idea of this patch is if we have a '+8' increment already calculated in the chain, we can generate a (legal) '+ vscale*16' addressing mode from it, allowing us to use the '[x16, #1, mul vl]' addressing mode instructions. In order to do this we keep track of the known 'bases' when generating chains in GenerateIVChain, checking for each if the accumulated increment expression from the base neatly folds into a legal addressing mode. If they do not we fall back to the existing LeftOverExpr, whether it is legal or not. This is mostly orthogonal to #88124, dealing with the generation of chains as opposed to rest of LSR. The existing vscale addressing mode work has greatly helped compared to the last time I looked at this, allowing us to check that the addressing modes are indeed legal.
1 parent a9e5f42 commit c7308d4

File tree

2 files changed

+98
-60
lines changed

2 files changed

+98
-60
lines changed

llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp

+58-14
Original file line numberDiff line numberDiff line change
@@ -1256,7 +1256,8 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
12561256
LSRUse::KindType Kind, MemAccessTy AccessTy,
12571257
GlobalValue *BaseGV, int64_t BaseOffset,
12581258
bool HasBaseReg, int64_t Scale,
1259-
Instruction *Fixup = nullptr);
1259+
Instruction *Fixup = nullptr,
1260+
int64_t ScalableOffset = 0);
12601261

12611262
static unsigned getSetupCost(const SCEV *Reg, unsigned Depth) {
12621263
if (isa<SCEVUnknown>(Reg) || isa<SCEVConstant>(Reg))
@@ -1675,16 +1676,18 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
16751676
LSRUse::KindType Kind, MemAccessTy AccessTy,
16761677
GlobalValue *BaseGV, int64_t BaseOffset,
16771678
bool HasBaseReg, int64_t Scale,
1678-
Instruction *Fixup/*= nullptr*/) {
1679+
Instruction *Fixup /* = nullptr */,
1680+
int64_t ScalableOffset) {
16791681
switch (Kind) {
16801682
case LSRUse::Address:
16811683
return TTI.isLegalAddressingMode(AccessTy.MemTy, BaseGV, BaseOffset,
1682-
HasBaseReg, Scale, AccessTy.AddrSpace, Fixup);
1684+
HasBaseReg, Scale, AccessTy.AddrSpace,
1685+
Fixup, ScalableOffset);
16831686

16841687
case LSRUse::ICmpZero:
16851688
// There's not even a target hook for querying whether it would be legal to
16861689
// fold a GV into an ICmp.
1687-
if (BaseGV)
1690+
if (BaseGV || ScalableOffset != 0)
16881691
return false;
16891692

16901693
// ICmp only has two operands; don't allow more than two non-trivial parts.
@@ -1715,11 +1718,12 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
17151718

17161719
case LSRUse::Basic:
17171720
// Only handle single-register values.
1718-
return !BaseGV && Scale == 0 && BaseOffset == 0;
1721+
return !BaseGV && Scale == 0 && BaseOffset == 0 && ScalableOffset == 0;
17191722

17201723
case LSRUse::Special:
17211724
// Special case Basic to handle -1 scales.
1722-
return !BaseGV && (Scale == 0 || Scale == -1) && BaseOffset == 0;
1725+
return !BaseGV && (Scale == 0 || Scale == -1) && BaseOffset == 0 &&
1726+
ScalableOffset == 0;
17231727
}
17241728

17251729
llvm_unreachable("Invalid LSRUse Kind!");
@@ -1843,7 +1847,7 @@ static InstructionCost getScalingFactorCost(const TargetTransformInfo &TTI,
18431847
static bool isAlwaysFoldable(const TargetTransformInfo &TTI,
18441848
LSRUse::KindType Kind, MemAccessTy AccessTy,
18451849
GlobalValue *BaseGV, int64_t BaseOffset,
1846-
bool HasBaseReg) {
1850+
bool HasBaseReg, int64_t ScalableOffset = 0) {
18471851
// Fast-path: zero is always foldable.
18481852
if (BaseOffset == 0 && !BaseGV) return true;
18491853

@@ -1859,7 +1863,7 @@ static bool isAlwaysFoldable(const TargetTransformInfo &TTI,
18591863
}
18601864

18611865
return isAMCompletelyFolded(TTI, Kind, AccessTy, BaseGV, BaseOffset,
1862-
HasBaseReg, Scale);
1866+
HasBaseReg, Scale, nullptr, ScalableOffset);
18631867
}
18641868

18651869
static bool isAlwaysFoldable(const TargetTransformInfo &TTI,
@@ -3165,16 +3169,30 @@ void LSRInstance::FinalizeChain(IVChain &Chain) {
31653169
static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst,
31663170
Value *Operand, const TargetTransformInfo &TTI) {
31673171
const SCEVConstant *IncConst = dyn_cast<SCEVConstant>(IncExpr);
3168-
if (!IncConst || !isAddressUse(TTI, UserInst, Operand))
3169-
return false;
3172+
int64_t IncOffset = 0;
3173+
int64_t ScalableOffset = 0;
3174+
if (IncConst) {
3175+
if (IncConst && IncConst->getAPInt().getSignificantBits() > 64)
3176+
return false;
3177+
IncOffset = IncConst->getValue()->getSExtValue();
3178+
} else {
3179+
// Look for mul(vscale, constant), to detect ScalableOffset.
3180+
auto *IncVScale = dyn_cast<SCEVMulExpr>(IncExpr);
3181+
if (!IncVScale || IncVScale->getNumOperands() != 2 ||
3182+
!isa<SCEVVScale>(IncVScale->getOperand(1)))
3183+
return false;
3184+
auto *Scale = dyn_cast<SCEVConstant>(IncVScale->getOperand(0));
3185+
if (!Scale || Scale->getType()->getScalarSizeInBits() > 64)
3186+
return false;
3187+
ScalableOffset = Scale->getValue()->getSExtValue();
3188+
}
31703189

3171-
if (IncConst->getAPInt().getSignificantBits() > 64)
3190+
if (!isAddressUse(TTI, UserInst, Operand))
31723191
return false;
31733192

31743193
MemAccessTy AccessTy = getAccessType(TTI, UserInst, Operand);
3175-
int64_t IncOffset = IncConst->getValue()->getSExtValue();
31763194
if (!isAlwaysFoldable(TTI, LSRUse::Address, AccessTy, /*BaseGV=*/nullptr,
3177-
IncOffset, /*HasBaseReg=*/false))
3195+
IncOffset, /*HasBaseReg=*/false, ScalableOffset))
31783196
return false;
31793197

31803198
return true;
@@ -3220,6 +3238,10 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain,
32203238
Type *IVTy = IVSrc->getType();
32213239
Type *IntTy = SE.getEffectiveSCEVType(IVTy);
32223240
const SCEV *LeftOverExpr = nullptr;
3241+
const SCEV *Accum = SE.getZero(IntTy);
3242+
SmallVector<std::pair<const SCEV *, Value *>> Bases;
3243+
Bases.emplace_back(Accum, IVSrc);
3244+
32233245
for (const IVInc &Inc : Chain) {
32243246
Instruction *InsertPt = Inc.UserInst;
32253247
if (isa<PHINode>(InsertPt))
@@ -3232,10 +3254,31 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain,
32323254
// IncExpr was the result of subtraction of two narrow values, so must
32333255
// be signed.
32343256
const SCEV *IncExpr = SE.getNoopOrSignExtend(Inc.IncExpr, IntTy);
3257+
Accum = SE.getAddExpr(Accum, IncExpr);
32353258
LeftOverExpr = LeftOverExpr ?
32363259
SE.getAddExpr(LeftOverExpr, IncExpr) : IncExpr;
32373260
}
3238-
if (LeftOverExpr && !LeftOverExpr->isZero()) {
3261+
3262+
// Look through each base to see if any can produce a nice addressing mode.
3263+
bool FoundBase = false;
3264+
for (auto [MapScev, MapIVOper] : reverse(Bases)) {
3265+
const SCEV *Remainder = SE.getMinusSCEV(Accum, MapScev);
3266+
if (canFoldIVIncExpr(Remainder, Inc.UserInst, Inc.IVOperand, TTI)) {
3267+
if (!Remainder->isZero()) {
3268+
Rewriter.clearPostInc();
3269+
Value *IncV = Rewriter.expandCodeFor(Remainder, IntTy, InsertPt);
3270+
const SCEV *IVOperExpr =
3271+
SE.getAddExpr(SE.getUnknown(MapIVOper), SE.getUnknown(IncV));
3272+
IVOper = Rewriter.expandCodeFor(IVOperExpr, IVTy, InsertPt);
3273+
} else {
3274+
IVOper = MapIVOper;
3275+
}
3276+
3277+
FoundBase = true;
3278+
break;
3279+
}
3280+
}
3281+
if (!FoundBase && LeftOverExpr && !LeftOverExpr->isZero()) {
32393282
// Expand the IV increment.
32403283
Rewriter.clearPostInc();
32413284
Value *IncV = Rewriter.expandCodeFor(LeftOverExpr, IntTy, InsertPt);
@@ -3246,6 +3289,7 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain,
32463289
// If an IV increment can't be folded, use it as the next IV value.
32473290
if (!canFoldIVIncExpr(LeftOverExpr, Inc.UserInst, Inc.IVOperand, TTI)) {
32483291
assert(IVTy == IVOper->getType() && "inconsistent IV increment type");
3292+
Bases.emplace_back(Accum, IVOper);
32493293
IVSrc = IVOper;
32503294
LeftOverExpr = nullptr;
32513295
}

llvm/test/CodeGen/AArch64/sve-lsrchain.ll

+40-46
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,22 @@ define void @test(ptr nocapture noundef readonly %kernel, i32 noundef %kw, float
1414
; CHECK-NEXT: // %bb.2: // %for.body.us.preheader
1515
; CHECK-NEXT: ptrue p0.h
1616
; CHECK-NEXT: add x11, x2, x11, lsl #1
17-
; CHECK-NEXT: mov x12, #-16 // =0xfffffffffffffff0
18-
; CHECK-NEXT: ptrue p1.b
1917
; CHECK-NEXT: mov w8, wzr
18+
; CHECK-NEXT: ptrue p1.b
2019
; CHECK-NEXT: mov x9, xzr
2120
; CHECK-NEXT: mov w10, wzr
22-
; CHECK-NEXT: addvl x12, x12, #1
23-
; CHECK-NEXT: mov x13, #4 // =0x4
24-
; CHECK-NEXT: mov x14, #8 // =0x8
21+
; CHECK-NEXT: mov x12, #4 // =0x4
22+
; CHECK-NEXT: mov x13, #8 // =0x8
2523
; CHECK-NEXT: .LBB0_3: // %for.body.us
2624
; CHECK-NEXT: // =>This Loop Header: Depth=1
2725
; CHECK-NEXT: // Child Loop BB0_4 Depth 2
28-
; CHECK-NEXT: add x15, x0, x9, lsl #2
29-
; CHECK-NEXT: sbfiz x16, x8, #1, #32
30-
; CHECK-NEXT: mov x17, x2
31-
; CHECK-NEXT: ldp s0, s1, [x15]
32-
; CHECK-NEXT: add x16, x16, #8
33-
; CHECK-NEXT: ldp s2, s3, [x15, #8]
34-
; CHECK-NEXT: ubfiz x15, x8, #1, #32
26+
; CHECK-NEXT: add x14, x0, x9, lsl #2
27+
; CHECK-NEXT: sbfiz x15, x8, #1, #32
28+
; CHECK-NEXT: mov x16, x2
29+
; CHECK-NEXT: ldp s0, s1, [x14]
30+
; CHECK-NEXT: add x15, x15, #8
31+
; CHECK-NEXT: ldp s2, s3, [x14, #8]
32+
; CHECK-NEXT: ubfiz x14, x8, #1, #32
3533
; CHECK-NEXT: fcvt h0, s0
3634
; CHECK-NEXT: fcvt h1, s1
3735
; CHECK-NEXT: fcvt h2, s2
@@ -43,56 +41,52 @@ define void @test(ptr nocapture noundef readonly %kernel, i32 noundef %kw, float
4341
; CHECK-NEXT: .LBB0_4: // %for.cond.i.preheader.us
4442
; CHECK-NEXT: // Parent Loop BB0_3 Depth=1
4543
; CHECK-NEXT: // => This Inner Loop Header: Depth=2
46-
; CHECK-NEXT: ld1b { z4.b }, p1/z, [x17, x15]
47-
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17]
48-
; CHECK-NEXT: add x18, x17, x16
49-
; CHECK-NEXT: add x3, x17, x15
44+
; CHECK-NEXT: ld1b { z4.b }, p1/z, [x16, x14]
45+
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x16]
46+
; CHECK-NEXT: add x17, x16, x15
47+
; CHECK-NEXT: add x18, x16, x14
48+
; CHECK-NEXT: add x3, x17, #8
49+
; CHECK-NEXT: add x4, x17, #16
5050
; CHECK-NEXT: fmad z4.h, p0/m, z0.h, z5.h
51-
; CHECK-NEXT: ld1b { z5.b }, p1/z, [x17, x16]
51+
; CHECK-NEXT: ld1b { z5.b }, p1/z, [x16, x15]
5252
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z1.h
53-
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x13, lsl #1]
53+
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, x12, lsl #1]
5454
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z2.h
55-
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x14, lsl #1]
56-
; CHECK-NEXT: add x18, x18, #16
55+
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, x13, lsl #1]
5756
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z3.h
58-
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #1, mul vl]
59-
; CHECK-NEXT: st1h { z4.h }, p0, [x17]
60-
; CHECK-NEXT: ld1h { z4.h }, p0/z, [x3, #1, mul vl]
57+
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x16, #1, mul vl]
58+
; CHECK-NEXT: st1h { z4.h }, p0, [x16]
59+
; CHECK-NEXT: ld1h { z4.h }, p0/z, [x18, #1, mul vl]
6160
; CHECK-NEXT: fmad z4.h, p0/m, z0.h, z5.h
62-
; CHECK-NEXT: ld1b { z5.b }, p1/z, [x18, x12]
63-
; CHECK-NEXT: add x18, x18, x12
61+
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #1, mul vl]
6462
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z1.h
65-
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x13, lsl #1]
63+
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x3, #1, mul vl]
6664
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z2.h
67-
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x14, lsl #1]
68-
; CHECK-NEXT: add x18, x18, #16
65+
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x4, #1, mul vl]
6966
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z3.h
70-
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #2, mul vl]
71-
; CHECK-NEXT: st1h { z4.h }, p0, [x17, #1, mul vl]
72-
; CHECK-NEXT: ld1h { z4.h }, p0/z, [x3, #2, mul vl]
67+
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x16, #2, mul vl]
68+
; CHECK-NEXT: st1h { z4.h }, p0, [x16, #1, mul vl]
69+
; CHECK-NEXT: ld1h { z4.h }, p0/z, [x18, #2, mul vl]
7370
; CHECK-NEXT: fmad z4.h, p0/m, z0.h, z5.h
74-
; CHECK-NEXT: ld1b { z5.b }, p1/z, [x18, x12]
75-
; CHECK-NEXT: add x18, x18, x12
71+
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #2, mul vl]
7672
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z1.h
77-
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x13, lsl #1]
73+
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x3, #2, mul vl]
7874
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z2.h
79-
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x14, lsl #1]
80-
; CHECK-NEXT: add x18, x18, #16
75+
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x4, #2, mul vl]
8176
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z3.h
82-
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #3, mul vl]
83-
; CHECK-NEXT: st1h { z4.h }, p0, [x17, #2, mul vl]
84-
; CHECK-NEXT: ld1h { z4.h }, p0/z, [x3, #3, mul vl]
77+
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x16, #3, mul vl]
78+
; CHECK-NEXT: st1h { z4.h }, p0, [x16, #2, mul vl]
79+
; CHECK-NEXT: ld1h { z4.h }, p0/z, [x18, #3, mul vl]
8580
; CHECK-NEXT: fmad z4.h, p0/m, z0.h, z5.h
86-
; CHECK-NEXT: ld1b { z5.b }, p1/z, [x18, x12]
87-
; CHECK-NEXT: add x18, x18, x12
81+
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #3, mul vl]
8882
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z1.h
89-
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x13, lsl #1]
83+
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x3, #3, mul vl]
9084
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z2.h
91-
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x14, lsl #1]
85+
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x4, #3, mul vl]
9286
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z3.h
93-
; CHECK-NEXT: st1h { z4.h }, p0, [x17, #3, mul vl]
94-
; CHECK-NEXT: addvl x17, x17, #4
95-
; CHECK-NEXT: cmp x17, x11
87+
; CHECK-NEXT: st1h { z4.h }, p0, [x16, #3, mul vl]
88+
; CHECK-NEXT: addvl x16, x16, #4
89+
; CHECK-NEXT: cmp x16, x11
9690
; CHECK-NEXT: b.lo .LBB0_4
9791
; CHECK-NEXT: // %bb.5: // %while.cond.i..exit_crit_edge.us
9892
; CHECK-NEXT: // in Loop: Header=BB0_3 Depth=1

0 commit comments

Comments
 (0)