Skip to content

Commit 85cf958

Browse files
authored
[AArch64] Improve codegen for some fixed-width partial reductions (#126529)
This patch teaches optimizeExtendOrTruncateConversion to bail out if the user of a zero-extend is a partial reduction intrinsic that we know will get lowered efficiently to a udot instruction.
1 parent 547a8bc commit 85cf958

File tree

2 files changed

+255
-7
lines changed

2 files changed

+255
-7
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2055,8 +2055,9 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
20552055

20562056
bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
20572057
const IntrinsicInst *I) const {
2058-
if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add)
2059-
return true;
2058+
assert(I->getIntrinsicID() ==
2059+
Intrinsic::experimental_vector_partial_reduce_add &&
2060+
"Unexpected intrinsic!");
20602061
if (EnablePartialReduceNodes)
20612062
return true;
20622063

@@ -16890,9 +16891,16 @@ bool AArch64TargetLowering::optimizeExtendOrTruncateConversion(
1689016891
// mul(zext(i8), sext) can be transformed into smull(zext, sext) which
1689116892
// performs one extend implicitly. If DstWidth is at most 4 * SrcWidth, at
1689216893
// most one extra extend step is needed and using tbl is not profitable.
16894+
// Similarly, bail out if partial_reduce(acc, zext(i8)) can be lowered to a
16895+
// udot instruction.
1689316896
if (SrcWidth * 4 <= DstWidth && I->hasOneUser()) {
1689416897
auto *SingleUser = cast<Instruction>(*I->user_begin());
16895-
if (match(SingleUser, m_c_Mul(m_Specific(I), m_SExt(m_Value()))))
16898+
if (match(SingleUser, m_c_Mul(m_Specific(I), m_SExt(m_Value()))) ||
16899+
(match(SingleUser,
16900+
m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
16901+
m_Value(), m_Specific(I))) &&
16902+
!shouldExpandPartialReductionIntrinsic(
16903+
cast<IntrinsicInst>(SingleUser))))
1689616904
return false;
1689716905
}
1689816906

llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll

Lines changed: 244 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-NOI8MM
33
; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM,CHECK-NODOT
44
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-I8MM
5-
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM,CHECK-NODOT
5+
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM
66

77
define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
88
; CHECK-DOT-LABEL: udot:
@@ -27,6 +27,66 @@ define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
2727
ret <4 x i32> %partial.reduce
2828
}
2929

