Skip to content

Commit 57f5d8f

Browse files
committed
[VPlan] Only store single vector per VPValue in VPTransformState. (NFC)
After 8ec4067 (#95842), VPTransformState only stores a single vector value per VPValue. Simplify the code by replacing the SmallVector in PerPartOutput with a single Value * and rename to VPV2Vector for clarity. Also remove the redundant Part argument from various accessors.
1 parent 31ac3d0 commit 57f5d8f

File tree

4 files changed

+145
-162
lines changed

4 files changed

+145
-162
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3075,13 +3075,13 @@ void InnerLoopVectorizer::fixNonInductionPHIs(VPlan &Plan,
30753075
VPWidenPHIRecipe *VPPhi = dyn_cast<VPWidenPHIRecipe>(&P);
30763076
if (!VPPhi)
30773077
continue;
3078-
PHINode *NewPhi = cast<PHINode>(State.get(VPPhi, 0));
3078+
PHINode *NewPhi = cast<PHINode>(State.get(VPPhi));
30793079
// Make sure the builder has a valid insert point.
30803080
Builder.SetInsertPoint(NewPhi);
30813081
for (unsigned Idx = 0; Idx < VPPhi->getNumOperands(); ++Idx) {
30823082
VPValue *Inc = VPPhi->getIncomingValue(Idx);
30833083
VPBasicBlock *VPBB = VPPhi->getIncomingBlock(Idx);
3084-
NewPhi->addIncoming(State.get(Inc, 0), State.CFG.VPBB2IRBB[VPBB]);
3084+
NewPhi->addIncoming(State.get(Inc), State.CFG.VPBB2IRBB[VPBB]);
30853085
}
30863086
}
30873087
}
@@ -9445,7 +9445,7 @@ void VPReplicateRecipe::execute(VPTransformState &State) {
94459445
assert(!State.VF.isScalable() && "VF is assumed to be non scalable.");
94469446
Value *Poison = PoisonValue::get(
94479447
VectorType::get(UI->getType(), State.VF));
9448-
State.set(this, Poison, State.Instance->Part);
9448+
State.set(this, Poison);
94499449
}
94509450
State.packScalarIntoVectorValue(this, *State.Instance);
94519451
}

llvm/lib/Transforms/Vectorize/VPlan.cpp

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,8 @@ Value *VPTransformState::get(VPValue *Def, const VPIteration &Instance) {
242242
return Data.PerPartScalars[Def][Instance.Part][0];
243243
}
244244

