Skip to content

Commit 3e0a76b

Browse files
authored
[Codegen][LegalizeIntegerTypes] Improve shift through stack (#96151)
Minor improvement on cc39c3b. Use an aligned stack slot to store the shifted value. Use the native register width as shifting unit, so the load of the shift result is aligned. If the shift amount is a multiple of the native register width, there is no need to do a follow-up shift after the load. I added new tests for these cases. Co-authored-by: Gergely Futo <[email protected]>
1 parent cd80ed4 commit 3e0a76b

23 files changed

+39402
-12935
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4608,14 +4608,23 @@ void DAGTypeLegalizer::ExpandIntRes_ShiftThroughStack(SDNode *N, SDValue &Lo,
46084608
SDValue ShAmt = N->getOperand(1);
46094609
EVT ShAmtVT = ShAmt.getValueType();
46104610

4611-
// This legalization is optimal when the shift is by a multiple of byte width,
4612-
// %x * 8 <-> %x << 3 so 3 low bits should be be known zero.
4613-
bool ShiftByByteMultiple =
4614-
DAG.computeKnownBits(ShAmt).countMinTrailingZeros() >= 3;
4611+
EVT LoadVT = VT;
4612+
do {
4613+
LoadVT = TLI.getTypeToTransformTo(*DAG.getContext(), LoadVT);
4614+
} while (!TLI.isTypeLegal(LoadVT));
4615+
4616+
const unsigned ShiftUnitInBits = LoadVT.getStoreSizeInBits();
4617+
assert(ShiftUnitInBits <= VT.getScalarSizeInBits());
4618+
assert(isPowerOf2_32(ShiftUnitInBits) &&
4619+
"Shifting unit is not a a power of two!");
4620+
4621+
const bool IsOneStepShift =
4622+
DAG.computeKnownBits(ShAmt).countMinTrailingZeros() >=
4623+
Log2_32(ShiftUnitInBits);
46154624

46164625
// If we can't do it as one step, we'll have two uses of shift amount,
46174626
// and thus must freeze it.
4618-
if (!ShiftByByteMultiple)
4627+
if (!IsOneStepShift)
46194628
ShAmt = DAG.getFreeze(ShAmt);
46204629

46214630
unsigned VTBitWidth = VT.getScalarSizeInBits();
@@ -4629,10 +4638,9 @@ void DAGTypeLegalizer::ExpandIntRes_ShiftThroughStack(SDNode *N, SDValue &Lo,
46294638

46304639
// Get a temporary stack slot 2x the width of our VT.
46314640
// FIXME: reuse stack slots?
4632-
// FIXME: should we be more picky about alignment?
4633-
Align StackSlotAlignment(1);
4634-
SDValue StackPtr = DAG.CreateStackTemporary(
4635-
TypeSize::getFixed(StackSlotByteWidth), StackSlotAlignment);
4641+
Align StackAlign = DAG.getReducedAlign(StackSlotVT, /*UseABI=*/false);
4642+
SDValue StackPtr =
4643+
DAG.CreateStackTemporary(StackSlotVT.getStoreSize(), StackAlign);
46364644
EVT PtrTy = StackPtr.getValueType();
46374645
SDValue Ch = DAG.getEntryNode();
46384646

@@ -4652,15 +4660,22 @@ void DAGTypeLegalizer::ExpandIntRes_ShiftThroughStack(SDNode *N, SDValue &Lo,
46524660
Init = DAG.getNode(ISD::BUILD_PAIR, dl, StackSlotVT, AllZeros, Shiftee);
46534661
}
46544662
// And spill it into the stack slot.
4655-
Ch = DAG.getStore(Ch, dl, Init, StackPtr, StackPtrInfo, StackSlotAlignment);
4663+
Ch = DAG.getStore(Ch, dl, Init, StackPtr, StackPtrInfo, StackAlign);
46564664

46574665
// Now, compute the full-byte offset into stack slot from where we can load.
4658-
// We have shift amount, which is in bits, but in multiples of byte.
4659-
// So just divide by CHAR_BIT.
4666+
// We have shift amount, which is in bits. Offset should point to an aligned
4667+
// address.
46604668
SDNodeFlags Flags;
4661-
if (ShiftByByteMultiple)
4662-
Flags.setExact(true);
4663-
SDValue ByteOffset = DAG.getNode(ISD::SRL, dl, ShAmtVT, ShAmt,
4669+
Flags.setExact(IsOneStepShift);
4670+
SDValue SrlTmp = DAG.getNode(
4671+
ISD::SRL, dl, ShAmtVT, ShAmt,
4672+
DAG.getConstant(Log2_32(ShiftUnitInBits), dl, ShAmtVT), Flags);
4673+
SDValue BitOffset =
4674+
DAG.getNode(ISD::SHL, dl, ShAmtVT, SrlTmp,
4675+
DAG.getConstant(Log2_32(ShiftUnitInBits), dl, ShAmtVT));
4676+
4677+
Flags.setExact(true);
4678+
SDValue ByteOffset = DAG.getNode(ISD::SRL, dl, ShAmtVT, BitOffset,
46644679
DAG.getConstant(3, dl, ShAmtVT), Flags);
46654680
// And clamp it, because OOB load is an immediate UB,
46664681
// while shift overflow would have *just* been poison.
@@ -4689,15 +4704,16 @@ void DAGTypeLegalizer::ExpandIntRes_ShiftThroughStack(SDNode *N, SDValue &Lo,
46894704
AdjStackPtr = DAG.getMemBasePlusOffset(AdjStackPtr, ByteOffset, dl);
46904705

46914706
// And load it! While the load is not legal, legalizing it is obvious.
4692-
SDValue Res = DAG.getLoad(
4693-
VT, dl, Ch, AdjStackPtr,
4694-
MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()), Align(1));
4695-
// We've performed the shift by a CHAR_BIT * [_ShAmt / CHAR_BIT_]
4696-
4697-
// If we may still have a less-than-CHAR_BIT to shift by, do so now.
4698-
if (!ShiftByByteMultiple) {
4699-
SDValue ShAmtRem = DAG.getNode(ISD::AND, dl, ShAmtVT, ShAmt,
4700-
DAG.getConstant(7, dl, ShAmtVT));
4707+
SDValue Res =
4708+
DAG.getLoad(VT, dl, Ch, AdjStackPtr,
4709+
MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()),
4710+
commonAlignment(StackAlign, LoadVT.getStoreSize()));
4711+
4712+
// If we may still have a remaining bits to shift by, do so now.
4713+
if (!IsOneStepShift) {
4714+
SDValue ShAmtRem =
4715+
DAG.getNode(ISD::AND, dl, ShAmtVT, ShAmt,
4716+
DAG.getConstant(ShiftUnitInBits - 1, dl, ShAmtVT));
47014717
Res = DAG.getNode(N->getOpcode(), dl, VT, Res, ShAmtRem);
47024718
}
47034719

llvm/test/CodeGen/AArch64/wide-scalar-shift-by-byte-multiple-legalization.ll

Lines changed: 146 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -186,23 +186,68 @@ define void @lshr_32bytes(ptr %src.ptr, ptr %byteOff.ptr, ptr %dst) nounwind {
186186
; ALL-NEXT: ldr q1, [x0]
187187
; ALL-NEXT: stp x9, x8, [sp, #16]
188188
; ALL-NEXT: mov x8, sp
189-
; ALL-NEXT: and x9, x10, #0x1f
189+
; ALL-NEXT: and x9, x10, #0x18
190190
; ALL-NEXT: str q1, [sp]
191191
; ALL-NEXT: add x8, x8, x9
192+
; ALL-NEXT: lsl x9, x10, #3
192193
; ALL-NEXT: stp q0, q0, [sp, #32]
194+
; ALL-NEXT: ldp x11, x10, [x8, #16]
195+
; ALL-NEXT: mvn w13, w9
196+
; ALL-NEXT: ldp x8, x12, [x8]
197+
; ALL-NEXT: and x9, x9, #0x38
198+
; ALL-NEXT: lsl x14, x10, #1
199+
; ALL-NEXT: lsl x15, x11, #1
200+
; ALL-NEXT: lsr x11, x11, x9
201+
; ALL-NEXT: lsl x16, x12, #1
202+
; ALL-NEXT: lsr x10, x10, x9
203+
; ALL-NEXT: lsr x12, x12, x9
204+
; ALL-NEXT: lsl x14, x14, x13
205+
; ALL-NEXT: lsr x8, x8, x9
206+
; ALL-NEXT: lsl x9, x16, x13
207+
; ALL-NEXT: lsl x13, x15, x13
208+
; ALL-NEXT: orr x11, x14, x11
209+
; ALL-NEXT: orr x8, x9, x8
210+
; ALL-NEXT: orr x9, x12, x13
211+
; ALL-NEXT: stp x11, x10, [x2, #16]
212+
; ALL-NEXT: stp x8, x9, [x2]
213+
; ALL-NEXT: add sp, sp, #64
214+
; ALL-NEXT: ret
215+
%src = load i256, ptr %src.ptr, align 1
216+
%byteOff = load i256, ptr %byteOff.ptr, align 1
217+
%bitOff = shl i256 %byteOff, 3
218+
%res = lshr i256 %src, %bitOff
219+
store i256 %res, ptr %dst, align 1
220+
ret void
221+
}
222+
223+
define void @lshr_32bytes_dwordOff(ptr %src.ptr, ptr %dwordOff.ptr, ptr %dst) nounwind {
224+
; ALL-LABEL: lshr_32bytes_dwordOff:
225+
; ALL: // %bb.0:
226+
; ALL-NEXT: sub sp, sp, #64
227+
; ALL-NEXT: ldp x9, x8, [x0, #16]
228+
; ALL-NEXT: movi v0.2d, #0000000000000000
229+
; ALL-NEXT: ldr x10, [x1]
230+
; ALL-NEXT: ldr q1, [x0]
231+
; ALL-NEXT: stp x9, x8, [sp, #16]
232+
; ALL-NEXT: ubfiz x8, x10, #3, #2
233+
; ALL-NEXT: mov x9, sp
234+
; ALL-NEXT: str q1, [sp]
235+
; ALL-NEXT: stp q0, q0, [sp, #32]
236+
; ALL-NEXT: add x8, x9, x8
193237
; ALL-NEXT: ldp x10, x9, [x8, #16]
194238
; ALL-NEXT: ldr q0, [x8]
195239
; ALL-NEXT: str q0, [x2]
196240
; ALL-NEXT: stp x10, x9, [x2, #16]
197241
; ALL-NEXT: add sp, sp, #64
198242
; ALL-NEXT: ret
199243
%src = load i256, ptr %src.ptr, align 1
200-
%byteOff = load i256, ptr %byteOff.ptr, align 1
201-
%bitOff = shl i256 %byteOff, 3
244+
%dwordOff = load i256, ptr %dwordOff.ptr, align 1
245+
%bitOff = shl i256 %dwordOff, 6
202246
%res = lshr i256 %src, %bitOff
203247
store i256 %res, ptr %dst, align 1
204248
ret void
205249
}
250+
206251
define void @shl_32bytes(ptr %src.ptr, ptr %byteOff.ptr, ptr %dst) nounwind {
207252
; ALL-LABEL: shl_32bytes:
208253
; ALL: // %bb.0:
@@ -213,48 +258,139 @@ define void @shl_32bytes(ptr %src.ptr, ptr %byteOff.ptr, ptr %dst) nounwind {
213258
; ALL-NEXT: ldr q1, [x0]
214259
; ALL-NEXT: stp x9, x8, [sp, #48]
215260
; ALL-NEXT: mov x8, sp
216-
; ALL-NEXT: and x9, x10, #0x1f
261+
; ALL-NEXT: and x9, x10, #0x18
217262
; ALL-NEXT: add x8, x8, #32
218263
; ALL-NEXT: stp q0, q0, [sp]
219264
; ALL-NEXT: str q1, [sp, #32]
220265
; ALL-NEXT: sub x8, x8, x9
266+
; ALL-NEXT: lsl x9, x10, #3
267+
; ALL-NEXT: ldp x10, x11, [x8]
268+
; ALL-NEXT: ldp x12, x8, [x8, #16]
269+
; ALL-NEXT: mvn w13, w9
270+
; ALL-NEXT: and x9, x9, #0x38
271+
; ALL-NEXT: lsr x14, x10, #1
272+
; ALL-NEXT: lsr x15, x11, #1
273+
; ALL-NEXT: lsl x11, x11, x9
274+
; ALL-NEXT: lsr x16, x12, #1
275+
; ALL-NEXT: lsl x10, x10, x9
276+
; ALL-NEXT: lsl x12, x12, x9
277+
; ALL-NEXT: lsr x14, x14, x13
278+
; ALL-NEXT: lsl x8, x8, x9
279+
; ALL-NEXT: lsr x9, x16, x13
280+
; ALL-NEXT: lsr x13, x15, x13
281+
; ALL-NEXT: orr x11, x11, x14
282+
; ALL-NEXT: orr x8, x8, x9
283+
; ALL-NEXT: orr x9, x12, x13
284+
; ALL-NEXT: stp x10, x11, [x2]
285+
; ALL-NEXT: stp x9, x8, [x2, #16]
286+
; ALL-NEXT: add sp, sp, #64
287+
; ALL-NEXT: ret
288+
%src = load i256, ptr %src.ptr, align 1
289+
%byteOff = load i256, ptr %byteOff.ptr, align 1
290+
%bitOff = shl i256 %byteOff, 3
291+
%res = shl i256 %src, %bitOff
292+
store i256 %res, ptr %dst, align 1
293+
ret void
294+
}
295+
296+
define void @shl_32bytes_dwordOff(ptr %src.ptr, ptr %dwordOff.ptr, ptr %dst) nounwind {
297+
; ALL-LABEL: shl_32bytes_dwordOff:
298+
; ALL: // %bb.0:
299+
; ALL-NEXT: sub sp, sp, #64
300+
; ALL-NEXT: ldp x9, x8, [x0, #16]
301+
; ALL-NEXT: movi v0.2d, #0000000000000000
302+
; ALL-NEXT: ldr x10, [x1]
303+
; ALL-NEXT: ldr q1, [x0]
304+
; ALL-NEXT: stp x9, x8, [sp, #48]
305+
; ALL-NEXT: mov x8, sp
306+
; ALL-NEXT: ubfiz x9, x10, #3, #2
307+
; ALL-NEXT: add x8, x8, #32
308+
; ALL-NEXT: stp q0, q1, [sp, #16]
309+
; ALL-NEXT: str q0, [sp]
310+
; ALL-NEXT: sub x8, x8, x9
221311
; ALL-NEXT: ldp x9, x10, [x8, #16]
222312
; ALL-NEXT: ldr q0, [x8]
223313
; ALL-NEXT: str q0, [x2]
224314
; ALL-NEXT: stp x9, x10, [x2, #16]
225315
; ALL-NEXT: add sp, sp, #64
226316
; ALL-NEXT: ret
227317
%src = load i256, ptr %src.ptr, align 1
228-
%byteOff = load i256, ptr %byteOff.ptr, align 1
229-
%bitOff = shl i256 %byteOff, 3
318+
%dwordOff = load i256, ptr %dwordOff.ptr, align 1
319+
%bitOff = shl i256 %dwordOff, 6
230320
%res = shl i256 %src, %bitOff
231321
store i256 %res, ptr %dst, align 1
232322
ret void
233323
}
324+
234325
define void @ashr_32bytes(ptr %src.ptr, ptr %byteOff.ptr, ptr %dst) nounwind {
235326
; ALL-LABEL: ashr_32bytes:
236327
; ALL: // %bb.0:
237328
; ALL-NEXT: sub sp, sp, #64
238329
; ALL-NEXT: ldp x9, x8, [x0, #16]
239330
; ALL-NEXT: ldr x10, [x1]
240331
; ALL-NEXT: ldr q0, [x0]
241-
; ALL-NEXT: and x10, x10, #0x1f
332+
; ALL-NEXT: and x11, x10, #0x18
242333
; ALL-NEXT: stp x9, x8, [sp, #16]
243334
; ALL-NEXT: asr x8, x8, #63
244335
; ALL-NEXT: mov x9, sp
245336
; ALL-NEXT: str q0, [sp]
337+
; ALL-NEXT: add x9, x9, x11
338+
; ALL-NEXT: stp x8, x8, [sp, #48]
339+
; ALL-NEXT: stp x8, x8, [sp, #32]
340+
; ALL-NEXT: lsl x8, x10, #3
341+
; ALL-NEXT: ldp x11, x10, [x9, #16]
342+
; ALL-NEXT: ldp x9, x12, [x9]
343+
; ALL-NEXT: mvn w13, w8
344+
; ALL-NEXT: and x8, x8, #0x38
345+
; ALL-NEXT: lsl x14, x10, #1
346+
; ALL-NEXT: lsl x15, x11, #1
347+
; ALL-NEXT: lsr x11, x11, x8
348+
; ALL-NEXT: lsl x16, x12, #1
349+
; ALL-NEXT: asr x10, x10, x8
350+
; ALL-NEXT: lsr x12, x12, x8
351+
; ALL-NEXT: lsl x14, x14, x13
352+
; ALL-NEXT: lsr x8, x9, x8
353+
; ALL-NEXT: lsl x9, x16, x13
354+
; ALL-NEXT: lsl x13, x15, x13
355+
; ALL-NEXT: orr x11, x14, x11
356+
; ALL-NEXT: orr x8, x9, x8
357+
; ALL-NEXT: orr x9, x12, x13
358+
; ALL-NEXT: stp x11, x10, [x2, #16]
359+
; ALL-NEXT: stp x8, x9, [x2]
360+
; ALL-NEXT: add sp, sp, #64
361+
; ALL-NEXT: ret
362+
%src = load i256, ptr %src.ptr, align 1
363+
%byteOff = load i256, ptr %byteOff.ptr, align 1
364+
%bitOff = shl i256 %byteOff, 3
365+
%res = ashr i256 %src, %bitOff
366+
store i256 %res, ptr %dst, align 1
367+
ret void
368+
}
369+
370+
define void @ashr_32bytes_dwordOff(ptr %src.ptr, ptr %dwordOff.ptr, ptr %dst) nounwind {
371+
; ALL-LABEL: ashr_32bytes_dwordOff:
372+
; ALL: // %bb.0:
373+
; ALL-NEXT: sub sp, sp, #64
374+
; ALL-NEXT: ldp x9, x8, [x0, #16]
375+
; ALL-NEXT: ldr x10, [x1]
376+
; ALL-NEXT: ldr q0, [x0]
377+
; ALL-NEXT: stp x9, x8, [sp, #16]
378+
; ALL-NEXT: asr x8, x8, #63
379+
; ALL-NEXT: ubfiz x9, x10, #3, #2
380+
; ALL-NEXT: mov x10, sp
381+
; ALL-NEXT: str q0, [sp]
246382
; ALL-NEXT: stp x8, x8, [sp, #48]
247383
; ALL-NEXT: stp x8, x8, [sp, #32]
248-
; ALL-NEXT: add x8, x9, x10
384+
; ALL-NEXT: add x8, x10, x9
249385
; ALL-NEXT: ldp x10, x9, [x8, #16]
250386
; ALL-NEXT: ldr q0, [x8]
251387
; ALL-NEXT: str q0, [x2]
252388
; ALL-NEXT: stp x10, x9, [x2, #16]
253389
; ALL-NEXT: add sp, sp, #64
254390
; ALL-NEXT: ret
255391
%src = load i256, ptr %src.ptr, align 1
256-
%byteOff = load i256, ptr %byteOff.ptr, align 1
257-
%bitOff = shl i256 %byteOff, 3
392+
%dwordOff = load i256, ptr %dwordOff.ptr, align 1
393+
%bitOff = shl i256 %dwordOff, 6
258394
%res = ashr i256 %src, %bitOff
259395
store i256 %res, ptr %dst, align 1
260396
ret void

0 commit comments

Comments
 (0)