30+
define <4 x i32> @udot_in_loop(ptr %p1, ptr %p2){
31+
; CHECK-DOT-LABEL: udot_in_loop:
32+
; CHECK-DOT: // %bb.0: // %entry
33+
; CHECK-DOT-NEXT: movi v1.2d, #0000000000000000
34+
; CHECK-DOT-NEXT: mov x8, xzr
35+
; CHECK-DOT-NEXT: .LBB1_1: // %vector.body
36+
; CHECK-DOT-NEXT: // =>This Inner Loop Header: Depth=1
37+
; CHECK-DOT-NEXT: ldr q2, [x0, x8]
38+
; CHECK-DOT-NEXT: ldr q3, [x1, x8]
39+
; CHECK-DOT-NEXT: mov v0.16b, v1.16b
40+
; CHECK-DOT-NEXT: add x8, x8, #16
41+
; CHECK-DOT-NEXT: udot v1.4s, v2.16b, v3.16b
42+
; CHECK-DOT-NEXT: cmp x8, #16
43+
; CHECK-DOT-NEXT: b.ne .LBB1_1
44+
; CHECK-DOT-NEXT: // %bb.2: // %end
45+
; CHECK-DOT-NEXT: ret
46+
;
47+
; CHECK-NODOT-LABEL: udot_in_loop:
48+
; CHECK-NODOT: // %bb.0: // %entry
49+
; CHECK-NODOT-NEXT: movi v1.2d, #0000000000000000
50+
; CHECK-NODOT-NEXT: mov x8, xzr
51+
; CHECK-NODOT-NEXT: .LBB1_1: // %vector.body
52+
; CHECK-NODOT-NEXT: // =>This Inner Loop Header: Depth=1
53+
; CHECK-NODOT-NEXT: ldr q0, [x0, x8]
54+
; CHECK-NODOT-NEXT: ldr q2, [x1, x8]
55+
; CHECK-NODOT-NEXT: add x8, x8, #16
56+
; CHECK-NODOT-NEXT: cmp x8, #16
57+
; CHECK-NODOT-NEXT: umull v3.8h, v0.8b, v2.8b
58+
; CHECK-NODOT-NEXT: umull2 v2.8h, v0.16b, v2.16b
59+
; CHECK-NODOT-NEXT: mov v0.16b, v1.16b
60+
; CHECK-NODOT-NEXT: ushll v1.4s, v2.4h, #0
61+
; CHECK-NODOT-NEXT: uaddw v4.4s, v0.4s, v3.4h
62+
; CHECK-NODOT-NEXT: uaddw2 v1.4s, v1.4s, v3.8h
63+
; CHECK-NODOT-NEXT: uaddw2 v2.4s, v4.4s, v2.8h
64+
; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s
65+
; CHECK-NODOT-NEXT: b.ne .LBB1_1
66+
; CHECK-NODOT-NEXT: // %bb.2: // %end
67+
; CHECK-NODOT-NEXT: ret
68+
entry:
69+
br label %vector.body
70+
71+
vector.body:
72+
%index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
73+
%acc = phi <4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce, %vector.body ]
74+
%gep1 = getelementptr i8, ptr %p1, i64 %index
75+
%load1 = load <16 x i8>, ptr %gep1, align 16
76+
%load1.wide = zext <16 x i8> %load1 to <16 x i32>
77+
%gep2 = getelementptr i8, ptr %p2, i64 %index
78+
%load2 = load <16 x i8>, ptr %gep2, align 16
79+
%load2.wide = zext <16 x i8> %load2 to <16 x i32>
80+
%mul = mul nuw nsw <16 x i32> %load1.wide, %load2.wide
81+
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mul)
82+
%index.next = add nuw i64 %index, 16
83+
%cmp = icmp eq i64 %index.next, 16
84+
br i1 %cmp, label %end, label %vector.body
85+
86+
end:
87+
ret <4 x i32> %acc
88+
}
89+
3090
define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
3191
; CHECK-DOT-LABEL: udot_narrow:
3292
; CHECK-DOT: // %bb.0:
@@ -129,6 +189,68 @@ define <4 x i32> @usdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
129189
ret <4 x i32> %partial.reduce
130190
}
131191

