Skip to content

[mlir][Vectorizer] Added support to Vectorize tensor.unpack #76087

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,14 @@ computeTransposedType(RankedTensorType rankedTensorType,
/// i.e. for a pack from an ABCD layout to an ABCDba:
/// The packed shape would be ABCDba.
/// The pre-permutation shape would be AaBbCD.
SmallVector<int64_t> getPackInverseDestPermutation(PackOp packOp);

SmallVector<int64_t> getPackUnPackInverseDestPerm(
std::variant<tensor::PackOp, tensor::UnPackOp> packOp);

/// Unpack requires some packing metadata data, so create another
/// function where this value is passed by reference.
SmallVector<int64_t> getPackUnPackInverseDestPerm(
std::variant<tensor::PackOp, tensor::UnPackOp> packOp,
PackingMetadata &PackingMetadata);
/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
/// source tensor or inserts the source tensor into a destination tensor with
/// the same shape.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3152,7 +3152,8 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(

// TODO: Check that the correct number of vectorSizes was provided.
for (Operation *target : targets) {
if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp>(target)) {
if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
target)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Unsupported Op, cannot vectorize";
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
PackingMetadata packingMetadata = computePackingMetadata(
packedTensorType.getRank(), packOp.getInnerDimsPos());
SmallVector<int64_t> packedToStripMinedShapePerm =
tensor::getPackInverseDestPermutation(packOp);
tensor::getPackUnPackInverseDestPerm(packOp);

// 3. Compute the stripMinedShape: this is the packed shape before any outer
// or inner permutations have been applied.
Expand Down
195 changes: 163 additions & 32 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1406,7 +1406,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
ArrayRef<int64_t> destShape) {
return applyPermutation(destShape,
tensor::getPackInverseDestPermutation(packOp));
tensor::getPackUnPackInverseDestPerm(packOp));
}

/// Create a TransferReadOp from `source` with static shape `readShape`. If the
Expand All @@ -1420,16 +1420,28 @@ static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
assert(sourceShape.size() == readShape.size());
auto maskType = VectorType::get(readShape, builder.getI1Type());
auto vectorType = VectorType::get(readShape, padValue.getType());
Type vecElemType = padValue != nullptr
? padValue.getType()
: cast<ShapedType>(source.getType()).getElementType();
auto vectorType = VectorType::get(readShape, vecElemType);
int64_t readRank = readShape.size();
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
auto transferReadOp = builder.create<vector::TransferReadOp>(
loc,
/*vectorType=*/vectorType,
/*source=*/source,
/*indices=*/SmallVector<Value>(readRank, zero),
/*padding=*/padValue,
/*inBounds=*/SmallVector<bool>(readRank, true));
vector::TransferReadOp transferReadOp = nullptr;
if (padValue == nullptr) {
transferReadOp = builder.create<vector::TransferReadOp>(
loc,
/*vectorType=*/vectorType,
/*source=*/source,
/*indices=*/SmallVector<Value>(readRank, zero));
} else {
transferReadOp = builder.create<vector::TransferReadOp>(
loc,
/*vectorType=*/vectorType,
/*source=*/source,
/*indices=*/SmallVector<Value>(readRank, zero),
/*padding=*/padValue,
/*inBounds=*/SmallVector<bool>(readRank, true));
}
if (llvm::equal(readShape, sourceShape)) {
return transferReadOp;
}
Expand Down Expand Up @@ -1547,7 +1559,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,

// Create TransposeOp.
auto destPermutation =
invertPermutationVector(tensor::getPackInverseDestPermutation(packOp));
invertPermutationVector(tensor::getPackUnPackInverseDestPerm(packOp));
auto transposeOp = rewriter.create<vector::TransposeOp>(
loc, shapeCastOp.getResult(), destPermutation);

Expand All @@ -1559,6 +1571,90 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
return success();
}

/// Vectorize a `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens for cases with OuterDimsPerms?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this in c7ed75e

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update doc accordingly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

