Skip to content

Commit b367636

Browse files
committed
fixup! [mlir][linalg] Add scalable vectorisation for depthwise convolutions
* Add missing dyn dimension in a test * Make sure "flattening" + "masked vectorisation" are not allowed
1 parent 2528f8e commit b367636

File tree

5 files changed

+53
-27
lines changed

5 files changed

+53
-27
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,8 @@ LogicalResult promoteSubviewsPrecondition(Operation *op,
460460
LogicalResult vectorizeOpPrecondition(Operation *op,
461461
ArrayRef<int64_t> inputVectorSizes = {},
462462
ArrayRef<bool> inputScalableVecDims = {},
463-
bool vectorizeNDExtract = false);
463+
bool vectorizeNDExtract = false,
464+
bool flatten1DDepthwiseConv = false);
464465

465466
//===----------------------------------------------------------------------===//
466467
// Transformations exposed as functional-style API calls.

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1715,9 +1715,17 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
17151715
return success();
17161716
}
17171717

1718-
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv) {
1718+
static LogicalResult
1719+
vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
1720+
bool flatten1DDepthwiseConv) {
1721+
if (flatten1DDepthwiseConv) {
1722+
LDBG("Vectorization of flattened convs with dynamic shapes is not "
1723+
"supported\n");
1724+
return failure();
1725+
}
1726+
17191727
if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1720-
LDBG("Not a depth-wise 1D conv, dynamic shapes are not supported\n");
1728+
LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
17211729
return failure();
17221730
}
17231731

@@ -1735,9 +1743,11 @@ static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv) {
17351743
return success();
17361744
}
17371745

1738-
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
1746+
static LogicalResult
1747+
vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
1748+
bool flatten1DDepthwiseConv) {
17391749
if (isa<ConvolutionOpInterface>(op.getOperation()))
1740-
return vectorizeDynamicConvOpPrecondition(op);
1750+
return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
17411751

17421752
// TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
17431753
// linalg.copy ops and ops that implement ContractionOpInterface for now.
@@ -1804,10 +1814,9 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
18041814
return success();
18051815
}
18061816

1807-
static LogicalResult
1808-
vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
1809-
ArrayRef<int64_t> inputVectorSizes,
1810-
bool vectorizeNDExtract) {
1817+
static LogicalResult vectorizeLinalgOpPrecondition(
1818+
LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
1819+
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
18111820
// tensor with dimension of 0 cannot be vectorized.
18121821
if (llvm::is_contained(linalgOp.getStaticShape(), 0))
18131822
return failure();
@@ -1817,8 +1826,8 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
18171826
inputVectorSizes)))
18181827
return failure();
18191828

1820-
if (linalgOp.hasDynamicShape() &&
1821-
failed(vectorizeDynamicLinalgOpPrecondition(linalgOp))) {
1829+
if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
1830+
linalgOp, flatten1DDepthwiseConv))) {
18221831
LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
18231832
return failure();
18241833
}
@@ -1946,15 +1955,17 @@ vectorizeScalableVectorPrecondition(Operation *op,
19461955

19471956
LogicalResult mlir::linalg::vectorizeOpPrecondition(
19481957
Operation *op, ArrayRef<int64_t> inputVectorSizes,
1949-
ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract) {
1958+
ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
1959+
bool flatten1DDepthwiseConv) {
19501960
if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
19511961
inputScalableVecDims)))
19521962
return failure();
19531963

19541964
return TypeSwitch<Operation *, LogicalResult>(op)
19551965
.Case<linalg::LinalgOp>([&](auto linalgOp) {
19561966
return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
1957-
vectorizeNDExtract);
1967+
vectorizeNDExtract,
1968+
flatten1DDepthwiseConv);
19581969
})
19591970
.Case<tensor::PadOp>([&](auto padOp) {
19601971
return vectorizePadOpPrecondition(padOp, inputVectorSizes);
@@ -2003,7 +2014,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
20032014
LLVM_DEBUG(llvm::dbgs() << "\n");
20042015

20052016
if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2006-
vectorizeNDExtract))) {
2017+
vectorizeNDExtract,
2018+
flatten1DDepthwiseConv))) {
20072019
LDBG("Vectorization pre-conditions failed\n");
20082020
return failure();
20092021
}
@@ -3180,6 +3192,10 @@ struct Conv1DGenerator
31803192
scalableChDim = channelDimScalableFlag;
31813193
useMasking = true;
31823194
}
3195+
3196+
assert(!(useMasking && flatten) &&
3197+
"Unsupported flattened conv with dynamic shapes");
3198+
31833199
// out{n, w, c}
31843200
bindShapeDims(resShapedType, nSize, wSize);
31853201

@@ -3282,10 +3298,15 @@ struct Conv1DGenerator
32823298
return kw * (wSize / wSizeStep) + w;
32833299
};
32843300

