-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][linalg] Split GenericPadOpVectorizationPattern into two patterns #111349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
45318f3
a8406b3
734e563
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2281,115 +2281,6 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter, | |
//----------------------------------------------------------------------------// | ||
// Misc. vectorization patterns. | ||
//----------------------------------------------------------------------------// | ||
|
||
/// Helper function that retrieves the value of an IntegerAttr. | ||
static int64_t getIntFromAttr(Attribute attr) { | ||
return cast<IntegerAttr>(attr).getInt(); | ||
} | ||
|
||
/// Given an ArrayRef of OpFoldResults, return a vector of Values. | ||
/// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are | ||
/// not supported. | ||
static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc, | ||
ArrayRef<OpFoldResult> ofrs) { | ||
SmallVector<Value> result; | ||
for (auto o : ofrs) { | ||
if (auto val = llvm::dyn_cast_if_present<Value>(o)) { | ||
result.push_back(val); | ||
} else { | ||
result.push_back(rewriter.create<arith::ConstantIndexOp>( | ||
loc, getIntFromAttr(o.template get<Attribute>()))); | ||
} | ||
} | ||
return result; | ||
} | ||
|
||
/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and | ||
/// InsertSliceOp. For now, only constant padding values are supported. | ||
/// If there is enough static type information, TransferReadOps and | ||
/// TransferWriteOps may be generated instead of InsertSliceOps. | ||
struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern { | ||
GenericPadOpVectorizationPattern(MLIRContext *context, | ||
PatternBenefit benefit = 1) | ||
: GeneralizePadOpPattern(context, tryVectorizeCopy, benefit) {} | ||
/// Vectorize the copying of a tensor::PadOp's source. This is possible if | ||
/// each dimension size is statically know in the source type or the result | ||
/// type (or both). | ||
static LogicalResult tryVectorizeCopy(RewriterBase &rewriter, | ||
tensor::PadOp padOp, Value dest) { | ||
auto sourceType = padOp.getSourceType(); | ||
auto resultType = padOp.getResultType(); | ||
if (!VectorType::isValidElementType(sourceType.getElementType())) | ||
return failure(); | ||
|
||
// Copy cannot be vectorized if pad value is non-constant and source shape | ||
// is dynamic. In case of a dynamic source shape, padding must be appended | ||
// by TransferReadOp, but TransferReadOp supports only constant padding. | ||
auto padValue = padOp.getConstantPaddingValue(); | ||
if (!padValue) { | ||
if (!sourceType.hasStaticShape()) | ||
return failure(); | ||
// Create dummy padding value. | ||
auto elemType = sourceType.getElementType(); | ||
padValue = rewriter.create<arith::ConstantOp>( | ||
padOp.getLoc(), elemType, rewriter.getZeroAttr(elemType)); | ||
} | ||
|
||
SmallVector<int64_t> vecShape; | ||
SmallVector<bool> readInBounds; | ||
SmallVector<bool> writeInBounds; | ||
for (unsigned i = 0; i < sourceType.getRank(); ++i) { | ||
if (!sourceType.isDynamicDim(i)) { | ||
vecShape.push_back(sourceType.getDimSize(i)); | ||
// Source shape is statically known: Neither read nor write are | ||
// out-of- bounds. | ||
readInBounds.push_back(true); | ||
writeInBounds.push_back(true); | ||
} else if (!resultType.isDynamicDim(i)) { | ||
// Source shape is not statically known, but result shape is. | ||
// Vectorize with size of result shape. This may be larger than the | ||
// source size. | ||
vecShape.push_back(resultType.getDimSize(i)); | ||
// Read may be out-of-bounds because the result size could be larger | ||
// than the source size. | ||
readInBounds.push_back(false); | ||
// Write is out-of-bounds if low padding > 0. | ||
writeInBounds.push_back( | ||
getConstantIntValue(padOp.getMixedLowPad()[i]) == | ||
static_cast<int64_t>(0)); | ||
} else { | ||
// Neither source nor result dim of padOp is static. Cannot vectorize | ||
// the copy. | ||
return failure(); | ||
} | ||
} | ||
auto vecType = VectorType::get(vecShape, sourceType.getElementType()); | ||
|
||
// Generate TransferReadOp. | ||
SmallVector<Value> readIndices( | ||
vecType.getRank(), | ||
rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0)); | ||
auto read = rewriter.create<vector::TransferReadOp>( | ||
padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue, | ||
ArrayRef<bool>{readInBounds}); | ||
|
||
// If `dest` is a FillOp and the TransferWriteOp would overwrite the | ||
// entire tensor, write directly to the FillOp's operand. | ||
if (llvm::equal(vecShape, resultType.getShape()) && | ||
llvm::all_of(writeInBounds, [](bool b) { return b; })) | ||
if (auto fill = dest.getDefiningOp<FillOp>()) | ||
dest = fill.output(); | ||
|
||
// Generate TransferWriteOp. | ||
auto writeIndices = | ||
ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad()); | ||
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( | ||
padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds}); | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
/// Base pattern for rewriting tensor::PadOps whose result is consumed by a | ||
/// given operation type OpTy. | ||
template <typename OpTy> | ||
|
@@ -2623,6 +2514,163 @@ struct PadOpVectorizationWithTransferWritePattern | |
} | ||
}; | ||
|
||
/// Returns the effective Pad value for the input op, provided it's a scalar. | ||
/// | ||
/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If | ||
/// this Op performs padding, retrieve the padding value provided that it's | ||
/// a scalar and static/fixed for all the padded values. Returns an empty value | ||
/// otherwise. | ||
static Value getStaticPadVal(Operation *op) { | ||
if (!op) | ||
return {}; | ||
|
||
// 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's | ||
// being broadcast, provided that it's a scalar. | ||
if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) { | ||
auto source = bcast.getSource(); | ||
if (llvm::dyn_cast<VectorType>(source.getType())) | ||
return {}; | ||
|
||
return source; | ||
} | ||
|
||
// 2. linalg.fill - use the scalar input value that used to fill the output | ||
// tensor. | ||
if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) { | ||
return fill.getInputs()[0]; | ||
} | ||
|
||
// 3. tensor.generateOp - can't guarantee the value is fixed without | ||
// analysing, bail out. | ||
if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) { | ||
return {}; | ||
} | ||
|
||
// 4. vector.transfer_write - inspect the input vector that's written from. If | ||
// if contains a single value that has been broadcast (e.g. via | ||
// vector.broadcast), extract it, fail otherwise. | ||
if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op)) | ||
return getStaticPadVal(xferWrite.getVector().getDefiningOp()); | ||
|
||
// 5. tensor.insert_slice - inspect the destination tensor. If it's larger | ||
// than the input tensor, then, provided it's constant, we'll extract the | ||
// value that was used to generate it (via e.g. linalg.fill), fail otherwise. | ||
// TODO: Clarify the semantics when the input tensor is larger than the | ||
// destination. | ||
if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op)) | ||
return getStaticPadVal(slice.getDest().getDefiningOp()); | ||
|
||
return {}; | ||
} | ||
|
||
/// Rewrite tensor.insert.slice as a vector.transfer_read + | ||
/// vector.transfer_write pair. The vector size is inferred from the static | ||
/// dims in the input and output tensors. If a dim is dynamic in both the input | ||
/// and output tensors, bails out. | ||
/// | ||
/// Before: | ||
/// !t_in_type = tensor<1x2x3xf32> | ||
/// !t_out_type = tensor<9x8x7x1x2x3xf32> | ||
/// !v_type = vector<1x2x3xf32> | ||
/// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type | ||
/// into !t_out_type | ||
/// After: | ||
/// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type | ||
/// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type | ||
/// | ||
/// TODO: Support masking | ||
struct InsertSliceVectorizePattern | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It can be done in a follow-up. I think we want to have an unified API and vectorization path, which helps us manage the "future" extensibility feature/design better. So it'd be a plus if we move the implementation to (There was a long and old discussion in https://reviews.llvm.org/D150495 though it's not happening...but I think the point is still hold...) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I am wondering whether that would be desirable in this case ... Let me explain.
My only concern is that vectorizing Btw, thank you for reminding me of that discussion on the overall design 🙏🏻 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, good point! The plan sounds good to me, and yes we should be careful for the insert_slice case. Because there are some transformations rely on these tiling artifacts (i.e., extract_slice/insert_slice pair). I think we'll move to better state with this PR. IIRC, we'll be able to remove the other three patterns (see below) with proper pad op vectorization support. (I think it will become transfer_read, can get folded into other transfer ops in folders/canonicalization.) // Try these specialized patterns first before resorting to the generic one.
patterns.add<PadOpVectorizationWithTransferReadPattern,
PadOpVectorizationWithTransferWritePattern,
PadOpVectorizationWithInsertSlicePattern>(
patterns.getContext(), baseBenefit.getBenefit() + 1); |
||
: public OpRewritePattern<tensor::InsertSliceOp> { | ||
using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(tensor::InsertSliceOp sliceOp, | ||
PatternRewriter &rewriter) const final { | ||
auto sourceType = sliceOp.getSource().getType(); | ||
if (!VectorType::isValidElementType(sourceType.getElementType())) | ||
return failure(); | ||
|
||
auto resultType = sliceOp.getResultType(); | ||
|
||
// 1. Get the pad value. | ||
// TransferReadOp requires a scalar padding value. Note that: | ||
// * for in-bounds access, the value is actually irrelevant. | ||
// There are 2 cases in which xfer.read accesses are known to be in-bounds: | ||
// 1. The source shape is static (output vector sizes would be based on | ||
// the source shape and hence all memory accesses would be in-bounds), | ||
// 2. Masking is used (output vector sizes would be user-provided, in which | ||
// case it is assumed that all memory accesses are in-bounds). This | ||
// remains a TODO. | ||
// | ||
// When the value is not known and not needed, use 0. Otherwise, bail out. | ||
Value padValue = getStaticPadVal(sliceOp); | ||
bool isOutOfBoundsRead = !sourceType.hasStaticShape(); | ||
|
||
if (!padValue && isOutOfBoundsRead) { | ||
LDBG("Failed to get a pad value for out-of-bounds read access\n"); | ||
return failure(); | ||
} | ||
|
||
if (!padValue) { | ||
auto elemType = sourceType.getElementType(); | ||
padValue = rewriter.create<arith::ConstantOp>( | ||
sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType)); | ||
} | ||
|
||
// 2. Get the vector shape and in-bounds attributes | ||
SmallVector<int64_t> vecShape; | ||
SmallVector<bool> readInBounds; | ||
SmallVector<bool> writeInBounds; | ||
size_t rankDiff = resultType.getRank() - sourceType.getRank(); | ||
for (unsigned i = 0; i < sourceType.getRank(); ++i) { | ||
if (!sourceType.isDynamicDim(i)) { | ||
vecShape.push_back(sourceType.getDimSize(i)); | ||
// Source shape is statically known: Neither read nor write are | ||
// out-of-bounds. | ||
readInBounds.push_back(true); | ||
writeInBounds.push_back(true); | ||
} else if (!resultType.isDynamicDim(i)) { | ||
// Source shape is not statically known, but result shape is. | ||
// Vectorize with size of result shape. This may be larger than the | ||
// source size. | ||
// FIXME: Using rankDiff implies that the source tensor is inserted at | ||
// the end of the destination tensor. However, that's not required. | ||
vecShape.push_back(resultType.getDimSize(rankDiff + i)); | ||
// Read may be out-of-bounds because the result size could be larger | ||
// than the source size. | ||
readInBounds.push_back(false); | ||
// Write will in-bounds provided that the corresponding write idx is 0. | ||
// To keep this logic simple, conservatively mark as out-of-bounds. | ||
writeInBounds.push_back(false); | ||
} else { | ||
// Neither source nor result dim of padOp is static. Cannot vectorize | ||
// the copy. | ||
// TODO: Add support for masking | ||
return failure(); | ||
} | ||
} | ||
auto vecType = VectorType::get(vecShape, sourceType.getElementType()); | ||
|
||
// 3. Generate TransferReadOp. | ||
SmallVector<Value> readIndices( | ||
vecType.getRank(), | ||
rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0)); | ||
auto read = rewriter.create<vector::TransferReadOp>( | ||
sliceOp.getLoc(), vecType, sliceOp.getSource(), readIndices, padValue, | ||
ArrayRef<bool>{readInBounds}); | ||
|
||
// 4. Generate TransferWriteOp. | ||
auto writeIndices = getValueOrCreateConstantIndexOp( | ||
rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets()); | ||
|
||
// 5. Finalize | ||
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( | ||
sliceOp, read, sliceOp.getDest(), writeIndices, | ||
ArrayRef<bool>{writeInBounds}); | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
/// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.: | ||
/// ``` | ||
/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32> | ||
|
@@ -2699,8 +2747,8 @@ struct PadOpVectorizationWithInsertSlicePattern | |
// Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at | ||
// specified offsets. Write is fully in-bounds because a InsertSliceOp's | ||
// source must fit into the destination at the specified offsets. | ||
auto writeIndices = | ||
ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets()); | ||
auto writeIndices = getValueOrCreateConstantIndexOp( | ||
rewriter, padOp.getLoc(), insertOp.getMixedOffsets()); | ||
SmallVector<bool> inBounds(vecRank, true); | ||
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( | ||
insertOp, read, insertOp.getDest(), writeIndices, | ||
|
@@ -2710,13 +2758,18 @@ struct PadOpVectorizationWithInsertSlicePattern | |
} | ||
}; | ||
|
||
void mlir::linalg::populateInsertSliceVectorizationPatterns( | ||
RewritePatternSet &patterns) { | ||
patterns.add<InsertSliceVectorizePattern>(patterns.getContext()); | ||
} | ||
|
||
void mlir::linalg::populatePadOpVectorizationPatterns( | ||
RewritePatternSet &patterns, PatternBenefit baseBenefit) { | ||
// TODO: The following pattern implements "decomposition" and | ||
// optional "vectorization". Seperate "decomposition" into a sepereate | ||
// pre-processing pattern group. | ||
patterns.add<GenericPadOpVectorizationPattern>(patterns.getContext(), | ||
baseBenefit); | ||
patterns.add<GeneralizePadOpPattern>(patterns.getContext(), baseBenefit); | ||
|
||
// Try these specialized patterns first before resorting to the generic one. | ||
patterns.add<PadOpVectorizationWithTransferReadPattern, | ||
PadOpVectorizationWithTransferWritePattern, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, I was confused. Then I realized that we can broadcast a scalar to vector (e.g., f32 -> vector<...xf32>).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a small note to clarify.