|
| 1 | +//===-- VPlanPredicator.cpp - VPlan predicator ----------------------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +/// |
| 9 | +/// \file |
| 10 | +/// This file implements predication for VPlans. |
| 11 | +/// |
| 12 | +//===----------------------------------------------------------------------===// |
| 13 | + |
| 14 | +#include "VPRecipeBuilder.h" |
| 15 | +#include "VPlan.h" |
| 16 | +#include "VPlanCFG.h" |
| 17 | +#include "VPlanTransforms.h" |
| 18 | +#include "VPlanUtils.h" |
| 19 | +#include "llvm/ADT/PostOrderIterator.h" |
| 20 | + |
| 21 | +using namespace llvm; |
| 22 | + |
| 23 | +struct VPPredicator { |
| 24 | + /// When we if-convert we need to create edge masks. We have to cache values |
| 25 | + /// so that we don't end up with exponential recursion/IR. Note that |
| 26 | + /// if-conversion currently takes place during VPlan-construction, so these |
| 27 | + /// caches are only used at that stage. |
| 28 | + using EdgeMaskCacheTy = |
| 29 | + DenseMap<std::pair<VPBasicBlock *, VPBasicBlock *>, VPValue *>; |
| 30 | + using BlockMaskCacheTy = DenseMap<VPBasicBlock *, VPValue *>; |
| 31 | + |
| 32 | + VPPredicator(VPRecipeBuilder &RecipeBuilder) : RecipeBuilder(RecipeBuilder) {} |
| 33 | + |
| 34 | + VPRecipeBuilder &RecipeBuilder; |
| 35 | + |
| 36 | + VPBuilder Builder; |
| 37 | + VPValue *createEdgeMask(VPBasicBlock *Src, VPBasicBlock *Dst) { |
| 38 | + assert(is_contained(Dst->getPredecessors(), Src) && "Invalid edge"); |
| 39 | + |
| 40 | + // Look for cached value. |
| 41 | + VPValue *EdgeMask = RecipeBuilder.getEdgeMask(Src, Dst); |
| 42 | + if (EdgeMask) |
| 43 | + return EdgeMask; |
| 44 | + |
| 45 | + VPValue *SrcMask = RecipeBuilder.getBlockInMask(Src); |
| 46 | + |
| 47 | + // The terminator has to be a branch inst! |
| 48 | + if (Src->empty() || Src->getNumSuccessors() == 1) { |
| 49 | + RecipeBuilder.setEdgeMask(Src, Dst, SrcMask); |
| 50 | + return SrcMask; |
| 51 | + } |
| 52 | + |
| 53 | + auto *Term = cast<VPInstruction>(Src->getTerminator()); |
| 54 | + if (Term->getOpcode() == Instruction::Switch) { |
| 55 | + createSwitchEdgeMasks(Term); |
| 56 | + return RecipeBuilder.getEdgeMask(Src, Dst); |
| 57 | + } |
| 58 | + |
| 59 | + auto *BI = cast<VPInstruction>(Src->getTerminator()); |
| 60 | + assert(BI->getOpcode() == VPInstruction::BranchOnCond); |
| 61 | + if (Src->getSuccessors()[0] == Src->getSuccessors()[1]) { |
| 62 | + RecipeBuilder.setEdgeMask(Src, Dst, SrcMask); |
| 63 | + return SrcMask; |
| 64 | + } |
| 65 | + |
| 66 | + EdgeMask = BI->getOperand(0); |
| 67 | + assert(EdgeMask && "No Edge Mask found for condition"); |
| 68 | + |
| 69 | + if (Src->getSuccessors()[0] != Dst) |
| 70 | + EdgeMask = Builder.createNot(EdgeMask, BI->getDebugLoc()); |
| 71 | + |
| 72 | + if (SrcMask) { // Otherwise block in-mask is all-one, no need to AND. |
| 73 | + // The bitwise 'And' of SrcMask and EdgeMask introduces new UB if SrcMask |
| 74 | + // is false and EdgeMask is poison. Avoid that by using 'LogicalAnd' |
| 75 | + // instead which generates 'select i1 SrcMask, i1 EdgeMask, i1 false'. |
| 76 | + EdgeMask = Builder.createLogicalAnd(SrcMask, EdgeMask, BI->getDebugLoc()); |
| 77 | + } |
| 78 | + |
| 79 | + RecipeBuilder.setEdgeMask(Src, Dst, EdgeMask); |
| 80 | + return EdgeMask; |
| 81 | + } |
| 82 | + |
| 83 | + VPValue *createBlockInMask(VPBasicBlock *VPBB) { |
| 84 | + Builder.setInsertPoint(VPBB, VPBB->begin()); |
| 85 | + // All-one mask is modelled as no-mask following the convention for masked |
| 86 | + // load/store/gather/scatter. Initialize BlockMask to no-mask. |
| 87 | + VPValue *BlockMask = nullptr; |
| 88 | + // This is the block mask. We OR all unique incoming edges. |
| 89 | + for (auto *Predecessor : SetVector<VPBlockBase *>( |
| 90 | + VPBB->getPredecessors().begin(), VPBB->getPredecessors().end())) { |
| 91 | + VPValue *EdgeMask = createEdgeMask(cast<VPBasicBlock>(Predecessor), VPBB); |
| 92 | + if (!EdgeMask) { // Mask of predecessor is all-one so mask of block is |
| 93 | + // too. |
| 94 | + RecipeBuilder.setBlockInMask(VPBB, EdgeMask); |
| 95 | + return EdgeMask; |
| 96 | + } |
| 97 | + |
| 98 | + if (!BlockMask) { // BlockMask has its initialized nullptr value. |
| 99 | + BlockMask = EdgeMask; |
| 100 | + continue; |
| 101 | + } |
| 102 | + |
| 103 | + BlockMask = Builder.createOr(BlockMask, EdgeMask, {}); |
| 104 | + } |
| 105 | + |
| 106 | + RecipeBuilder.setBlockInMask(VPBB, BlockMask); |
| 107 | + return BlockMask; |
| 108 | + } |
| 109 | + |
| 110 | + void createHeaderMask(VPBasicBlock *HeaderVPBB, bool FoldTail) { |
| 111 | + if (!FoldTail) { |
| 112 | + RecipeBuilder.setBlockInMask(HeaderVPBB, nullptr); |
| 113 | + return; |
| 114 | + } |
| 115 | + |
| 116 | + // Introduce the early-exit compare IV <= BTC to form header block mask. |
| 117 | + // This is used instead of IV < TC because TC may wrap, unlike BTC. Start by |
| 118 | + // constructing the desired canonical IV in the header block as its first |
| 119 | + // non-phi instructions. |
| 120 | + |
| 121 | + auto NewInsertionPoint = HeaderVPBB->getFirstNonPhi(); |
| 122 | + auto &Plan = *HeaderVPBB->getPlan(); |
| 123 | + auto *IV = new VPWidenCanonicalIVRecipe(Plan.getCanonicalIV()); |
| 124 | + HeaderVPBB->insert(IV, NewInsertionPoint); |
| 125 | + |
| 126 | + VPBuilder::InsertPointGuard Guard(Builder); |
| 127 | + Builder.setInsertPoint(HeaderVPBB, NewInsertionPoint); |
| 128 | + VPValue *BlockMask = nullptr; |
| 129 | + VPValue *BTC = Plan.getOrCreateBackedgeTakenCount(); |
| 130 | + BlockMask = Builder.createICmp(CmpInst::ICMP_ULE, IV, BTC); |
| 131 | + RecipeBuilder.setBlockInMask(HeaderVPBB, BlockMask); |
| 132 | + } |
| 133 | + |
| 134 | + void createSwitchEdgeMasks(VPInstruction *SI) { |
| 135 | + VPBasicBlock *Src = SI->getParent(); |
| 136 | + |
| 137 | + // Create masks where the terminator in Src is a switch. We create mask for |
| 138 | + // all edges at the same time. This is more efficient, as we can create and |
| 139 | + // collect compares for all cases once. |
| 140 | + VPValue *Cond = SI->getOperand(0); |
| 141 | + VPBasicBlock *DefaultDst = cast<VPBasicBlock>(Src->getSuccessors()[0]); |
| 142 | + MapVector<VPBasicBlock *, SmallVector<VPValue *>> Dst2Compares; |
| 143 | + for (const auto &[Idx, Succ] : |
| 144 | + enumerate(ArrayRef(Src->getSuccessors()).drop_front())) { |
| 145 | + VPBasicBlock *Dst = cast<VPBasicBlock>(Succ); |
| 146 | + // assert(!EdgeMaskCache.contains({Src, Dst}) && "Edge masks already |
| 147 | + // created"); |
| 148 | + // Cases whose destination is the same as default are redundant and can |
| 149 | + // be ignored - they will get there anyhow. |
| 150 | + if (Dst == DefaultDst) |
| 151 | + continue; |
| 152 | + auto &Compares = Dst2Compares[Dst]; |
| 153 | + VPValue *V = SI->getOperand(Idx + 1); |
| 154 | + Compares.push_back(Builder.createICmp(CmpInst::ICMP_EQ, Cond, V)); |
| 155 | + } |
| 156 | + |
| 157 | + // We need to handle 2 separate cases below for all entries in Dst2Compares, |
| 158 | + // which excludes destinations matching the default destination. |
| 159 | + VPValue *SrcMask = RecipeBuilder.getBlockInMask(Src); |
| 160 | + VPValue *DefaultMask = nullptr; |
| 161 | + for (const auto &[Dst, Conds] : Dst2Compares) { |
| 162 | + // 1. Dst is not the default destination. Dst is reached if any of the |
| 163 | + // cases with destination == Dst are taken. Join the conditions for each |
| 164 | + // case whose destination == Dst using an OR. |
| 165 | + VPValue *Mask = Conds[0]; |
| 166 | + for (VPValue *V : ArrayRef<VPValue *>(Conds).drop_front()) |
| 167 | + Mask = Builder.createOr(Mask, V); |
| 168 | + if (SrcMask) |
| 169 | + Mask = Builder.createLogicalAnd(SrcMask, Mask); |
| 170 | + RecipeBuilder.setEdgeMask(Src, Dst, Mask); |
| 171 | + |
| 172 | + // 2. Create the mask for the default destination, which is reached if |
| 173 | + // none of the cases with destination != default destination are taken. |
| 174 | + // Join the conditions for each case where the destination is != Dst using |
| 175 | + // an OR and negate it. |
| 176 | + DefaultMask = DefaultMask ? Builder.createOr(DefaultMask, Mask) : Mask; |
| 177 | + } |
| 178 | + |
| 179 | + if (DefaultMask) { |
| 180 | + DefaultMask = Builder.createNot(DefaultMask); |
| 181 | + if (SrcMask) |
| 182 | + DefaultMask = Builder.createLogicalAnd(SrcMask, DefaultMask); |
| 183 | + } |
| 184 | + RecipeBuilder.setEdgeMask(Src, DefaultDst, DefaultMask); |
| 185 | + } |
| 186 | +}; |
| 187 | + |
| 188 | +void VPlanTransforms::predicateAndLinearize(VPlan &Plan, bool FoldTail, |
| 189 | + VPRecipeBuilder &RecipeBuilder) { |
| 190 | + VPRegionBlock *LoopRegion = Plan.getVectorLoopRegion(); |
| 191 | + // Scan the body of the loop in a topological order to visit each basic block |
| 192 | + // after having visited its predecessor basic blocks. |
| 193 | + VPBasicBlock *Header = LoopRegion->getEntryBasicBlock(); |
| 194 | + ReversePostOrderTraversal<VPBlockShallowTraversalWrapper<VPBlockBase *>> RPOT( |
| 195 | + Header); |
| 196 | + VPPredicator Predicator(RecipeBuilder); |
| 197 | + for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(RPOT)) { |
| 198 | + // Handle VPBBs down to the latch. |
| 199 | + if (VPBB == LoopRegion->getExiting()) |
| 200 | + break; |
| 201 | + |
| 202 | + if (VPBB == Header) { |
| 203 | + Predicator.createHeaderMask(Header, FoldTail); |
| 204 | + continue; |
| 205 | + } |
| 206 | + SmallVector<VPWidenPHIRecipe *> Phis; |
| 207 | + for (VPRecipeBase &R : VPBB->phis()) |
| 208 | + Phis.push_back(cast<VPWidenPHIRecipe>(&R)); |
| 209 | + |
| 210 | + Predicator.createBlockInMask(VPBB); |
| 211 | + |
| 212 | + for (VPWidenPHIRecipe *Phi : Phis) { |
| 213 | + PHINode *IRPhi = cast<PHINode>(Phi->getUnderlyingValue()); |
| 214 | + |
| 215 | + unsigned NumIncoming = IRPhi->getNumIncomingValues(); |
| 216 | + |
| 217 | + // We know that all PHIs in non-header blocks are converted into selects, |
| 218 | + // so we don't have to worry about the insertion order and we can just use |
| 219 | + // the builder. At this point we generate the predication tree. There may |
| 220 | + // be duplications since this is a simple recursive scan, but future |
| 221 | + // optimizations will clean it up. |
| 222 | + |
| 223 | + // Map incoming IR BasicBlocks to incoming VPValues, for lookup below. |
| 224 | + // TODO: Add operands and masks in order from the VPlan predecessors. |
| 225 | + DenseMap<BasicBlock *, VPValue *> VPIncomingValues; |
| 226 | + DenseMap<BasicBlock *, VPBasicBlock *> VPIncomingBlocks; |
| 227 | + for (const auto &[Idx, Pred] : |
| 228 | + enumerate(predecessors(IRPhi->getParent()))) { |
| 229 | + VPIncomingValues[Pred] = Phi->getOperand(Idx); |
| 230 | + VPIncomingBlocks[Pred] = |
| 231 | + cast<VPBasicBlock>(VPBB->getPredecessors()[Idx]); |
| 232 | + } |
| 233 | + |
| 234 | + SmallVector<VPValue *, 2> OperandsWithMask; |
| 235 | + for (unsigned In = 0; In < NumIncoming; In++) { |
| 236 | + BasicBlock *Pred = IRPhi->getIncomingBlock(In); |
| 237 | + OperandsWithMask.push_back(VPIncomingValues.lookup(Pred)); |
| 238 | + VPValue *EdgeMask = |
| 239 | + RecipeBuilder.getEdgeMask(VPIncomingBlocks.lookup(Pred), VPBB); |
| 240 | + if (!EdgeMask) { |
| 241 | + assert(In == 0 && "Both null and non-null edge masks found"); |
| 242 | + assert(all_equal(Phi->operands()) && |
| 243 | + "Distinct incoming values with one having a full mask"); |
| 244 | + break; |
| 245 | + } |
| 246 | + OperandsWithMask.push_back(EdgeMask); |
| 247 | + } |
| 248 | + auto *Blend = new VPBlendRecipe(IRPhi, OperandsWithMask); |
| 249 | + Blend->insertBefore(Phi); |
| 250 | + Phi->replaceAllUsesWith(Blend); |
| 251 | + Phi->eraseFromParent(); |
| 252 | + RecipeBuilder.setRecipe(IRPhi, Blend); |
| 253 | + } |
| 254 | + } |
| 255 | +} |
0 commit comments