/// Vector::TransferReadOp - Reads a vector from the source tensor
/// vector::TransposeOp - Transpose the Source tensor
/// ShapeCastOp - Reshape the data based on the target.
/// vector::TransferWriteOp. - Write the result vector back to the destination
/// tensor
static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
tensor::UnPackOp unpackOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults) {

OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unpackOp);

RankedTensorType unpackTensorType = unpackOp.getSourceType();

SmallVector<int64_t> readMaskShape(unpackTensorType.getShape());
llvm::ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
llvm::ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
for (unsigned int i = 0; i < inputVectorSizes.size(); i++) {
readMaskShape[i] = inputVectorSizes[i];
}
for (auto [index, size] : enumerate(innerTiles)) {
readMaskShape[innerDimPos[index]] =
llvm::divideCeil(readMaskShape[innerDimPos[index]], size);
}

// ReadMask is the size of tensor used to read and apply mask. It is
// set like this. Let's say the vectorSize (VS) array is size 'N' and
// the sourceShape(SS) is 'M' where M >= N
// Thus:
// ReadMaskShape = [VS[0], ..., VS[N-1], SS[N], ..., SS[M-1]]
ReifiedRankedShapedTypeDims reifiedRetShapes;
LogicalResult status =
cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
.reifyResultShapes(rewriter, reifiedRetShapes);
if (status.failed()) {
LDBG("Unable to reify result shapes of " << unpackOp);
return failure();
}
Location loc = unpackOp->getLoc();

// Read result, mask if necessary.
Value readResult = createReadOrMaskedRead(
rewriter, loc, unpackOp.getSource(),
llvm::ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()),
nullptr);

PackingMetadata packMetadata;
SmallVector<int64_t> lastDimToInsertPosPerm = invertPermutationVector(
tensor::getPackUnPackInverseDestPerm(unpackOp, packMetadata));
ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
RankedTensorType stripMineTensorType =
RankedTensorType::Builder(stripMineShape, stripMineElemType, {})
.setShape(stripMineShape);

// Transpose the appropriate rows to match output.
vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
loc, readResult, lastDimToInsertPosPerm);

// Collapse the vector to the size required by result.
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
stripMineTensorType, packMetadata.reassociations);
mlir::VectorType vecCollapsedType =
VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
loc, vecCollapsedType, transposeOp->getResult(0));

// WriteMaskShape had to match the shapecast shape for dynamic sizes,
// otherwise the validator complains that the mask size is invalid.
SmallVector<int64_t> writeMaskShape(
unpackOp.getDestType().hasStaticShape()
? inputVectorSizes
: shapeCastOp.getResultVectorType().getShape());
Operation *write =
createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(),
reifiedRetShapes[0], writeMaskShape);
newResults.push_back(write->getResult(0));
return success();
}

/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
/// and (3) all-zero lowPad to
/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
Expand Down Expand Up @@ -1655,6 +1751,33 @@ isValidMaskedInputVector(ArrayRef<int64_t> shape,
return success();
}

/// Need to check if the inner-tiles are static/constant.
static LogicalResult
vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
ArrayRef<int64_t> inputVectorSizes) {

// Handling this case requires a bit more change. Right now
// just the required attributes are handled.
// TODO: Handle OuterDimsPerm.
if (!unpackOp.getOuterDimsPerm().empty()) {
LDBG("outer dimensions perms NYI for: " << unpackOp);
return failure();
}

if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
return !getConstantIntValue(res).has_value();
})) {
LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
return failure();
}
llvm::ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
if (inputVectorSizes.empty() == false &&
failed(isValidMaskedInputVector(resultShape, inputVectorSizes)))
return failure();

return success();
}

static LogicalResult
vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
ArrayRef<int64_t> inputVectorSizes,
Expand Down Expand Up @@ -1703,9 +1826,10 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
}
if (isElementwise(linalgOp))
return success();
// TODO: isaConvolutionOpInterface that can also infer from generic features.
// But we will still need stride/dilation attributes that will be annoying to
// reverse-engineer...

// TODO: isaConvolutionOpInterface that can also infer from generic
// features. But we will still need stride/dilation attributes that will be
// annoying to reverse-engineer...
if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
return success();
// TODO: the common vector shape is equal to the static loop sizes only when
Expand Down Expand Up @@ -1810,6 +1934,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
.Case<tensor::PackOp>([&](auto packOp) {
return vectorizePackOpPrecondition(packOp, inputVectorSizes);
})
.Case<tensor::UnPackOp>([&](auto unpackOp) {
return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
})
.Default([](auto) { return failure(); });
}