192+
define <4 x i32> @usdot_in_loop(ptr %p1, ptr %p2){
193+
; CHECK-NOI8MM-LABEL: usdot_in_loop:
194+
; CHECK-NOI8MM: // %bb.0: // %entry
195+
; CHECK-NOI8MM-NEXT: movi v1.2d, #0000000000000000
196+
; CHECK-NOI8MM-NEXT: mov x8, xzr
197+
; CHECK-NOI8MM-NEXT: .LBB6_1: // %vector.body
198+
; CHECK-NOI8MM-NEXT: // =>This Inner Loop Header: Depth=1
199+
; CHECK-NOI8MM-NEXT: ldr q0, [x0, x8]
200+
; CHECK-NOI8MM-NEXT: ldr q2, [x1, x8]
201+
; CHECK-NOI8MM-NEXT: add x8, x8, #16
202+
; CHECK-NOI8MM-NEXT: cmp x8, #16
203+
; CHECK-NOI8MM-NEXT: sshll v3.8h, v0.8b, #0
204+
; CHECK-NOI8MM-NEXT: sshll2 v4.8h, v0.16b, #0
205+
; CHECK-NOI8MM-NEXT: ushll v5.8h, v2.8b, #0
206+
; CHECK-NOI8MM-NEXT: ushll2 v2.8h, v2.16b, #0
207+
; CHECK-NOI8MM-NEXT: mov v0.16b, v1.16b
208+
; CHECK-NOI8MM-NEXT: smlal v1.4s, v3.4h, v5.4h
209+
; CHECK-NOI8MM-NEXT: smull v6.4s, v4.4h, v2.4h
210+
; CHECK-NOI8MM-NEXT: smlal2 v1.4s, v4.8h, v2.8h
211+
; CHECK-NOI8MM-NEXT: smlal2 v6.4s, v3.8h, v5.8h
212+
; CHECK-NOI8MM-NEXT: add v1.4s, v6.4s, v1.4s
213+
; CHECK-NOI8MM-NEXT: b.ne .LBB6_1
214+
; CHECK-NOI8MM-NEXT: // %bb.2: // %end
215+
; CHECK-NOI8MM-NEXT: ret
216+
;
217+
; CHECK-I8MM-LABEL: usdot_in_loop:
218+
; CHECK-I8MM: // %bb.0: // %entry
219+
; CHECK-I8MM-NEXT: movi v1.2d, #0000000000000000
220+
; CHECK-I8MM-NEXT: mov x8, xzr
221+
; CHECK-I8MM-NEXT: .LBB6_1: // %vector.body
222+
; CHECK-I8MM-NEXT: // =>This Inner Loop Header: Depth=1
223+
; CHECK-I8MM-NEXT: ldr q2, [x0, x8]
224+
; CHECK-I8MM-NEXT: ldr q3, [x1, x8]
225+
; CHECK-I8MM-NEXT: mov v0.16b, v1.16b
226+
; CHECK-I8MM-NEXT: add x8, x8, #16
227+
; CHECK-I8MM-NEXT: usdot v1.4s, v3.16b, v2.16b
228+
; CHECK-I8MM-NEXT: cmp x8, #16
229+
; CHECK-I8MM-NEXT: b.ne .LBB6_1
230+
; CHECK-I8MM-NEXT: // %bb.2: // %end
231+
; CHECK-I8MM-NEXT: ret
232+
entry:
233+
br label %vector.body
234+
235+
vector.body:
236+
%index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
237+
%acc = phi <4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce, %vector.body ]
238+
%gep1 = getelementptr i8, ptr %p1, i64 %index
239+
%load1 = load <16 x i8>, ptr %gep1, align 16
240+
%load1.wide = sext <16 x i8> %load1 to <16 x i32>
241+
%gep2 = getelementptr i8, ptr %p2, i64 %index
242+
%load2 = load <16 x i8>, ptr %gep2, align 16
243+
%load2.wide = zext <16 x i8> %load2 to <16 x i32>
244+
%mul = mul nuw nsw <16 x i32> %load1.wide, %load2.wide
245+
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mul)
246+
%index.next = add nuw i64 %index, 16
247+
%cmp = icmp eq i64 %index.next, 16
248+
br i1 %cmp, label %end, label %vector.body
249+
250+
end:
251+
ret <4 x i32> %acc
252+
}
253+
132254
define <2 x i32> @usdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
133255
; CHECK-NOI8MM-LABEL: usdot_narrow:
134256
; CHECK-NOI8MM: // %bb.0:
@@ -176,13 +298,75 @@ define <4 x i32> @sudot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
176298
; CHECK-I8MM: // %bb.0:
177299
; CHECK-I8MM-NEXT: usdot v0.4s, v2.16b, v1.16b
178300
; CHECK-I8MM-NEXT: ret
179-
%u.wide = sext <16 x i8> %u to <16 x i32>
180-
%s.wide = zext <16 x i8> %s to <16 x i32>
181-
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
301+
%s.wide = sext <16 x i8> %u to <16 x i32>
302+
%u.wide = zext <16 x i8> %s to <16 x i32>
303+
%mult = mul nuw nsw <16 x i32> %u.wide, %s.wide
182304
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult)
183305
ret <4 x i32> %partial.reduce
184306
}
185307

