Skip to content

[mlir][Vectorizer] Added support to Vectorize tensor.unpack #76087

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,11 @@ FailureOr<RankedTensorType>
computeTransposedType(RankedTensorType rankedTensorType,
ArrayRef<int64_t> transposeVector);

/// Given a tensor::PackOp, compute the permutation vector to shuffle the
/// packed shape into the shape before any outer or inner permutations have
/// been applied.
/// i.e. for a pack from an ABCD layout to an ABCDba:
/// The packed shape would be ABCDba.
/// The pre-permutation shape would be AaBbCD.
SmallVector<int64_t> getPackInverseDestPermutation(PackOp packOp);
SmallVector<int64_t> getPackInverseDestPerm(tensor::PackOp packOp);
SmallVector<int64_t> getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp);

SmallVector<int64_t> getUnPackInverseSrcPerm(tensor::UnPackOp,
PackingMetadata &metadata);

/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
/// source tensor or inserts the source tensor into a destination tensor with
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3152,7 +3152,8 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(

// TODO: Check that the correct number of vectorSizes was provided.
for (Operation *target : targets) {
if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp>(target)) {
if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
target)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Unsupported Op, cannot vectorize";
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
PackingMetadata packingMetadata = computePackingMetadata(
packedTensorType.getRank(), packOp.getInnerDimsPos());
SmallVector<int64_t> packedToStripMinedShapePerm =
tensor::getPackInverseDestPermutation(packOp);
tensor::getPackInverseDestPerm(packOp);

// 3. Compute the stripMinedShape: this is the packed shape before any outer
// or inner permutations have been applied.
Expand Down
182 changes: 157 additions & 25 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1405,8 +1405,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
/// permutations.
static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
ArrayRef<int64_t> destShape) {
return applyPermutation(destShape,
tensor::getPackInverseDestPermutation(packOp));
return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp));
}

/// Create a TransferReadOp from `source` with static shape `readShape`. If the
Expand Down Expand Up @@ -1547,7 +1546,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,

// Create TransposeOp.
auto destPermutation =
invertPermutationVector(tensor::getPackInverseDestPermutation(packOp));
invertPermutationVector(tensor::getPackInverseDestPerm(packOp));
auto transposeOp = rewriter.create<vector::TransposeOp>(
loc, shapeCastOp.getResult(), destPermutation);

Expand All @@ -1559,6 +1558,112 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
return success();
}

/// Vectorize a `tensor::UnPackOp` to these 4 Ops:
/// Vector::TransferReadOp - Reads a vector from the source tensor
/// vector::TransposeOp - Transpose the Source tensor
/// ShapeCastOp - Reshape the data based on the target.
/// vector::TransferWriteOp. - Write the result vector back to the destination
/// tensor
static LogicalResult
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults) {

OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unpackOp);

RankedTensorType unpackTensorType = unpackOp.getSourceType();

ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();

SmallVector<int64_t> readMaskShape(inputVectorSizes.begin(),
inputVectorSizes.end());
ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();

// ReadMask is the size of tensor used to read and apply mask. It is
// set like this: Let's say the vectorSize (VS) array is size 'N' and
// the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
// size M-N
// Thus:
// - initially: ReadMaskShape = vectorInputSizes
// - Divide all the readMaskShape locations pointed by innerDimPos
// by the innerTileSize attribute value.
// - if outer_dims_perms is present: do that permutation on readMaskShape.
// - Append the remaining shape from SS
// E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
// inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
// 128] and outer_dims_perm is [1, 0] then read shape is:
// ReadMaskShape(initial): [512, 128]
// Final Value(after innerDim Adjustment): [512/32, 128/16]
// = [16, 8]
// After applying outer_dims_perm: [8, 16]
// After appending the rest of the sourceShape: [8, 16, 32, 16]

for (auto [index, size] : enumerate(innerTiles)) {
readMaskShape[innerDimPos[index]] =
llvm::divideCeil(readMaskShape[innerDimPos[index]], size);
}
if (!outerDimsPerm.empty()) {
applyPermutationToVector(readMaskShape, outerDimsPerm);
}
readMaskShape.append(sourceShape.begin() + inputVectorSizes.size(),
sourceShape.end());

ReifiedRankedShapedTypeDims reifiedRetShapes;
LogicalResult status =
cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
.reifyResultShapes(rewriter, reifiedRetShapes);
if (status.failed()) {
LDBG("Unable to reify result shapes of " << unpackOp);
return failure();
}
Location loc = unpackOp->getLoc();

auto padValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));

// Read result, mask if necessary. If transferReadOp shape is not equal
// to shape of source, then a mask is necessary.
Value readResult = createReadOrMaskedRead(
rewriter, loc, unpackOp.getSource(),
ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue);

PackingMetadata packMetadata;
SmallVector<int64_t> lastDimToInsertPosPerm =
tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata);
ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
RankedTensorType stripMineTensorType =
RankedTensorType::get(stripMineShape, stripMineElemType);
// Transpose the appropriate rows to match output.
vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
loc, readResult, lastDimToInsertPosPerm);

// Collapse the vector to the size required by result.
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
stripMineTensorType, packMetadata.reassociations);
mlir::VectorType vecCollapsedType =
VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
loc, vecCollapsedType, transposeOp->getResult(0));

// WriteMaskShape had to match the shapecast shape for dynamic sizes,
// otherwise the validator complains that the mask size is invalid.
SmallVector<int64_t> writeMaskShape(
unpackOp.getDestType().hasStaticShape()
? inputVectorSizes
: shapeCastOp.getResultVectorType().getShape());
Operation *write =
createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(),
reifiedRetShapes[0], writeMaskShape);
newResults.push_back(write->getResult(0));
return success();
}

