Skip to content

Commit 15b0fab

Browse files
authored
[RISCV] Vectorize phi for loop carried @llvm.vector.reduce.fadd (llvm#78244)
LLVM vector reduction intrinsics return a scalar result, but on RISC-V vector reduction instructions write the result in the first element of a vector register. So when a reduction in a loop uses a scalar phi, we end up with unnecessary scalar moves: loop: vfmv.s.f v10, fa0 vfredosum.vs v8, v8, v10 vfmv.f.s fa0, v8 This mainly affects ordered fadd reductions, which has a scalar accumulator operand. This tries to vectorize any scalar phis that feed into a fadd reduction in RISCVCodeGenPrepare, converting: loop: %phi = phi <float> [ ..., %entry ], [ %acc, %loop] %acc = call float @llvm.vector.reduce.fadd.nxv4f32(float %phi, <vscale x 2 x float> %vec) ``` to loop: %phi = phi <vscale x 2 x float> [ ..., %entry ], [ %acc.vec, %loop] %phi.scalar = extractelement <vscale x 2 x float> %phi, i64 0 %acc = call float @llvm.vector.reduce.fadd.nxv4f32(float %x, <vscale x 2 x float> %vec) %acc.vec = insertelement <vscale x 2 x float> poison, float %acc.next, i64 0 Which eliminates the scalar -> vector -> scalar crossing during instruction selection.
1 parent 085eae6 commit 15b0fab

File tree

3 files changed

+149
-1
lines changed

3 files changed

+149
-1
lines changed

llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
#include "llvm/ADT/Statistic.h"
1919
#include "llvm/Analysis/ValueTracking.h"
2020
#include "llvm/CodeGen/TargetPassConfig.h"
21+
#include "llvm/IR/IRBuilder.h"
2122
#include "llvm/IR/InstVisitor.h"
22-
#include "llvm/IR/PatternMatch.h"
23+
#include "llvm/IR/Intrinsics.h"
2324
#include "llvm/InitializePasses.h"
2425
#include "llvm/Pass.h"
2526

@@ -51,6 +52,7 @@ class RISCVCodeGenPrepare : public FunctionPass,
5152

5253
bool visitInstruction(Instruction &I) { return false; }
5354
bool visitAnd(BinaryOperator &BO);
55+
bool visitIntrinsicInst(IntrinsicInst &I);
5456
};
5557

5658
} // end anonymous namespace
@@ -103,6 +105,62 @@ bool RISCVCodeGenPrepare::visitAnd(BinaryOperator &BO) {
103105
return true;
104106
}
105107

108+
// LLVM vector reduction intrinsics return a scalar result, but on RISC-V vector
109+
// reduction instructions write the result in the first element of a vector
110+
// register. So when a reduction in a loop uses a scalar phi, we end up with
111+
// unnecessary scalar moves:
112+
//
113+
// loop:
114+
// vfmv.s.f v10, fa0
115+
// vfredosum.vs v8, v8, v10
116+
// vfmv.f.s fa0, v8
117+
//
118+
// This mainly affects ordered fadd reductions, since other types of reduction
119+
// typically use element-wise vectorisation in the loop body. This tries to
120+
// vectorize any scalar phis that feed into a fadd reduction:
121+
//
122+
// loop:
123+
// %phi = phi <float> [ ..., %entry ], [ %acc, %loop ]
124+
// %acc = call float @llvm.vector.reduce.fadd.nxv4f32(float %phi, <vscale x 2 x float> %vec)
125+
//
126+
// ->
127+
//
128+
// loop:
129+
// %phi = phi <vscale x 2 x float> [ ..., %entry ], [ %acc.vec, %loop ]
130+
// %phi.scalar = extractelement <vscale x 2 x float> %phi, i64 0
131+
// %acc = call float @llvm.vector.reduce.fadd.nxv4f32(float %x, <vscale x 2 x float> %vec)
132+
// %acc.vec = insertelement <vscale x 2 x float> poison, float %acc.next, i64 0
133+
//
134+
// Which eliminates the scalar -> vector -> scalar crossing during instruction
135+
// selection.
136+
bool RISCVCodeGenPrepare::visitIntrinsicInst(IntrinsicInst &I) {
137+
if (I.getIntrinsicID() != Intrinsic::vector_reduce_fadd)
138+
return false;
139+
140+
auto *PHI = dyn_cast<PHINode>(I.getOperand(0));
141+
if (!PHI || !PHI->hasOneUse() ||
142+
!llvm::is_contained(PHI->incoming_values(), &I))
143+
return false;
144+
145+
Type *VecTy = I.getOperand(1)->getType();
146+
IRBuilder<> Builder(PHI);
147+
auto *VecPHI = Builder.CreatePHI(VecTy, PHI->getNumIncomingValues());
148+
149+
for (auto *BB : PHI->blocks()) {
150+
Builder.SetInsertPoint(BB->getTerminator());
151+
Value *InsertElt = Builder.CreateInsertElement(
152+
VecTy, PHI->getIncomingValueForBlock(BB), (uint64_t)0);
153+
VecPHI->addIncoming(InsertElt, BB);
154+
}
155+
156+
Builder.SetInsertPoint(&I);
157+
I.setOperand(0, Builder.CreateExtractElement(VecPHI, (uint64_t)0));
158+
159+
PHI->eraseFromParent();
160+
161+
return true;
162+
}
163+
106164
bool RISCVCodeGenPrepare::runOnFunction(Function &F) {
107165
if (skipFunction(F))
108166
return false;
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
2+
; RUN: llc < %s -mtriple=riscv64 -mattr=+v | FileCheck %s
3+
4+
declare i64 @llvm.vscale.i64()
5+
declare float @llvm.vector.reduce.fadd.nxv4f32(float, <vscale x 4 x float>)
6+
7+
define float @reduce_fadd(ptr %f) {
8+
; CHECK-LABEL: reduce_fadd:
9+
; CHECK: # %bb.0: # %entry
10+
; CHECK-NEXT: csrr a2, vlenb
11+
; CHECK-NEXT: srli a1, a2, 1
12+
; CHECK-NEXT: vsetvli a3, zero, e32, m1, ta, ma
13+
; CHECK-NEXT: vmv.s.x v8, zero
14+
; CHECK-NEXT: slli a2, a2, 1
15+
; CHECK-NEXT: li a3, 1024
16+
; CHECK-NEXT: .LBB0_1: # %vector.body
17+
; CHECK-NEXT: # =>This Inner Loop Header: Depth=1
18+
; CHECK-NEXT: vl2re32.v v10, (a0)
19+
; CHECK-NEXT: vsetvli a4, zero, e32, m2, ta, ma
20+
; CHECK-NEXT: vfredosum.vs v8, v10, v8
21+
; CHECK-NEXT: sub a3, a3, a1
22+
; CHECK-NEXT: add a0, a0, a2
23+
; CHECK-NEXT: bnez a3, .LBB0_1
24+
; CHECK-NEXT: # %bb.2: # %exit
25+
; CHECK-NEXT: vfmv.f.s fa0, v8
26+
; CHECK-NEXT: ret
27+
entry:
28+
%vscale = tail call i64 @llvm.vscale.i64()
29+
%vecsize = shl nuw nsw i64 %vscale, 2
30+
br label %vector.body
31+
32+
vector.body:
33+
%index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
34+
%vec.phi = phi float [ 0.000000e+00, %entry ], [ %acc, %vector.body ]
35+
%gep = getelementptr inbounds float, ptr %f, i64 %index
36+
%wide.load = load <vscale x 4 x float>, ptr %gep, align 4
37+
%acc = tail call float @llvm.vector.reduce.fadd.nxv4f32(float %vec.phi, <vscale x 4 x float> %wide.load)
38+
%index.next = add nuw i64 %index, %vecsize
39+
%done = icmp eq i64 %index.next, 1024
40+
br i1 %done, label %exit, label %vector.body
41+
42+
exit:
43+
ret float %acc
44+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
2+
; RUN: opt %s -S -riscv-codegenprepare -mtriple=riscv64 -mattr=+v | FileCheck %s
3+
4+
declare i64 @llvm.vscale.i64()
5+
declare float @llvm.vector.reduce.fadd.nxv4f32(float, <vscale x 4 x float>)
6+
7+
define float @reduce_fadd(ptr %f) {
8+
; CHECK-LABEL: define float @reduce_fadd(
9+
; CHECK-SAME: ptr [[F:%.*]]) #[[ATTR2:[0-9]+]] {
10+
; CHECK-NEXT: entry:
11+
; CHECK-NEXT: [[VSCALE:%.*]] = tail call i64 @llvm.vscale.i64()
12+
; CHECK-NEXT: [[VECSIZE:%.*]] = shl nuw nsw i64 [[VSCALE]], 2
13+
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
14+
; CHECK: vector.body:
15+
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
16+
; CHECK-NEXT: [[TMP0:%.*]] = phi <vscale x 4 x float> [ insertelement (<vscale x 4 x float> poison, float 0.000000e+00, i64 0), [[ENTRY]] ], [ [[TMP2:%.*]], [[VECTOR_BODY]] ]
17+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds float, ptr [[F]], i64 [[INDEX]]
18+
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 4 x float>, ptr [[GEP]], align 4
19+
; CHECK-NEXT: [[TMP1:%.*]] = extractelement <vscale x 4 x float> [[TMP0]], i64 0
20+
; CHECK-NEXT: [[ACC:%.*]] = tail call float @llvm.vector.reduce.fadd.nxv4f32(float [[TMP1]], <vscale x 4 x float> [[WIDE_LOAD]])
21+
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[VECSIZE]]
22+
; CHECK-NEXT: [[DONE:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
23+
; CHECK-NEXT: [[TMP2]] = insertelement <vscale x 4 x float> poison, float [[ACC]], i64 0
24+
; CHECK-NEXT: br i1 [[DONE]], label [[EXIT:%.*]], label [[VECTOR_BODY]]
25+
; CHECK: exit:
26+
; CHECK-NEXT: ret float [[ACC]]
27+
;
28+
29+
entry:
30+
%vscale = tail call i64 @llvm.vscale.i64()
31+
%vecsize = shl nuw nsw i64 %vscale, 2
32+
br label %vector.body
33+
34+
vector.body:
35+
%index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
36+
%vec.phi = phi float [ 0.000000e+00, %entry ], [ %acc, %vector.body ]
37+
%gep = getelementptr inbounds float, ptr %f, i64 %index
38+
%wide.load = load <vscale x 4 x float>, ptr %gep, align 4
39+
%acc = tail call float @llvm.vector.reduce.fadd.nxv4f32(float %vec.phi, <vscale x 4 x float> %wide.load)
40+
%index.next = add nuw i64 %index, %vecsize
41+
%done = icmp eq i64 %index.next, 1024
42+
br i1 %done, label %exit, label %vector.body
43+
44+
exit:
45+
ret float %acc
46+
}

0 commit comments

Comments
 (0)