308+
define <4 x i32> @sudot_in_loop(ptr %p1, ptr %p2){
309+
; CHECK-NOI8MM-LABEL: sudot_in_loop:
310+
; CHECK-NOI8MM: // %bb.0: // %entry
311+
; CHECK-NOI8MM-NEXT: movi v1.2d, #0000000000000000
312+
; CHECK-NOI8MM-NEXT: mov x8, xzr
313+
; CHECK-NOI8MM-NEXT: .LBB9_1: // %vector.body
314+
; CHECK-NOI8MM-NEXT: // =>This Inner Loop Header: Depth=1
315+
; CHECK-NOI8MM-NEXT: ldr q0, [x0, x8]
316+
; CHECK-NOI8MM-NEXT: ldr q2, [x1, x8]
317+
; CHECK-NOI8MM-NEXT: add x8, x8, #16
318+
; CHECK-NOI8MM-NEXT: cmp x8, #16
319+
; CHECK-NOI8MM-NEXT: ushll v3.8h, v0.8b, #0
320+
; CHECK-NOI8MM-NEXT: ushll2 v4.8h, v0.16b, #0
321+
; CHECK-NOI8MM-NEXT: sshll v5.8h, v2.8b, #0
322+
; CHECK-NOI8MM-NEXT: sshll2 v2.8h, v2.16b, #0
323+
; CHECK-NOI8MM-NEXT: mov v0.16b, v1.16b
324+
; CHECK-NOI8MM-NEXT: smlal v1.4s, v3.4h, v5.4h
325+
; CHECK-NOI8MM-NEXT: smull v6.4s, v4.4h, v2.4h
326+
; CHECK-NOI8MM-NEXT: smlal2 v1.4s, v4.8h, v2.8h
327+
; CHECK-NOI8MM-NEXT: smlal2 v6.4s, v3.8h, v5.8h
328+
; CHECK-NOI8MM-NEXT: add v1.4s, v6.4s, v1.4s
329+
; CHECK-NOI8MM-NEXT: b.ne .LBB9_1
330+
; CHECK-NOI8MM-NEXT: // %bb.2: // %end
331+
; CHECK-NOI8MM-NEXT: ret
332+
;
333+
; CHECK-I8MM-LABEL: sudot_in_loop:
334+
; CHECK-I8MM: // %bb.0: // %entry
335+
; CHECK-I8MM-NEXT: movi v1.2d, #0000000000000000
336+
; CHECK-I8MM-NEXT: mov x8, xzr
337+
; CHECK-I8MM-NEXT: .LBB9_1: // %vector.body
338+
; CHECK-I8MM-NEXT: // =>This Inner Loop Header: Depth=1
339+
; CHECK-I8MM-NEXT: ldr q2, [x0, x8]
340+
; CHECK-I8MM-NEXT: ldr q3, [x1, x8]
341+
; CHECK-I8MM-NEXT: mov v0.16b, v1.16b
342+
; CHECK-I8MM-NEXT: add x8, x8, #16
343+
; CHECK-I8MM-NEXT: usdot v1.4s, v2.16b, v3.16b
344+
; CHECK-I8MM-NEXT: cmp x8, #16
345+
; CHECK-I8MM-NEXT: b.ne .LBB9_1
346+
; CHECK-I8MM-NEXT: // %bb.2: // %end
347+
; CHECK-I8MM-NEXT: ret
348+
entry:
349+
br label %vector.body
350+
351+
vector.body:
352+
%index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
353+
%acc = phi <4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce, %vector.body ]
354+
%gep1 = getelementptr i8, ptr %p1, i64 %index
355+
%load1 = load <16 x i8>, ptr %gep1, align 16
356+
%load1.wide = zext <16 x i8> %load1 to <16 x i32>
357+
%gep2 = getelementptr i8, ptr %p2, i64 %index
358+
%load2 = load <16 x i8>, ptr %gep2, align 16
359+
%load2.wide = sext <16 x i8> %load2 to <16 x i32>
360+
%mul = mul nuw nsw <16 x i32> %load1.wide, %load2.wide
361+
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mul)
362+
%index.next = add nuw i64 %index, 16
363+
%cmp = icmp eq i64 %index.next, 16
364+
br i1 %cmp, label %end, label %vector.body
365+
366+
end:
367+
ret <4 x i32> %acc
368+
}
369+
186370
define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
187371
; CHECK-NOI8MM-LABEL: sudot_narrow:
188372
; CHECK-NOI8MM: // %bb.0:
@@ -390,6 +574,62 @@ define <4 x i32> @udot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
390574
ret <4 x i32> %partial.reduce
391575
}
392576

