15
15
#include " llvm/CodeGen/ReplaceWithVeclib.h"
16
16
#include " llvm/ADT/STLExtras.h"
17
17
#include " llvm/ADT/Statistic.h"
18
+ #include " llvm/ADT/StringRef.h"
18
19
#include " llvm/Analysis/DemandedBits.h"
19
20
#include " llvm/Analysis/GlobalsModRef.h"
20
21
#include " llvm/Analysis/OptimizationRemarkEmitter.h"
21
22
#include " llvm/Analysis/TargetLibraryInfo.h"
22
23
#include " llvm/Analysis/VectorUtils.h"
23
24
#include " llvm/CodeGen/Passes.h"
25
+ #include " llvm/IR/DerivedTypes.h"
24
26
#include " llvm/IR/IRBuilder.h"
25
27
#include " llvm/IR/InstIterator.h"
28
+ #include " llvm/Support/TypeSize.h"
26
29
#include " llvm/Transforms/Utils/ModuleUtils.h"
27
30
28
31
using namespace llvm ;
@@ -38,138 +41,137 @@ STATISTIC(NumTLIFuncDeclAdded,
38
41
STATISTIC (NumFuncUsedAdded,
39
42
" Number of functions added to `llvm.compiler.used`" );
40
43
41
- static bool replaceWithTLIFunction (CallInst &CI, const StringRef TLIName) {
42
- Module *M = CI.getModule ();
43
-
44
- Function *OldFunc = CI.getCalledFunction ();
45
-
46
- // Check if the vector library function is already declared in this module,
47
- // otherwise insert it.
44
+ // / Returns a vector Function that it adds to the Module \p M. When an \p
45
+ // / ScalarFunc is not null, it copies its attributes to the newly created
46
+ // / Function.
47
+ Function *getTLIFunction (Module *M, FunctionType *VectorFTy,
48
+ const StringRef TLIName,
49
+ Function *ScalarFunc = nullptr ) {
48
50
Function *TLIFunc = M->getFunction (TLIName);
49
51
if (!TLIFunc) {
50
- TLIFunc = Function::Create (OldFunc->getFunctionType (),
51
- Function::ExternalLinkage, TLIName, *M);
52
- TLIFunc->copyAttributesFrom (OldFunc);
52
+ TLIFunc =
53
+ Function::Create (VectorFTy, Function::ExternalLinkage, TLIName, *M);
54
+ if (ScalarFunc)
55
+ TLIFunc->copyAttributesFrom (ScalarFunc);
53
56
54
57
LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Added vector library function `"
55
58
<< TLIName << " ` of type `" << *(TLIFunc->getType ())
56
59
<< " ` to module.\n " );
57
60
58
61
++NumTLIFuncDeclAdded;
59
-
60
- // Add the freshly created function to llvm.compiler.used,
61
- // similar to as it is done in InjectTLIMappings
62
+ // Add the freshly created function to llvm.compiler.used, similar to as it
63
+ // is done in InjectTLIMappings.
62
64
appendToCompilerUsed (*M, {TLIFunc});
63
-
64
65
LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Adding `" << TLIName
65
66
<< " ` to `@llvm.compiler.used`.\n " );
66
67
++NumFuncUsedAdded;
67
68
}
69
+ return TLIFunc;
70
+ }
68
71
69
- // Replace the call to the vector intrinsic with a call
70
- // to the corresponding function from the vector library.
71
- IRBuilder<> IRBuilder (&CI);
72
- SmallVector<Value *> Args (CI.args ());
73
- // Preserve the operand bundles.
74
- SmallVector<OperandBundleDef, 1 > OpBundles;
75
- CI.getOperandBundlesAsDefs (OpBundles);
76
- CallInst *Replacement = IRBuilder.CreateCall (TLIFunc, Args, OpBundles);
77
- assert (OldFunc->getFunctionType () == TLIFunc->getFunctionType () &&
78
- " Expecting function types to be identical" );
79
- CI.replaceAllUsesWith (Replacement);
80
- if (isa<FPMathOperator>(Replacement)) {
81
- // Preserve fast math flags for FP math.
82
- Replacement->copyFastMathFlags (&CI);
72
+ // / Replace the call to the vector intrinsic ( \p CalltoReplace ) with a call to
73
+ // / the corresponding function from the vector library ( \p TLIVecFunc ).
74
+ static void replaceWithTLIFunction (CallInst &CalltoReplace, VFInfo &Info,
75
+ Function *TLIVecFunc) {
76
+ IRBuilder<> IRBuilder (&CalltoReplace);
77
+ SmallVector<Value *> Args (CalltoReplace.args ());
78
+ if (auto OptMaskpos = Info.getParamIndexForOptionalMask ()) {
79
+ auto *MaskTy = VectorType::get (Type::getInt1Ty (CalltoReplace.getContext ()),
80
+ Info.Shape .VF );
81
+ Args.insert (Args.begin () + OptMaskpos.value (),
82
+ Constant::getAllOnesValue (MaskTy));
83
83
}
84
84
85
- LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Replaced call to `"
86
- << OldFunc->getName () << " ` with call to `" << TLIName
87
- << " `.\n " );
88
- ++NumCallsReplaced;
89
- return true ;
85
+ // Preserve the operand bundles.
86
+ SmallVector<OperandBundleDef, 1 > OpBundles;
87
+ CalltoReplace.getOperandBundlesAsDefs (OpBundles);
88
+ CallInst *Replacement = IRBuilder.CreateCall (TLIVecFunc, Args, OpBundles);
89
+ CalltoReplace.replaceAllUsesWith (Replacement);
90
+ // Preserve fast math flags for FP math.
91
+ if (isa<FPMathOperator>(Replacement))
92
+ Replacement->copyFastMathFlags (&CalltoReplace);
90
93
}
91
94
95
+ // / Returns true when successfully replaced \p CallToReplace with a suitable
96
+ // / function taking vector arguments, based on available mappings in the \p TLI.
97
+ // / Currently only works when \p CallToReplace is a call to vectorized
98
+ // / intrinsic.
92
99
static bool replaceWithCallToVeclib (const TargetLibraryInfo &TLI,
93
- CallInst &CI ) {
94
- if (!CI .getCalledFunction ()) {
100
+ CallInst &CallToReplace ) {
101
+ if (!CallToReplace .getCalledFunction ())
95
102
return false ;
96
- }
97
103
98
- auto IntrinsicID = CI .getCalledFunction ()->getIntrinsicID ();
99
- if (IntrinsicID == Intrinsic::not_intrinsic) {
100
- // Replacement is only performed for intrinsic functions
104
+ auto IntrinsicID = CallToReplace .getCalledFunction ()->getIntrinsicID ();
105
+ // Replacement is only performed for intrinsic functions.
106
+ if (IntrinsicID == Intrinsic::not_intrinsic)
101
107
return false ;
102
- }
103
108
104
- // Convert vector arguments to scalar type and check that
105
- // all vector operands have identical vector width .
109
+ // Compute arguments types of the corresponding scalar call. Additionally
110
+ // checks if in the vector call, all vector operands have the same EC .
106
111
ElementCount VF = ElementCount::getFixed (0 );
107
- SmallVector<Type *> ScalarTypes;
108
- for (auto Arg : enumerate(CI.args ())) {
109
- auto *ArgType = Arg.value ()->getType ();
110
- // Vector calls to intrinsics can still have
111
- // scalar operands for specific arguments.
112
+ SmallVector<Type *> ScalarArgTypes;
113
+ for (auto Arg : enumerate(CallToReplace.args ())) {
114
+ auto *ArgTy = Arg.value ()->getType ();
112
115
if (isVectorIntrinsicWithScalarOpAtArg (IntrinsicID, Arg.index ())) {
113
- ScalarTypes.push_back (ArgType);
114
- } else {
115
- // The argument in this place should be a vector if
116
- // this is a call to a vector intrinsic.
117
- auto *VectorArgTy = dyn_cast<VectorType>(ArgType);
118
- if (!VectorArgTy) {
119
- // The argument is not a vector, do not perform
120
- // the replacement.
121
- return false ;
122
- }
123
- ElementCount NumElements = VectorArgTy->getElementCount ();
124
- if (NumElements.isScalable ()) {
125
- // The current implementation does not support
126
- // scalable vectors.
116
+ ScalarArgTypes.push_back (ArgTy);
117
+ } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
118
+ ScalarArgTypes.push_back (ArgTy->getScalarType ());
119
+ // Disallow vector arguments with different VFs. When processing the first
120
+ // vector argument, store it's VF, and for the rest ensure that they match
121
+ // it.
122
+ if (VF.isZero ())
123
+ VF = VectorArgTy->getElementCount ();
124
+ else if (VF != VectorArgTy->getElementCount ())
127
125
return false ;
128
- }
129
- if (VF.isNonZero () && VF != NumElements) {
130
- // The different arguments differ in vector size.
131
- return false ;
132
- } else {
133
- VF = NumElements;
134
- }
135
- ScalarTypes.push_back (VectorArgTy->getElementType ());
136
- }
126
+ } else
127
+ // Exit when it is supposed to be a vector argument but it isn't.
128
+ return false ;
137
129
}
138
130
139
- // Try to reconstruct the name for the scalar version of this
140
- // intrinsic using the intrinsic ID and the argument types
141
- // converted to scalar above.
142
- std::string ScalarName;
143
- if (Intrinsic::isOverloaded (IntrinsicID)) {
144
- ScalarName = Intrinsic::getName (IntrinsicID, ScalarTypes, CI.getModule ());
145
- } else {
146
- ScalarName = Intrinsic::getName (IntrinsicID).str ();
147
- }
131
+ // Try to reconstruct the name for the scalar version of this intrinsic using
132
+ // the intrinsic ID and the argument types converted to scalar above.
133
+ std::string ScalarName =
134
+ (Intrinsic::isOverloaded (IntrinsicID)
135
+ ? Intrinsic::getName (IntrinsicID, ScalarArgTypes,
136
+ CallToReplace.getModule ())
137
+ : Intrinsic::getName (IntrinsicID).str ());
138
+
139
+ // Try to find the mapping for the scalar version of this intrinsic and the
140
+ // exact vector width of the call operands in the TargetLibraryInfo. First,
141
+ // check with a non-masked variant, and if that fails try with a masked one.
142
+ const VecDesc *VD =
143
+ TLI.getVectorMappingInfo (ScalarName, VF, /* Masked*/ false );
144
+ if (!VD && !(VD = TLI.getVectorMappingInfo (ScalarName, VF, /* Masked*/ true )))
145
+ return false ;
148
146
149
- if (!TLI.isFunctionVectorizable (ScalarName)) {
150
- // The TargetLibraryInfo does not contain a vectorized version of
151
- // the scalar function.
147
+ LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Found TLI mapping from: `" << ScalarName
148
+ << " ` and vector width " << VF << " to: `"
149
+ << VD->getVectorFnName () << " `.\n " );
150
+
151
+ // Replace the call to the intrinsic with a call to the vector library
152
+ // function.
153
+ Type *ScalarRetTy = CallToReplace.getType ()->getScalarType ();
154
+ FunctionType *ScalarFTy =
155
+ FunctionType::get (ScalarRetTy, ScalarArgTypes, /* isVarArg*/ false );
156
+ const std::string MangledName = VD->getVectorFunctionABIVariantString ();
157
+ auto OptInfo = VFABI::tryDemangleForVFABI (MangledName, ScalarFTy);
158
+ if (!OptInfo)
152
159
return false ;
153
- }
154
160
155
- // Try to find the mapping for the scalar version of this intrinsic
156
- // and the exact vector width of the call operands in the
157
- // TargetLibraryInfo.
158
- StringRef TLIName = TLI.getVectorizedFunction (ScalarName, VF);
159
-
160
- LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Looking up TLI mapping for `"
161
- << ScalarName << " ` and vector width " << VF << " .\n " );
162
-
163
- if (!TLIName.empty ()) {
164
- // Found the correct mapping in the TargetLibraryInfo,
165
- // replace the call to the intrinsic with a call to
166
- // the vector library function.
167
- LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Found TLI function `" << TLIName
168
- << " `.\n " );
169
- return replaceWithTLIFunction (CI, TLIName);
170
- }
161
+ FunctionType *VectorFTy = VFABI::createFunctionType (*OptInfo, ScalarFTy);
162
+ if (!VectorFTy)
163
+ return false ;
164
+
165
+ Function *FuncToReplace = CallToReplace.getCalledFunction ();
166
+ Function *TLIFunc = getTLIFunction (CallToReplace.getModule (), VectorFTy,
167
+ VD->getVectorFnName (), FuncToReplace);
168
+ replaceWithTLIFunction (CallToReplace, *OptInfo, TLIFunc);
171
169
172
- return false ;
170
+ LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Replaced call to `"
171
+ << FuncToReplace->getName () << " ` with call to `"
172
+ << TLIFunc->getName () << " `.\n " );
173
+ ++NumCallsReplaced;
174
+ return true ;
173
175
}
174
176
175
177
static bool runImpl (const TargetLibraryInfo &TLI, Function &F) {
@@ -185,9 +187,8 @@ static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
185
187
}
186
188
// Erase the calls to the intrinsics that have been replaced
187
189
// with calls to the vector library.
188
- for (auto *CI : ReplacedCalls) {
190
+ for (auto *CI : ReplacedCalls)
189
191
CI->eraseFromParent ();
190
- }
191
192
return Changed;
192
193
}
193
194
@@ -207,10 +208,10 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F,
207
208
PA.preserve <DemandedBitsAnalysis>();
208
209
PA.preserve <OptimizationRemarkEmitterAnalysis>();
209
210
return PA;
210
- } else {
211
- // The pass did not replace any calls, hence it preserves all analyses.
212
- return PreservedAnalyses::all ();
213
211
}
212
+
213
+ // The pass did not replace any calls, hence it preserves all analyses.
214
+ return PreservedAnalyses::all ();
214
215
}
215
216
216
217
// //////////////////////////////////////////////////////////////////////////////
0 commit comments