3301+
// Note - the scalable flags are ignored as flattening combined with
3302+
// scalable vectorization is not supported.
32853303
auto inOutFlattenSliceSizes =
32863304
SmallVector<int64_t>{nSize, wSizeStep * cSize};
3287-
auto lhsCastType = VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3288-
auto resCastType = VectorType::get(inOutFlattenSliceSizes, resEltType);
3305+
auto lhsTypeAfterFlattening =
3306+
VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3307+
auto resTypeAfterFlattening =
3308+
VectorType::get(inOutFlattenSliceSizes, resEltType);
3309+
32893310
// Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
32903311
for (int64_t kw = 0; kw < kwSize; ++kw) {
32913312
for (int64_t w = 0; w < wSize; w += wSizeStep) {
@@ -3295,9 +3316,9 @@ struct Conv1DGenerator
32953316
// Flatten the input and output vectors (collapse the channel
32963317
// dimension)
32973318
lhsVal = rewriter.create<vector::ShapeCastOp>(
3298-
loc, lhsCastType, lhsVals[linearIndex(kw, w)]);
3299-
resVal = rewriter.create<vector::ShapeCastOp>(loc, resCastType,
3300-
resVals[w]);
3319+
loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3320+
resVal = rewriter.create<vector::ShapeCastOp>(
3321+
loc, resTypeAfterFlattening, resVals[w]);
33013322
}
33023323
resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
33033324
rhsVals[kw], resVal, flatten);
@@ -3353,6 +3374,10 @@ struct Conv1DGenerator
33533374
lhs = promote(rewriter, loc, lhs, resTy);
33543375

33553376
if (flatten) {
3377+
// NOTE: This following logic won't work for scalable vectors. For this
3378+
// reason, "flattening" is not supported when shapes are dynamic (this
3379+
// should be captured by one of the pre-conditions).
3380+
33563381
// There are two options for handling the filter:
33573382
// * shape_cast(broadcast(filter))
33583383
// * broadcast(shuffle(filter))

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,14 +306,14 @@ SmallVector<OpFoldResult> vector::getMixedSizesXfer(bool hasTensorSemantics,
306306
RewriterBase &rewriter) {
307307
auto loc = xfer->getLoc();
308308

309-
Value blah = TypeSwitch<Operation *, Value>(xfer)
309+
Value base = TypeSwitch<Operation *, Value>(xfer)
310310
.Case<vector::TransferReadOp>(
311311
[&](auto readOp) { return readOp.getSource(); })
312312
.Case<vector::TransferWriteOp>(
313313
[&](auto writeOp) { return writeOp.getOperand(1); });
314314

315315
SmallVector<OpFoldResult> mixedSourceDims =
316-
hasTensorSemantics ? tensor::getMixedSizes(rewriter, loc, blah)
317-
: memref::getMixedSizes(rewriter, loc, blah);
316+
hasTensorSemantics ? tensor::getMixedSizes(rewriter, loc, base)
317+
: memref::getMixedSizes(rewriter, loc, base);
318318
return mixedSourceDims;
319319
}

mlir/test/Dialect/Linalg/vectorization-unsupported.mlir

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@ module attributes {transform.with_named_sequence} {
1919

2020
// -----
2121

22-
func.func @depthwise_conv1d_ncw_cw(%input: memref<3x5x4xf32>, %filter: memref<5x1xf32>, %output: memref<3x5x4xf32>) {
22+
// Masked vectorisation of 1D depthwise CW convs is not yet supported
23+
24+
func.func @depthwise_conv1d_ncw_cw(%input: memref<3x?x4xf32>, %filter: memref<?x1xf32>, %output: memref<3x?x4xf32>) {
2325
// expected-error @+1 {{Attempted to vectorize, but failed}}
2426
linalg.depthwise_conv_1d_ncw_cw
2527
{dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
26-
ins(%input, %filter : memref<3x5x4xf32>, memref<5x1xf32>)
27-
outs(%output : memref<3x5x4xf32>)
28+
ins(%input, %filter : memref<3x?x4xf32>, memref<?x1xf32>)
29+
outs(%output : memref<3x?x4xf32>)
2830
return
2931
}
3032

mlir/test/Dialect/Linalg/vectorize-conv-masked-and-scalable.mlir

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,6 @@ module attributes {transform.with_named_sequence} {
120120
// CHECK: %[[OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<1x8x[4]xi8>, tensor<1x8x?xi8> } : vector<1x8x[4]xi1> -> tensor<1x8x?xi8>
121121
// CHECK: return %[[OUT]] : tensor<1x8x?xi8>
122122

123-
124-
125123
// -----
126124

127125
func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(

0 commit comments

Comments
 (0)