-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][Vectorizer] Added support to Vectorize tensor.unpack #76087
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
c0c6432
853a735
a48dfac
70cc122
c33642b
744a291
d5a0dec
59d761f
c7ed75e
a349b14
e8e0d88
524c0d9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1406,7 +1406,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, | |
static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp, | ||
ArrayRef<int64_t> destShape) { | ||
return applyPermutation(destShape, | ||
tensor::getPackInverseDestPermutation(packOp)); | ||
tensor::getPackUnPackInverseDestPerm(packOp)); | ||
} | ||
|
||
/// Create a TransferReadOp from `source` with static shape `readShape`. If the | ||
|
@@ -1420,16 +1420,28 @@ static Value createReadOrMaskedRead(OpBuilder &builder, Location loc, | |
auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape(); | ||
assert(sourceShape.size() == readShape.size()); | ||
auto maskType = VectorType::get(readShape, builder.getI1Type()); | ||
auto vectorType = VectorType::get(readShape, padValue.getType()); | ||
Type vecElemType = padValue != nullptr | ||
bviyer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
? padValue.getType() | ||
: cast<ShapedType>(source.getType()).getElementType(); | ||
bviyer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
auto vectorType = VectorType::get(readShape, vecElemType); | ||
int64_t readRank = readShape.size(); | ||
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0); | ||
auto transferReadOp = builder.create<vector::TransferReadOp>( | ||
loc, | ||
/*vectorType=*/vectorType, | ||
/*source=*/source, | ||
/*indices=*/SmallVector<Value>(readRank, zero), | ||
/*padding=*/padValue, | ||
/*inBounds=*/SmallVector<bool>(readRank, true)); | ||
vector::TransferReadOp transferReadOp = nullptr; | ||
if (padValue == nullptr) { | ||
transferReadOp = builder.create<vector::TransferReadOp>( | ||
loc, | ||
/*vectorType=*/vectorType, | ||
/*source=*/source, | ||
/*indices=*/SmallVector<Value>(readRank, zero)); | ||
} else { | ||
transferReadOp = builder.create<vector::TransferReadOp>( | ||
loc, | ||
/*vectorType=*/vectorType, | ||
/*source=*/source, | ||
/*indices=*/SmallVector<Value>(readRank, zero), | ||
/*padding=*/padValue, | ||
bviyer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/*inBounds=*/SmallVector<bool>(readRank, true)); | ||
} | ||
if (llvm::equal(readShape, sourceShape)) { | ||
return transferReadOp; | ||
} | ||
|
@@ -1547,7 +1559,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp, | |
|
||
// Create TransposeOp. | ||
auto destPermutation = | ||
invertPermutationVector(tensor::getPackInverseDestPermutation(packOp)); | ||
invertPermutationVector(tensor::getPackUnPackInverseDestPerm(packOp)); | ||
auto transposeOp = rewriter.create<vector::TransposeOp>( | ||
loc, shapeCastOp.getResult(), destPermutation); | ||
|
||
|
@@ -1559,6 +1571,90 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp, | |
return success(); | ||
} | ||
|
||
/// Vectorize a `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happens for cases with OuterDimsPerms? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added this in c7ed75e There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. update doc accordingly? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
/// Vector::TransferReadOp - Reads a vector from the source tensor | ||
/// vector::TransposeOp - Transpose the Source tensor | ||
/// ShapeCastOp - Reshape the data based on the target. | ||
/// vector::TransferWriteOp. - Write the result vector back to the destination | ||
bviyer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/// tensor | ||
static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter, | ||
tensor::UnPackOp unpackOp, | ||
ArrayRef<int64_t> inputVectorSizes, | ||
bviyer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
SmallVectorImpl<Value> &newResults) { | ||
|
||
bviyer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
OpBuilder::InsertionGuard g(rewriter); | ||
rewriter.setInsertionPoint(unpackOp); | ||
|
||
RankedTensorType unpackTensorType = unpackOp.getSourceType(); | ||
|
||
SmallVector<int64_t> readMaskShape(unpackTensorType.getShape()); | ||
bviyer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
llvm::ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos(); | ||
llvm::ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles(); | ||
bviyer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for (unsigned int i = 0; i < inputVectorSizes.size(); i++) { | ||
readMaskShape[i] = inputVectorSizes[i]; | ||
} | ||
bviyer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for (auto [index, size] : enumerate(innerTiles)) { | ||
readMaskShape[innerDimPos[index]] = | ||
llvm::divideCeil(readMaskShape[innerDimPos[index]], size); | ||
} | ||
|
||
// ReadMask is the size of tensor used to read and apply mask. It is | ||
// set like this. Let's say the vectorSize (VS) array is size 'N' and | ||
// the sourceShape(SS) is 'M' where M >= N | ||
// Thus: | ||
// ReadMaskShape = [VS[0], ..., VS[N-1], SS[N], ..., SS[M-1]] | ||
bviyer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ReifiedRankedShapedTypeDims reifiedRetShapes; | ||
LogicalResult status = | ||
cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation()) | ||
.reifyResultShapes(rewriter, reifiedRetShapes); | ||
if (status.failed()) { | ||
LDBG("Unable to reify result shapes of " << unpackOp); | ||
return failure(); | ||
} | ||
Location loc = unpackOp->getLoc(); | ||
|
||
// Read result, mask if necessary. | ||
Value readResult = createReadOrMaskedRead( | ||
rewriter, loc, unpackOp.getSource(), | ||
llvm::ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), | ||
nullptr); | ||
|
||
PackingMetadata packMetadata; | ||
SmallVector<int64_t> lastDimToInsertPosPerm = invertPermutationVector( | ||
tensor::getPackUnPackInverseDestPerm(unpackOp, packMetadata)); | ||
ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType()); | ||
SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape()); | ||
mlir::Type stripMineElemType = maskedOpShapedType.getElementType(); | ||
applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm); | ||
RankedTensorType stripMineTensorType = | ||
RankedTensorType::Builder(stripMineShape, stripMineElemType, {}) | ||
.setShape(stripMineShape); | ||
bviyer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// Transpose the appropriate rows to match output. | ||
vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>( | ||
loc, readResult, lastDimToInsertPosPerm); | ||
|
||
// Collapse the vector to the size required by result. | ||
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType( | ||
stripMineTensorType, packMetadata.reassociations); | ||
mlir::VectorType vecCollapsedType = | ||
VectorType::get(collapsedType.getShape(), collapsedType.getElementType()); | ||
vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>( | ||
loc, vecCollapsedType, transposeOp->getResult(0)); | ||
|
||
// WriteMaskShape had to match the shapecast shape for dynamic sizes, | ||
// otherwise the validator complains that the mask size is invalid. | ||
SmallVector<int64_t> writeMaskShape( | ||
unpackOp.getDestType().hasStaticShape() | ||
? inputVectorSizes | ||
: shapeCastOp.getResultVectorType().getShape()); | ||
Operation *write = | ||
createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(), | ||
reifiedRetShapes[0], writeMaskShape); | ||
newResults.push_back(write->getResult(0)); | ||
return success(); | ||
} | ||
|
||
/// Vectorize a `padOp` with (1) static result type, (2) constant padding value | ||
/// and (3) all-zero lowPad to | ||
/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`. | ||
|
@@ -1655,6 +1751,33 @@ isValidMaskedInputVector(ArrayRef<int64_t> shape, | |
return success(); | ||
} | ||
|
||
/// Need to check if the inner-tiles are static/constant. | ||
static LogicalResult | ||
vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp, | ||
ArrayRef<int64_t> inputVectorSizes) { | ||
|
||
// Handling this case requires a bit more change. Right now | ||
// just the required attributes are handled. | ||
// TODO: Handle OuterDimsPerm. | ||
if (!unpackOp.getOuterDimsPerm().empty()) { | ||
LDBG("outer dimensions perms NYI for: " << unpackOp); | ||
return failure(); | ||
} | ||
|
||
if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) { | ||
return !getConstantIntValue(res).has_value(); | ||
})) { | ||
LDBG("Inner-tiles must be constant: " << unpackOp << "\n"); | ||
return failure(); | ||
} | ||
llvm::ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape(); | ||
if (inputVectorSizes.empty() == false && | ||
bviyer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
failed(isValidMaskedInputVector(resultShape, inputVectorSizes))) | ||
return failure(); | ||
|
||
return success(); | ||
} | ||
|
||
static LogicalResult | ||
vectorizeLinalgOpPrecondition(LinalgOp linalgOp, | ||
ArrayRef<int64_t> inputVectorSizes, | ||
|
@@ -1703,9 +1826,10 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp, | |
} | ||
if (isElementwise(linalgOp)) | ||
return success(); | ||
// TODO: isaConvolutionOpInterface that can also infer from generic features. | ||
// But we will still need stride/dilation attributes that will be annoying to | ||
// reverse-engineer... | ||
|
||
// TODO: isaConvolutionOpInterface that can also infer from generic | ||
// features. But we will still need stride/dilation attributes that will be | ||
// annoying to reverse-engineer... | ||
if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) | ||
return success(); | ||
// TODO: the common vector shape is equal to the static loop sizes only when | ||
|
@@ -1810,6 +1934,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition( | |
.Case<tensor::PackOp>([&](auto packOp) { | ||
return vectorizePackOpPrecondition(packOp, inputVectorSizes); | ||
}) | ||
.Case<tensor::UnPackOp>([&](auto unpackOp) { | ||
return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes); | ||
}) | ||
.Default([](auto) { return failure(); }); | ||
} | ||
|
||
|
@@ -1829,11 +1956,11 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) { | |
} | ||
|
||
/// Emit a suitable vector form for an operation. If provided, | ||
/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes` | ||
/// must match the rank of the iteration space of the operation and the input | ||
/// vector sizes must be greater than or equal to their counterpart iteration | ||
/// space sizes, if static. `inputVectorShapes` also allows the vectorization of | ||
/// operations with dynamic shapes. | ||
/// `inputVectorSizes` are used to vectorize this operation. | ||
/// `inputVectorSizes` must match the rank of the iteration space of the | ||
/// operation and the input vector sizes must be greater than or equal to | ||
/// their counterpart iteration space sizes, if static. `inputVectorShapes` | ||
/// also allows the vectorization of operations with dynamic shapes. | ||
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, | ||
ArrayRef<int64_t> inputVectorSizes, | ||
ArrayRef<bool> inputScalableVecDims, | ||
|
@@ -1867,8 +1994,9 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, | |
auto vectorizeResult = | ||
TypeSwitch<Operation *, LogicalResult>(op) | ||
.Case<linalg::LinalgOp>([&](auto linalgOp) { | ||
// TODO: isaConvolutionOpInterface that can also infer from generic | ||
// features. Will require stride/dilation attributes inference. | ||
// TODO: isaConvolutionOpInterface that can also infer from | ||
// generic features. Will require stride/dilation attributes | ||
// inference. | ||
if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) { | ||
FailureOr<Operation *> convOr = vectorizeConvolution( | ||
rewriter, linalgOp, flatten1DDepthwiseConv); | ||
|
@@ -1902,6 +2030,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, | |
return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes, | ||
results); | ||
}) | ||
.Case<tensor::UnPackOp>([&](auto unpackOp) { | ||
return vectorizeAsUnpackOp(rewriter, unpackOp, inputVectorSizes, | ||
results); | ||
}) | ||
.Default([](auto) { return failure(); }); | ||
|
||
if (failed(vectorizeResult)) { | ||
|
@@ -1919,7 +2051,6 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, | |
|
||
LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter, | ||
memref::CopyOp copyOp) { | ||
|
||
auto srcType = cast<MemRefType>(copyOp.getSource().getType()); | ||
auto dstType = cast<MemRefType>(copyOp.getTarget().getType()); | ||
if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) | ||
|
@@ -2833,8 +2964,8 @@ struct Conv1DGenerator | |
Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped, | ||
resPadding); | ||
|
||
// The base vectorization case for channeled convolution is input: {n,w,c}, | ||
// weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern | ||
// The base vectorization case for channeled convolution is input: | ||
// {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern | ||
// vectorization case, we do pre transpose on input, weight, and output. | ||
switch (conv1DOpOrder) { | ||
case Conv1DOpOrder::W: | ||
|
@@ -2877,9 +3008,9 @@ struct Conv1DGenerator | |
return kw * (wSize / wSizeStep) + w; | ||
}; | ||
|
||
// Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} or | ||
// perform outerproduct for non-channeled convolution or | ||
// perform simple arith operation for pooling | ||
// Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} | ||
// or perform outerproduct for non-channeled convolution or perform simple | ||
// arith operation for pooling | ||
for (int64_t kw = 0; kw < kwSize; ++kw) { | ||
for (int64_t w = 0; w < wSize; w += wSizeStep) { | ||
switch (oper) { | ||
|
@@ -2908,9 +3039,9 @@ struct Conv1DGenerator | |
// End vector-only rewrite part | ||
//===------------------------------------------------------------------===// | ||
|
||
// The base vectorization case for channeled convolution is output: {n,w,f} | ||
// To reuse the result from base pattern vectorization case, we post | ||
// transpose the base case result. | ||
// The base vectorization case for channeled convolution is output: | ||
// {n,w,f} To reuse the result from base pattern vectorization case, we | ||
// post transpose the base case result. | ||
switch (conv1DOpOrder) { | ||
case Conv1DOpOrder::W: | ||
case Conv1DOpOrder::Nwc: | ||
|
@@ -3348,9 +3479,9 @@ static FailureOr<Operation *> | |
vectorizeConvolution(RewriterBase &rewriter, LinalgOp op, | ||
bool flatten1DDepthwiseConv) { | ||
// The ConvolutionOpInterface gives us guarantees of existence for | ||
// strides/dilations. However, we do not need to rely on those, we can simply | ||
// use them if present, otherwise use the default and let the generic conv. | ||
// matcher in the ConvGenerator succeed or fail. | ||
// strides/dilations. However, we do not need to rely on those, we can | ||
// simply use them if present, otherwise use the default and let the generic | ||
// conv. matcher in the ConvGenerator succeed or fail. | ||
auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides"); | ||
auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations"); | ||
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1; | ||
|
Uh oh!
There was an error while loading. Please reload this page.