Skip to content

Commit 70995a1

Browse files
authored
[ScalarizeMaskedMemIntr] Optimize splat non-constant masks (#104537)
In cases (like the ones added in the tests) where the condition of a masked load or store is a splat but not a constant (that is, a masked operation is being used to implement patterns like "load if the current lane is in-bounds, otherwise return 0"), optimize the 'scalarized' code to perform an aligned vector load/store if the splat constant is true. Additionally, take a few steps to preserve aliasing information and names when nothing is scalarized while I'm here. As motivation, some LLVM IR users will genatate masked load/store in cases that map to this kind of predicated operation (where either the vector is loaded/stored or it isn't) in order to take advantage of hardware primitives, but on AMDGPU, where we don't have a masked load or store, this pass would scalarize a load or store that was intended to be - and can be - vectorized while also introducing expensive branches. Fixes #104520 Pre-commit tests at #104527
1 parent f33d519 commit 70995a1

File tree

5 files changed

+90
-917
lines changed

5 files changed

+90
-917
lines changed

llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "llvm/ADT/Twine.h"
1818
#include "llvm/Analysis/DomTreeUpdater.h"
1919
#include "llvm/Analysis/TargetTransformInfo.h"
20+
#include "llvm/Analysis/VectorUtils.h"
2021
#include "llvm/IR/BasicBlock.h"
2122
#include "llvm/IR/Constant.h"
2223
#include "llvm/IR/Constants.h"
@@ -161,7 +162,9 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
161162

162163
// Short-cut if the mask is all-true.
163164
if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
164-
Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
165+
LoadInst *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
166+
NewI->copyMetadata(*CI);
167+
NewI->takeName(CI);
165168
CI->replaceAllUsesWith(NewI);
166169
CI->eraseFromParent();
167170
return;
@@ -188,8 +191,39 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
188191
return;
189192
}
190193

194+
// Optimize the case where the "masked load" is a predicated load - that is,
195+
// where the mask is the splat of a non-constant scalar boolean. In that case,
196+
// use that splated value as the guard on a conditional vector load.
197+
if (isSplatValue(Mask, /*Index=*/0)) {
198+
Value *Predicate = Builder.CreateExtractElement(Mask, uint64_t(0ull),
199+
Mask->getName() + ".first");
200+
Instruction *ThenTerm =
201+
SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
202+
/*BranchWeights=*/nullptr, DTU);
203+
204+
BasicBlock *CondBlock = ThenTerm->getParent();
205+
CondBlock->setName("cond.load");
206+
Builder.SetInsertPoint(CondBlock->getTerminator());
207+
LoadInst *Load = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal,
208+
CI->getName() + ".cond.load");
209+
Load->copyMetadata(*CI);
210+
211+
BasicBlock *PostLoad = ThenTerm->getSuccessor(0);
212+
Builder.SetInsertPoint(PostLoad, PostLoad->begin());
213+
PHINode *Phi = Builder.CreatePHI(VecType, /*NumReservedValues=*/2);
214+
Phi->addIncoming(Load, CondBlock);
215+
Phi->addIncoming(Src0, IfBlock);
216+
Phi->takeName(CI);
217+
218+
CI->replaceAllUsesWith(Phi);
219+
CI->eraseFromParent();
220+
ModifiedDT = true;
221+
return;
222+
}
191223
// If the mask is not v1i1, use scalar bit test operations. This generates
192224
// better results on X86 at least.
225+
// Note: this produces worse code on AMDGPU, where the "i1" is implicitly SIMD
226+
// - what's a good way to detect this?
193227
Value *SclrMask;
194228
if (VectorWidth != 1) {
195229
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
@@ -297,7 +331,9 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
297331

298332
// Short-cut if the mask is all-true.
299333
if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
300-
Builder.CreateAlignedStore(Src, Ptr, AlignVal);
334+
StoreInst *Store = Builder.CreateAlignedStore(Src, Ptr, AlignVal);
335+
Store->takeName(CI);
336+
Store->copyMetadata(*CI);
301337
CI->eraseFromParent();
302338
return;
303339
}
@@ -319,8 +355,31 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
319355
return;
320356
}
321357

358+
// Optimize the case where the "masked store" is a predicated store - that is,
359+
// when the mask is the splat of a non-constant scalar boolean. In that case,
360+
// optimize to a conditional store.
361+
if (isSplatValue(Mask, /*Index=*/0)) {
362+
Value *Predicate = Builder.CreateExtractElement(Mask, uint64_t(0ull),
363+
Mask->getName() + ".first");
364+
Instruction *ThenTerm =
365+
SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
366+
/*BranchWeights=*/nullptr, DTU);
367+
BasicBlock *CondBlock = ThenTerm->getParent();
368+
CondBlock->setName("cond.store");
369+
Builder.SetInsertPoint(CondBlock->getTerminator());
370+
371+
StoreInst *Store = Builder.CreateAlignedStore(Src, Ptr, AlignVal);
372+
Store->takeName(CI);
373+
Store->copyMetadata(*CI);
374+
375+
CI->eraseFromParent();
376+
ModifiedDT = true;
377+
return;
378+
}
379+
322380
// If the mask is not v1i1, use scalar bit test operations. This generates
323381
// better results on X86 at least.
382+
324383
Value *SclrMask;
325384
if (VectorWidth != 1) {
326385
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
@@ -997,7 +1056,6 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
9971056
any_of(II->args(),
9981057
[](Value *V) { return isa<ScalableVectorType>(V->getType()); }))
9991058
return false;
1000-
10011059
switch (II->getIntrinsicID()) {
10021060
default:
10031061
break;

0 commit comments

Comments
 (0)