577+
define <4 x i32> @udot_no_bin_op_in_loop(ptr %p){
578+
; CHECK-DOT-LABEL: udot_no_bin_op_in_loop:
579+
; CHECK-DOT: // %bb.0: // %entry
580+
; CHECK-DOT-NEXT: movi v1.2d, #0000000000000000
581+
; CHECK-DOT-NEXT: movi v2.16b, #1
582+
; CHECK-DOT-NEXT: mov x8, xzr
583+
; CHECK-DOT-NEXT: .LBB16_1: // %vector.body
584+
; CHECK-DOT-NEXT: // =>This Inner Loop Header: Depth=1
585+
; CHECK-DOT-NEXT: ldr q3, [x0, x8]
586+
; CHECK-DOT-NEXT: mov v0.16b, v1.16b
587+
; CHECK-DOT-NEXT: add x8, x8, #16
588+
; CHECK-DOT-NEXT: cmp x8, #16
589+
; CHECK-DOT-NEXT: udot v1.4s, v3.16b, v2.16b
590+
; CHECK-DOT-NEXT: b.ne .LBB16_1
591+
; CHECK-DOT-NEXT: // %bb.2: // %end
592+
; CHECK-DOT-NEXT: ret
593+
;
594+
; CHECK-NODOT-LABEL: udot_no_bin_op_in_loop:
595+
; CHECK-NODOT: // %bb.0: // %entry
596+
; CHECK-NODOT-NEXT: movi v1.2d, #0000000000000000
597+
; CHECK-NODOT-NEXT: mov x8, xzr
598+
; CHECK-NODOT-NEXT: .LBB16_1: // %vector.body
599+
; CHECK-NODOT-NEXT: // =>This Inner Loop Header: Depth=1
600+
; CHECK-NODOT-NEXT: ldr q0, [x0, x8]
601+
; CHECK-NODOT-NEXT: add x8, x8, #16
602+
; CHECK-NODOT-NEXT: cmp x8, #16
603+
; CHECK-NODOT-NEXT: ushll v2.8h, v0.8b, #0
604+
; CHECK-NODOT-NEXT: ushll2 v3.8h, v0.16b, #0
605+
; CHECK-NODOT-NEXT: mov v0.16b, v1.16b
606+
; CHECK-NODOT-NEXT: ushll v1.4s, v3.4h, #0
607+
; CHECK-NODOT-NEXT: uaddw v4.4s, v0.4s, v2.4h
608+
; CHECK-NODOT-NEXT: uaddw2 v1.4s, v1.4s, v2.8h
609+
; CHECK-NODOT-NEXT: uaddw2 v2.4s, v4.4s, v3.8h
610+
; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s
611+
; CHECK-NODOT-NEXT: b.ne .LBB16_1
612+
; CHECK-NODOT-NEXT: // %bb.2: // %end
613+
; CHECK-NODOT-NEXT: ret
614+
615+
entry:
616+
br label %vector.body
617+
618+
vector.body:
619+
%index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
620+
%acc = phi <4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce, %vector.body ]
621+
%gep = getelementptr i8, ptr %p, i64 %index
622+
%load = load <16 x i8>, ptr %gep, align 16
623+
%load.wide = zext <16 x i8> %load to <16 x i32>
624+
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %load.wide)
625+
%index.next = add nuw i64 %index, 16
626+
%cmp = icmp eq i64 %index.next, 16
627+
br i1 %cmp, label %end, label %vector.body
628+
629+
end:
630+
ret <4 x i32> %acc
631+
}
632+
393633
define <4 x i32> @sdot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
394634
; CHECK-DOT-LABEL: sdot_no_bin_op:
395635
; CHECK-DOT: // %bb.0:

0 commit comments

Comments
 (0)