245-
assert(hasVectorValue(Def, Instance.Part));
246-
auto *VecPart = Data.PerPartOutput[Def][Instance.Part];
245+
assert(hasVectorValue(Def));
246+
auto *VecPart = Data.VPV2Vector[Def];
247247
if (!VecPart->getType()->isVectorTy()) {
248248
assert(Instance.Lane.isFirstLane() && "cannot get lane > 0 for scalar");
249249
return VecPart;
@@ -255,20 +255,20 @@ Value *VPTransformState::get(VPValue *Def, const VPIteration &Instance) {
255255
return Extract;
256256
}
257257

258-
Value *VPTransformState::get(VPValue *Def, unsigned Part, bool NeedsScalar) {
258+
Value *VPTransformState::get(VPValue *Def, bool NeedsScalar) {
259259
if (NeedsScalar) {
260-
assert((VF.isScalar() || Def->isLiveIn() || hasVectorValue(Def, Part) ||
260+
assert((VF.isScalar() || Def->isLiveIn() || hasVectorValue(Def) ||
261261
!vputils::onlyFirstLaneUsed(Def) ||
262-
(hasScalarValue(Def, VPIteration(Part, 0)) &&
263-
Data.PerPartScalars[Def][Part].size() == 1)) &&
262+
(hasScalarValue(Def, VPIteration(0, 0)) &&
263+
Data.PerPartScalars[Def][0].size() == 1)) &&
264264
"Trying to access a single scalar per part but has multiple scalars "
265265
"per part.");
266-
return get(Def, VPIteration(Part, 0));
266+
return get(Def, VPIteration(0, 0));
267267
}
268268

269269
// If Values have been set for this Def return the one relevant for \p Part.
270-
if (hasVectorValue(Def, Part))
271-
return Data.PerPartOutput[Def][Part];
270+
if (hasVectorValue(Def))
271+
return Data.VPV2Vector[Def];
272272

273273
auto GetBroadcastInstrs = [this, Def](Value *V) {
274274
bool SafeToHoist = Def->isDefinedOutsideLoopRegions();
@@ -290,29 +290,27 @@ Value *VPTransformState::get(VPValue *Def, unsigned Part, bool NeedsScalar) {
290290
return Shuf;
291291
};
292292

293-
if (!hasScalarValue(Def, {Part, 0})) {
293+
if (!hasScalarValue(Def, {0, 0})) {
294294
assert(Def->isLiveIn() && "expected a live-in");
295-
if (Part != 0)
296-
return get(Def, 0);
297295
Value *IRV = Def->getLiveInIRValue();
298296
Value *B = GetBroadcastInstrs(IRV);
299-
set(Def, B, Part);
297+
set(Def, B);
300298
return B;
301299
}
302300

303-
Value *ScalarValue = get(Def, {Part, 0});
301+
Value *ScalarValue = get(Def, {0, 0});
304302
// If we aren't vectorizing, we can just copy the scalar map values over
305303
// to the vector map.
306304
if (VF.isScalar()) {
307-
set(Def, ScalarValue, Part);
305+
set(Def, ScalarValue);
308306
return ScalarValue;
309307
}
310308

311309
bool IsUniform = vputils::isUniformAfterVectorization(Def);
312310

313311
unsigned LastLane = IsUniform ? 0 : VF.getKnownMinValue() - 1;
314312
// Check if there is a scalar value for the selected lane.
315-
if (!hasScalarValue(Def, {Part, LastLane})) {
313+
if (!hasScalarValue(Def, {0, LastLane})) {
316314
// At the moment, VPWidenIntOrFpInductionRecipes, VPScalarIVStepsRecipes and
317315
// VPExpandSCEVRecipes can also be uniform.
318316
assert((isa<VPWidenIntOrFpInductionRecipe>(Def->getDefiningRecipe()) ||
@@ -323,7 +321,7 @@ Value *VPTransformState::get(VPValue *Def, unsigned Part, bool NeedsScalar) {
323321
LastLane = 0;
324322
}
325323

326-
auto *LastInst = cast<Instruction>(get(Def, {Part, LastLane}));
324+
auto *LastInst = cast<Instruction>(get(Def, {0, LastLane}));
327325
// Set the insert point after the last scalarized instruction or after the
328326
// last PHI, if LastInst is a PHI. This ensures the insertelement sequence
329327
// will directly follow the scalar definitions.
@@ -343,15 +341,15 @@ Value *VPTransformState::get(VPValue *Def, unsigned Part, bool NeedsScalar) {
343341
Value *VectorValue = nullptr;
344342
if (IsUniform) {
345343
VectorValue = GetBroadcastInstrs(ScalarValue);
346-
set(Def, VectorValue, Part);
344+
set(Def, VectorValue);
347345
} else {
348346
// Initialize packing with insertelements to start from undef.
349347
assert(!VF.isScalable() && "VF is assumed to be non scalable.");
350348
Value *Undef = PoisonValue::get(VectorType::get(LastInst->getType(), VF));
351-
set(Def, Undef, Part);
349+
set(Def, Undef);
352350
for (unsigned Lane = 0; Lane < VF.getKnownMinValue(); ++Lane)
353-
packScalarIntoVectorValue(Def, {Part, Lane});
354-
VectorValue = get(Def, Part);
351+
packScalarIntoVectorValue(Def, {0, Lane});
352+
VectorValue = get(Def);
355353
}
356354
Builder.restoreIP(OldIP);
357355
return VectorValue;
@@ -406,10 +404,10 @@ void VPTransformState::setDebugLocFrom(DebugLoc DL) {
406404
void VPTransformState::packScalarIntoVectorValue(VPValue *Def,
407405
const VPIteration &Instance) {
408406
Value *ScalarInst = get(Def, Instance);
409-
Value *VectorValue = get(Def, Instance.Part);
407+
Value *VectorValue = get(Def);
410408
VectorValue = Builder.CreateInsertElement(
411409
VectorValue, ScalarInst, Instance.Lane.getAsRuntimeExpr(Builder, VF));
412-
set(Def, VectorValue, Instance.Part);
410+
set(Def, VectorValue);
413411
}
414412

415413
BasicBlock *
@@ -1074,12 +1072,12 @@ void VPlan::execute(VPTransformState *State) {
10741072
isa<VPWidenIntOrFpInductionRecipe>(&R)) {
10751073
PHINode *Phi = nullptr;
10761074
if (isa<VPWidenIntOrFpInductionRecipe>(&R)) {
1077-
Phi = cast<PHINode>(State->get(R.getVPSingleValue(), 0));
1075+
Phi = cast<PHINode>(State->get(R.getVPSingleValue()));
10781076
} else {
10791077
auto *WidenPhi = cast<VPWidenPointerInductionRecipe>(&R);
10801078
assert(!WidenPhi->onlyScalarsGenerated(State->VF.isScalable()) &&
10811079
"recipe generating only scalars should have been replaced");
1082-
auto *GEP = cast<GetElementPtrInst>(State->get(WidenPhi, 0));
1080+
auto *GEP = cast<GetElementPtrInst>(State->get(WidenPhi));
10831081
Phi = cast<PHINode>(GEP->getPointerOperand());
10841082
}
10851083

@@ -1092,7 +1090,7 @@ void VPlan::execute(VPTransformState *State) {
10921090

10931091
// Use the steps for the last part as backedge value for the induction.
10941092
if (auto *IV = dyn_cast<VPWidenIntOrFpInductionRecipe>(&R))
1095-
Inc->setOperand(0, State->get(IV->getLastUnrolledPartOperand(), 0));
1093+
Inc->setOperand(0, State->get(IV->getLastUnrolledPartOperand()));
10961094
continue;
10971095
}
10981096

@@ -1101,8 +1099,8 @@ void VPlan::execute(VPTransformState *State) {
11011099
isa<VPCanonicalIVPHIRecipe, VPEVLBasedIVPHIRecipe>(PhiR) ||
11021100
(isa<VPReductionPHIRecipe>(PhiR) &&
11031101
cast<VPReductionPHIRecipe>(PhiR)->isInLoop());
1104-
Value *Phi = State->get(PhiR, 0, NeedsScalar);
1105-
Value *Val = State->get(PhiR->getBackedgeValue(), 0, NeedsScalar);
1102+
Value *Phi = State->get(PhiR, NeedsScalar);
1103+
Value *Val = State->get(PhiR->getBackedgeValue(), NeedsScalar);
11061104
cast<PHINode>(Phi)->addIncoming(Val, VectorLatchBB);
11071105
}
11081106

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -263,30 +263,22 @@ struct VPTransformState {
263263
std::optional<VPIteration> Instance;
264264

265265
struct DataState {
266-
/// A type for vectorized values in the new loop. Each value from the
267-
/// original loop, when vectorized, is represented by UF vector values in
268-
/// the new unrolled loop, where UF is the unroll factor.
269-
typedef SmallVector<Value *, 2> PerPartValuesTy;
270-
271-
DenseMap<VPValue *, PerPartValuesTy> PerPartOutput;
266+
// Each value from the original loop, when vectorized, is represented by a
267+
// vector value in the map.
268+
DenseMap<VPValue *, Value *> VPV2Vector;
272269

273270
using ScalarsPerPartValuesTy = SmallVector<SmallVector<Value *, 4>, 2>;
274271
DenseMap<VPValue *, ScalarsPerPartValuesTy> PerPartScalars;
275272
} Data;
276273

277-
/// Get the generated vector Value for a given VPValue \p Def and a given \p
278-
/// Part if \p IsScalar is false, otherwise return the generated scalar
279-
/// for \p Part. \See set.
280-
Value *get(VPValue *Def, unsigned Part, bool IsScalar = false);
274+
/// Get the generated vector Value for a given VPValue \p Def if \p IsScalar
275+
/// is false, otherwise return the generated scalar. \See set.
276+
Value *get(VPValue *Def, bool IsScalar = false);
281277

282278
/// Get the generated Value for a given VPValue and given Part and Lane.
283279
Value *get(VPValue *Def, const VPIteration &Instance);
284280

285-
bool hasVectorValue(VPValue *Def, unsigned Part) {
286-
auto I = Data.PerPartOutput.find(Def);
287-
return I != Data.PerPartOutput.end() && Part < I->second.size() &&
288-
I->second[Part];
289-
}
281+
bool hasVectorValue(VPValue *Def) { return Data.VPV2Vector.contains(Def); }
290282

291283
bool hasScalarValue(VPValue *Def, VPIteration Instance) {
292284
auto I = Data.PerPartScalars.find(Def);
@@ -298,28 +290,22 @@ struct VPTransformState {
298290
I->second[Instance.Part][CacheIdx];
299291
}
300292

301-
/// Set the generated vector Value for a given VPValue and a given Part, if \p
302-
/// IsScalar is false. If \p IsScalar is true, set the scalar in (Part, 0).
303-
void set(VPValue *Def, Value *V, unsigned Part, bool IsScalar = false) {
293+
/// Set the generated vector Value for a given VPValue, if \p
294+
/// IsScalar is false. If \p IsScalar is true, set the scalar in lane 0.
295+
void set(VPValue *Def, Value *V, bool IsScalar = false) {
304296
if (IsScalar) {
305-
set(Def, V, VPIteration(Part, 0));
297+
set(Def, V, VPIteration(0, 0));
306298
return;
307299
}
308300
assert((VF.isScalar() || V->getType()->isVectorTy()) &&
309-
"scalar values must be stored as (Part, 0)");
310-
if (!Data.PerPartOutput.count(Def)) {
311-
DataState::PerPartValuesTy Entry(1);
312-
Data.PerPartOutput[Def] = Entry;
313-
}
314-
Data.PerPartOutput[Def][Part] = V;
301+
"scalar values must be stored as (0, 0)");
302+
Data.VPV2Vector[Def] = V;
315303
}
316304

317305
/// Reset an existing vector value for \p Def and a given \p Part.
318-
void reset(VPValue *Def, Value *V, unsigned Part) {
319-
auto Iter = Data.PerPartOutput.find(Def);
320-
assert(Iter != Data.PerPartOutput.end() &&
321-
"need to overwrite existing value");
322-
Iter->second[Part] = V;
306+
void reset(VPValue *Def, Value *V) {
307+
assert(Data.VPV2Vector.contains(Def) && "need to overwrite existing value");
308+
Data.VPV2Vector[Def] = V;
323309
}
324310

325311
/// Set the generated scalar \p V for \p Def and the given \p Instance.

0 commit comments

Comments
 (0)