|
18 | 18 | #include "llvm/ADT/Statistic.h"
|
19 | 19 | #include "llvm/Analysis/ValueTracking.h"
|
20 | 20 | #include "llvm/CodeGen/TargetPassConfig.h"
|
| 21 | +#include "llvm/IR/IRBuilder.h" |
21 | 22 | #include "llvm/IR/InstVisitor.h"
|
22 |
| -#include "llvm/IR/PatternMatch.h" |
| 23 | +#include "llvm/IR/Intrinsics.h" |
23 | 24 | #include "llvm/InitializePasses.h"
|
24 | 25 | #include "llvm/Pass.h"
|
25 | 26 |
|
@@ -51,6 +52,7 @@ class RISCVCodeGenPrepare : public FunctionPass,
|
51 | 52 |
|
52 | 53 | bool visitInstruction(Instruction &I) { return false; }
|
53 | 54 | bool visitAnd(BinaryOperator &BO);
|
| 55 | + bool visitIntrinsicInst(IntrinsicInst &I); |
54 | 56 | };
|
55 | 57 |
|
56 | 58 | } // end anonymous namespace
|
@@ -103,6 +105,62 @@ bool RISCVCodeGenPrepare::visitAnd(BinaryOperator &BO) {
|
103 | 105 | return true;
|
104 | 106 | }
|
105 | 107 |
|
| 108 | +// LLVM vector reduction intrinsics return a scalar result, but on RISC-V vector |
| 109 | +// reduction instructions write the result in the first element of a vector |
| 110 | +// register. So when a reduction in a loop uses a scalar phi, we end up with |
| 111 | +// unnecessary scalar moves: |
| 112 | +// |
| 113 | +// loop: |
| 114 | +// vfmv.s.f v10, fa0 |
| 115 | +// vfredosum.vs v8, v8, v10 |
| 116 | +// vfmv.f.s fa0, v8 |
| 117 | +// |
| 118 | +// This mainly affects ordered fadd reductions, since other types of reduction |
| 119 | +// typically use element-wise vectorisation in the loop body. This tries to |
| 120 | +// vectorize any scalar phis that feed into a fadd reduction: |
| 121 | +// |
| 122 | +// loop: |
| 123 | +// %phi = phi <float> [ ..., %entry ], [ %acc, %loop ] |
| 124 | +// %acc = call float @llvm.vector.reduce.fadd.nxv4f32(float %phi, <vscale x 2 x float> %vec) |
| 125 | +// |
| 126 | +// -> |
| 127 | +// |
| 128 | +// loop: |
| 129 | +// %phi = phi <vscale x 2 x float> [ ..., %entry ], [ %acc.vec, %loop ] |
| 130 | +// %phi.scalar = extractelement <vscale x 2 x float> %phi, i64 0 |
| 131 | +// %acc = call float @llvm.vector.reduce.fadd.nxv4f32(float %x, <vscale x 2 x float> %vec) |
| 132 | +// %acc.vec = insertelement <vscale x 2 x float> poison, float %acc.next, i64 0 |
| 133 | +// |
| 134 | +// Which eliminates the scalar -> vector -> scalar crossing during instruction |
| 135 | +// selection. |
| 136 | +bool RISCVCodeGenPrepare::visitIntrinsicInst(IntrinsicInst &I) { |
| 137 | + if (I.getIntrinsicID() != Intrinsic::vector_reduce_fadd) |
| 138 | + return false; |
| 139 | + |
| 140 | + auto *PHI = dyn_cast<PHINode>(I.getOperand(0)); |
| 141 | + if (!PHI || !PHI->hasOneUse() || |
| 142 | + !llvm::is_contained(PHI->incoming_values(), &I)) |
| 143 | + return false; |
| 144 | + |
| 145 | + Type *VecTy = I.getOperand(1)->getType(); |
| 146 | + IRBuilder<> Builder(PHI); |
| 147 | + auto *VecPHI = Builder.CreatePHI(VecTy, PHI->getNumIncomingValues()); |
| 148 | + |
| 149 | + for (auto *BB : PHI->blocks()) { |
| 150 | + Builder.SetInsertPoint(BB->getTerminator()); |
| 151 | + Value *InsertElt = Builder.CreateInsertElement( |
| 152 | + VecTy, PHI->getIncomingValueForBlock(BB), (uint64_t)0); |
| 153 | + VecPHI->addIncoming(InsertElt, BB); |
| 154 | + } |
| 155 | + |
| 156 | + Builder.SetInsertPoint(&I); |
| 157 | + I.setOperand(0, Builder.CreateExtractElement(VecPHI, (uint64_t)0)); |
| 158 | + |
| 159 | + PHI->eraseFromParent(); |
| 160 | + |
| 161 | + return true; |
| 162 | +} |
| 163 | + |
106 | 164 | bool RISCVCodeGenPrepare::runOnFunction(Function &F) {
|
107 | 165 | if (skipFunction(F))
|
108 | 166 | return false;
|
|
0 commit comments