|
| 1 | +//===----- FlattenMemRefs.cpp - MemRef ops flattener pass ----------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// This file contains patterns for flattening an multi-rank memref-related |
| 10 | +// ops into 1-d memref ops. |
| 11 | +// |
| 12 | +//===----------------------------------------------------------------------===// |
| 13 | + |
| 14 | +#include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 15 | +#include "mlir/Dialect/Arith/IR/Arith.h" |
| 16 | +#include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 17 | +#include "mlir/Dialect/MemRef/Transforms/Passes.h" |
| 18 | +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
| 19 | +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
| 20 | +#include "mlir/Dialect/Utils/IndexingUtils.h" |
| 21 | +#include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 22 | +#include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 23 | +#include "mlir/IR/AffineExpr.h" |
| 24 | +#include "mlir/IR/Attributes.h" |
| 25 | +#include "mlir/IR/Builders.h" |
| 26 | +#include "mlir/IR/BuiltinTypes.h" |
| 27 | +#include "mlir/IR/OpDefinition.h" |
| 28 | +#include "mlir/IR/PatternMatch.h" |
| 29 | +#include "mlir/Pass/Pass.h" |
| 30 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 31 | +#include "llvm/ADT/SmallVector.h" |
| 32 | +#include "llvm/ADT/TypeSwitch.h" |
| 33 | + |
| 34 | +#include <numeric> |
| 35 | + |
| 36 | +namespace mlir { |
| 37 | +namespace memref { |
| 38 | +#define GEN_PASS_DEF_FLATTENMEMREFSPASS |
| 39 | +#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" |
| 40 | +} // namespace memref |
| 41 | +} // namespace mlir |
| 42 | + |
| 43 | +using namespace mlir; |
| 44 | + |
| 45 | +static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc, |
| 46 | + OpFoldResult in) { |
| 47 | + if (Attribute offsetAttr = dyn_cast<Attribute>(in)) { |
| 48 | + return rewriter.create<arith::ConstantIndexOp>( |
| 49 | + loc, cast<IntegerAttr>(offsetAttr).getInt()); |
| 50 | + } |
| 51 | + return cast<Value>(in); |
| 52 | +} |
| 53 | + |
| 54 | +/// Returns a collapsed memref and the linearized index to access the element |
| 55 | +/// at the specified indices. |
| 56 | +static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter, |
| 57 | + Location loc, |
| 58 | + Value source, |
| 59 | + ValueRange indices) { |
| 60 | + int64_t sourceOffset; |
| 61 | + SmallVector<int64_t, 4> sourceStrides; |
| 62 | + auto sourceType = cast<MemRefType>(source.getType()); |
| 63 | + if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset))) { |
| 64 | + assert(false); |
| 65 | + } |
| 66 | + |
| 67 | + memref::ExtractStridedMetadataOp stridedMetadata = |
| 68 | + rewriter.create<memref::ExtractStridedMetadataOp>(loc, source); |
| 69 | + |
| 70 | + auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth(); |
| 71 | + OpFoldResult linearizedIndices; |
| 72 | + memref::LinearizedMemRefInfo linearizedInfo; |
| 73 | + std::tie(linearizedInfo, linearizedIndices) = |
| 74 | + memref::getLinearizedMemRefOffsetAndSize( |
| 75 | + rewriter, loc, typeBit, typeBit, |
| 76 | + stridedMetadata.getConstifiedMixedOffset(), |
| 77 | + stridedMetadata.getConstifiedMixedSizes(), |
| 78 | + stridedMetadata.getConstifiedMixedStrides(), |
| 79 | + getAsOpFoldResult(indices)); |
| 80 | + |
| 81 | + return std::make_pair( |
| 82 | + rewriter.create<memref::ReinterpretCastOp>( |
| 83 | + loc, source, |
| 84 | + /* offset = */ linearizedInfo.linearizedOffset, |
| 85 | + /* shapes = */ |
| 86 | + ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize}, |
| 87 | + /* strides = */ |
| 88 | + ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)}), |
| 89 | + getValueFromOpFoldResult(rewriter, loc, linearizedIndices)); |
| 90 | +} |
| 91 | + |
| 92 | +static bool needFlattening(Value val) { |
| 93 | + auto type = cast<MemRefType>(val.getType()); |
| 94 | + return type.getRank() > 1; |
| 95 | +} |
| 96 | + |
| 97 | +static bool checkLayout(Value val) { |
| 98 | + auto type = cast<MemRefType>(val.getType()); |
| 99 | + return type.getLayout().isIdentity() || |
| 100 | + isa<StridedLayoutAttr>(type.getLayout()); |
| 101 | +} |
| 102 | + |
| 103 | +namespace { |
| 104 | +static Value getTargetMemref(Operation *op) { |
| 105 | + return llvm::TypeSwitch<Operation *, Value>(op) |
| 106 | + .template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp, |
| 107 | + memref::AllocOp>([](auto op) { return op.getMemref(); }) |
| 108 | + .template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp, |
| 109 | + vector::MaskedStoreOp, vector::TransferReadOp, |
| 110 | + vector::TransferWriteOp>( |
| 111 | + [](auto op) { return op.getBase(); }) |
| 112 | + .Default([](auto) { return Value{}; }); |
| 113 | +} |
| 114 | + |
| 115 | +template <typename T> |
| 116 | +static void castAllocResult(T oper, T newOper, Location loc, |
| 117 | + PatternRewriter &rewriter) { |
| 118 | + memref::ExtractStridedMetadataOp stridedMetadata = |
| 119 | + rewriter.create<memref::ExtractStridedMetadataOp>(loc, oper); |
| 120 | + rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>( |
| 121 | + oper, cast<MemRefType>(oper.getType()), newOper, |
| 122 | + /*offset=*/rewriter.getIndexAttr(0), |
| 123 | + stridedMetadata.getConstifiedMixedSizes(), |
| 124 | + stridedMetadata.getConstifiedMixedStrides()); |
| 125 | +} |
| 126 | + |
| 127 | +template <typename T> |
| 128 | +static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref, |
| 129 | + Value offset) { |
| 130 | + Location loc = op->getLoc(); |
| 131 | + llvm::TypeSwitch<Operation *>(op.getOperation()) |
| 132 | + .template Case<memref::AllocOp>([&](auto oper) { |
| 133 | + auto newAlloc = rewriter.create<memref::AllocOp>( |
| 134 | + loc, cast<MemRefType>(flatMemref.getType()), |
| 135 | + oper.getAlignmentAttr()); |
| 136 | + castAllocResult(oper, newAlloc, loc, rewriter); |
| 137 | + }) |
| 138 | + .template Case<memref::AllocaOp>([&](auto oper) { |
| 139 | + auto newAlloca = rewriter.create<memref::AllocaOp>( |
| 140 | + loc, cast<MemRefType>(flatMemref.getType()), |
| 141 | + oper.getAlignmentAttr()); |
| 142 | + castAllocResult(oper, newAlloca, loc, rewriter); |
| 143 | + }) |
| 144 | + .template Case<memref::LoadOp>([&](auto op) { |
| 145 | + auto newLoad = rewriter.create<memref::LoadOp>( |
| 146 | + loc, op->getResultTypes(), flatMemref, ValueRange{offset}); |
| 147 | + newLoad->setAttrs(op->getAttrs()); |
| 148 | + rewriter.replaceOp(op, newLoad.getResult()); |
| 149 | + }) |
| 150 | + .template Case<memref::StoreOp>([&](auto op) { |
| 151 | + auto newStore = rewriter.create<memref::StoreOp>( |
| 152 | + loc, op->getOperands().front(), flatMemref, ValueRange{offset}); |
| 153 | + newStore->setAttrs(op->getAttrs()); |
| 154 | + rewriter.replaceOp(op, newStore); |
| 155 | + }) |
| 156 | + .template Case<vector::LoadOp>([&](auto op) { |
| 157 | + auto newLoad = rewriter.create<vector::LoadOp>( |
| 158 | + loc, op->getResultTypes(), flatMemref, ValueRange{offset}); |
| 159 | + newLoad->setAttrs(op->getAttrs()); |
| 160 | + rewriter.replaceOp(op, newLoad.getResult()); |
| 161 | + }) |
| 162 | + .template Case<vector::StoreOp>([&](auto op) { |
| 163 | + auto newStore = rewriter.create<vector::StoreOp>( |
| 164 | + loc, op->getOperands().front(), flatMemref, ValueRange{offset}); |
| 165 | + newStore->setAttrs(op->getAttrs()); |
| 166 | + rewriter.replaceOp(op, newStore); |
| 167 | + }) |
| 168 | + .template Case<vector::MaskedLoadOp>([&](auto op) { |
| 169 | + auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>( |
| 170 | + loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(), |
| 171 | + op.getPassThru()); |
| 172 | + newMaskedLoad->setAttrs(op->getAttrs()); |
| 173 | + rewriter.replaceOp(op, newMaskedLoad.getResult()); |
| 174 | + }) |
| 175 | + .template Case<vector::MaskedStoreOp>([&](auto op) { |
| 176 | + auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>( |
| 177 | + loc, flatMemref, ValueRange{offset}, op.getMask(), |
| 178 | + op.getValueToStore()); |
| 179 | + newMaskedStore->setAttrs(op->getAttrs()); |
| 180 | + rewriter.replaceOp(op, newMaskedStore); |
| 181 | + }) |
| 182 | + .template Case<vector::TransferReadOp>([&](auto op) { |
| 183 | + auto newTransferRead = rewriter.create<vector::TransferReadOp>( |
| 184 | + loc, op.getType(), flatMemref, ValueRange{offset}, op.getPadding()); |
| 185 | + rewriter.replaceOp(op, newTransferRead.getResult()); |
| 186 | + }) |
| 187 | + .template Case<vector::TransferWriteOp>([&](auto op) { |
| 188 | + auto newTransferWrite = rewriter.create<vector::TransferWriteOp>( |
| 189 | + loc, op.getVector(), flatMemref, ValueRange{offset}); |
| 190 | + rewriter.replaceOp(op, newTransferWrite); |
| 191 | + }) |
| 192 | + .Default([&](auto op) { |
| 193 | + op->emitOpError("unimplemented: do not know how to replace op."); |
| 194 | + }); |
| 195 | +} |
| 196 | + |
| 197 | +template <typename T> |
| 198 | +static ValueRange getIndices(T op) { |
| 199 | + if constexpr (std::is_same_v<T, memref::AllocaOp> || |
| 200 | + std::is_same_v<T, memref::AllocOp>) { |
| 201 | + return ValueRange{}; |
| 202 | + } else { |
| 203 | + return op.getIndices(); |
| 204 | + } |
| 205 | +} |
| 206 | + |
| 207 | +template <typename T> |
| 208 | +static LogicalResult canBeFlattened(T op, PatternRewriter &rewriter) { |
| 209 | + return llvm::TypeSwitch<Operation *, LogicalResult>(op.getOperation()) |
| 210 | + .template Case<vector::TransferReadOp, vector::TransferWriteOp>( |
| 211 | + [&](auto oper) { |
| 212 | + // For vector.transfer_read/write, must make sure: |
| 213 | + // 1. all accesses are inbound, and |
| 214 | + // 2. has an identity or minor identity permutation map. |
| 215 | + auto permutationMap = oper.getPermutationMap(); |
| 216 | + if (!permutationMap.isIdentity() && |
| 217 | + !permutationMap.isMinorIdentity()) { |
| 218 | + return rewriter.notifyMatchFailure( |
| 219 | + oper, "only identity permutation map is supported"); |
| 220 | + } |
| 221 | + mlir::ArrayAttr inbounds = oper.getInBounds(); |
| 222 | + if (llvm::any_of(inbounds, [](Attribute attr) { |
| 223 | + return !cast<BoolAttr>(attr).getValue(); |
| 224 | + })) { |
| 225 | + return rewriter.notifyMatchFailure(oper, |
| 226 | + "only inbounds are supported"); |
| 227 | + } |
| 228 | + return success(); |
| 229 | + }) |
| 230 | + .Default([&](auto op) { return success(); }); |
| 231 | +} |
| 232 | + |
| 233 | +template <typename T> |
| 234 | +struct MemRefRewritePattern : public OpRewritePattern<T> { |
| 235 | + using OpRewritePattern<T>::OpRewritePattern; |
| 236 | + LogicalResult matchAndRewrite(T op, |
| 237 | + PatternRewriter &rewriter) const override { |
| 238 | + LogicalResult canFlatten = canBeFlattened(op, rewriter); |
| 239 | + if (failed(canFlatten)) { |
| 240 | + return canFlatten; |
| 241 | + } |
| 242 | + |
| 243 | + Value memref = getTargetMemref(op); |
| 244 | + if (!needFlattening(memref) || !checkLayout(memref)) |
| 245 | + return failure(); |
| 246 | + auto &&[flatMemref, offset] = getFlattenMemrefAndOffset( |
| 247 | + rewriter, op->getLoc(), memref, getIndices<T>(op)); |
| 248 | + replaceOp<T>(op, rewriter, flatMemref, offset); |
| 249 | + return success(); |
| 250 | + } |
| 251 | +}; |
| 252 | + |
| 253 | +struct FlattenMemrefsPass |
| 254 | + : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> { |
| 255 | + using Base::Base; |
| 256 | + |
| 257 | + void getDependentDialects(DialectRegistry ®istry) const override { |
| 258 | + registry.insert<affine::AffineDialect, arith::ArithDialect, |
| 259 | + memref::MemRefDialect, vector::VectorDialect>(); |
| 260 | + } |
| 261 | + |
| 262 | + void runOnOperation() override { |
| 263 | + RewritePatternSet patterns(&getContext()); |
| 264 | + |
| 265 | + memref::populateFlattenMemrefsPatterns(patterns); |
| 266 | + |
| 267 | + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) |
| 268 | + return signalPassFailure(); |
| 269 | + } |
| 270 | +}; |
| 271 | + |
| 272 | +} // namespace |
| 273 | + |
| 274 | +void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) { |
| 275 | + patterns.insert<MemRefRewritePattern<memref::LoadOp>, |
| 276 | + MemRefRewritePattern<memref::StoreOp>, |
| 277 | + MemRefRewritePattern<memref::AllocOp>, |
| 278 | + MemRefRewritePattern<memref::AllocaOp>, |
| 279 | + MemRefRewritePattern<vector::LoadOp>, |
| 280 | + MemRefRewritePattern<vector::StoreOp>, |
| 281 | + MemRefRewritePattern<vector::TransferReadOp>, |
| 282 | + MemRefRewritePattern<vector::TransferWriteOp>, |
| 283 | + MemRefRewritePattern<vector::MaskedLoadOp>, |
| 284 | + MemRefRewritePattern<vector::MaskedStoreOp>>( |
| 285 | + patterns.getContext()); |
| 286 | +} |
0 commit comments