Expand All @@ -1829,11 +1956,11 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
}

/// Emit a suitable vector form for an operation. If provided,
/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
/// must match the rank of the iteration space of the operation and the input
/// vector sizes must be greater than or equal to their counterpart iteration
/// space sizes, if static. `inputVectorShapes` also allows the vectorization of
/// operations with dynamic shapes.
/// `inputVectorSizes` are used to vectorize this operation.
/// `inputVectorSizes` must match the rank of the iteration space of the
/// operation and the input vector sizes must be greater than or equal to
/// their counterpart iteration space sizes, if static. `inputVectorShapes`
/// also allows the vectorization of operations with dynamic shapes.
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims,
Expand Down Expand Up @@ -1867,8 +1994,9 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
auto vectorizeResult =
TypeSwitch<Operation *, LogicalResult>(op)
.Case<linalg::LinalgOp>([&](auto linalgOp) {
// TODO: isaConvolutionOpInterface that can also infer from generic
// features. Will require stride/dilation attributes inference.
// TODO: isaConvolutionOpInterface that can also infer from
// generic features. Will require stride/dilation attributes
// inference.
if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
FailureOr<Operation *> convOr = vectorizeConvolution(
rewriter, linalgOp, flatten1DDepthwiseConv);
Expand Down Expand Up @@ -1902,6 +2030,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
results);
})
.Case<tensor::UnPackOp>([&](auto unpackOp) {
return vectorizeAsUnpackOp(rewriter, unpackOp, inputVectorSizes,
results);
})
.Default([](auto) { return failure(); });

if (failed(vectorizeResult)) {
Expand All @@ -1919,7 +2051,6 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,

LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
memref::CopyOp copyOp) {

auto srcType = cast<MemRefType>(copyOp.getSource().getType());
auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
Expand Down Expand Up @@ -2833,8 +2964,8 @@ struct Conv1DGenerator
Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
resPadding);

// The base vectorization case for channeled convolution is input: {n,w,c},
// weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
// The base vectorization case for channeled convolution is input:
// {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
// vectorization case, we do pre transpose on input, weight, and output.
switch (conv1DOpOrder) {
case Conv1DOpOrder::W:
Expand Down Expand Up @@ -2877,9 +3008,9 @@ struct Conv1DGenerator
return kw * (wSize / wSizeStep) + w;
};

// Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} or
// perform outerproduct for non-channeled convolution or
// perform simple arith operation for pooling
// Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
// or perform outerproduct for non-channeled convolution or perform simple
// arith operation for pooling
for (int64_t kw = 0; kw < kwSize; ++kw) {
for (int64_t w = 0; w < wSize; w += wSizeStep) {
switch (oper) {
Expand Down Expand Up @@ -2908,9 +3039,9 @@ struct Conv1DGenerator
// End vector-only rewrite part
//===------------------------------------------------------------------===//

// The base vectorization case for channeled convolution is output: {n,w,f}
// To reuse the result from base pattern vectorization case, we post
// transpose the base case result.
// The base vectorization case for channeled convolution is output:
// {n,w,f} To reuse the result from base pattern vectorization case, we
// post transpose the base case result.
switch (conv1DOpOrder) {
case Conv1DOpOrder::W:
case Conv1DOpOrder::Nwc:
Expand Down Expand Up @@ -3348,9 +3479,9 @@ static FailureOr<Operation *>
vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
bool flatten1DDepthwiseConv) {
// The ConvolutionOpInterface gives us guarantees of existence for
// strides/dilations. However, we do not need to rely on those, we can simply
// use them if present, otherwise use the default and let the generic conv.
// matcher in the ConvGenerator succeed or fail.
// strides/dilations. However, we do not need to rely on those, we can
// simply use them if present, otherwise use the default and let the generic
// conv. matcher in the ConvGenerator succeed or fail.
auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
Expand Down
Loading