/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
/// and (3) all-zero lowPad to
/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
Expand Down Expand Up @@ -1655,6 +1760,25 @@ isValidMaskedInputVector(ArrayRef<int64_t> shape,
return success();
}

/// Need to check if the inner-tiles are static/constant.
static LogicalResult
vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
ArrayRef<int64_t> inputVectorSizes) {

if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
return !getConstantIntValue(res).has_value();
})) {
LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
return failure();
}
llvm::ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
if (!inputVectorSizes.empty() &&
failed(isValidMaskedInputVector(resultShape, inputVectorSizes)))
return failure();

return success();
}

static LogicalResult
vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
ArrayRef<int64_t> inputVectorSizes,
Expand Down Expand Up @@ -1703,9 +1827,10 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
}
if (isElementwise(linalgOp))
return success();
// TODO: isaConvolutionOpInterface that can also infer from generic features.
// But we will still need stride/dilation attributes that will be annoying to
// reverse-engineer...

// TODO: isaConvolutionOpInterface that can also infer from generic
// features. But we will still need stride/dilation attributes that will be
// annoying to reverse-engineer...
if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
return success();
// TODO: the common vector shape is equal to the static loop sizes only when
Expand Down Expand Up @@ -1810,6 +1935,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
.Case<tensor::PackOp>([&](auto packOp) {
return vectorizePackOpPrecondition(packOp, inputVectorSizes);
})
.Case<tensor::UnPackOp>([&](auto unpackOp) {
return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
})
.Default([](auto) { return failure(); });
}

Expand All @@ -1829,11 +1957,11 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
}

/// Emit a suitable vector form for an operation. If provided,
/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
/// must match the rank of the iteration space of the operation and the input
/// vector sizes must be greater than or equal to their counterpart iteration
/// space sizes, if static. `inputVectorShapes` also allows the vectorization of
/// operations with dynamic shapes.
/// `inputVectorSizes` are used to vectorize this operation.
/// `inputVectorSizes` must match the rank of the iteration space of the
/// operation and the input vector sizes must be greater than or equal to
/// their counterpart iteration space sizes, if static. `inputVectorShapes`
/// also allows the vectorization of operations with dynamic shapes.
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims,
Expand Down Expand Up @@ -1867,8 +1995,9 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
auto vectorizeResult =
TypeSwitch<Operation *, LogicalResult>(op)
.Case<linalg::LinalgOp>([&](auto linalgOp) {
// TODO: isaConvolutionOpInterface that can also infer from generic
// features. Will require stride/dilation attributes inference.
// TODO: isaConvolutionOpInterface that can also infer from
// generic features. Will require stride/dilation attributes
// inference.
if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
FailureOr<Operation *> convOr = vectorizeConvolution(
rewriter, linalgOp, flatten1DDepthwiseConv);
Expand Down Expand Up @@ -1902,6 +2031,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
results);
})
.Case<tensor::UnPackOp>([&](auto unpackOp) {
return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
inputVectorSizes, results);
})
.Default([](auto) { return failure(); });

if (failed(vectorizeResult)) {
Expand All @@ -1919,7 +2052,6 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,

LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
memref::CopyOp copyOp) {

auto srcType = cast<MemRefType>(copyOp.getSource().getType());
auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
Expand Down Expand Up @@ -2833,8 +2965,8 @@ struct Conv1DGenerator
Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
resPadding);

// The base vectorization case for channeled convolution is input: {n,w,c},
// weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
// The base vectorization case for channeled convolution is input:
// {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
// vectorization case, we do pre transpose on input, weight, and output.
switch (conv1DOpOrder) {
case Conv1DOpOrder::W:
Expand Down Expand Up @@ -2877,9 +3009,9 @@ struct Conv1DGenerator
return kw * (wSize / wSizeStep) + w;
};

// Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} or
// perform outerproduct for non-channeled convolution or
// perform simple arith operation for pooling
// Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
// or perform outerproduct for non-channeled convolution or perform simple
// arith operation for pooling
for (int64_t kw = 0; kw < kwSize; ++kw) {
for (int64_t w = 0; w < wSize; w += wSizeStep) {
switch (oper) {
Expand Down Expand Up @@ -2908,9 +3040,9 @@ struct Conv1DGenerator
// End vector-only rewrite part
//===------------------------------------------------------------------===//

// The base vectorization case for channeled convolution is output: {n,w,f}
// To reuse the result from base pattern vectorization case, we post
// transpose the base case result.
// The base vectorization case for channeled convolution is output:
// {n,w,f} To reuse the result from base pattern vectorization case, we
// post transpose the base case result.
switch (conv1DOpOrder) {
case Conv1DOpOrder::W:
case Conv1DOpOrder::Nwc:
Expand Down Expand Up @@ -3348,9 +3480,9 @@ static FailureOr<Operation *>
vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
bool flatten1DDepthwiseConv) {
// The ConvolutionOpInterface gives us guarantees of existence for
// strides/dilations. However, we do not need to rely on those, we can simply
// use them if present, otherwise use the default and let the generic conv.
// matcher in the ConvGenerator succeed or fail.
// strides/dilations. However, we do not need to rely on those, we can
// simply use them if present, otherwise use the default and let the generic
// conv. matcher in the ConvGenerator succeed or fail.
auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
Expand Down
Loading