@@ -1715,9 +1715,17 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
1715
1715
return success ();
1716
1716
}
1717
1717
1718
- static LogicalResult vectorizeDynamicConvOpPrecondition (linalg::LinalgOp conv) {
1718
+ static LogicalResult
1719
+ vectorizeDynamicConvOpPrecondition (linalg::LinalgOp conv,
1720
+ bool flatten1DDepthwiseConv) {
1721
+ if (flatten1DDepthwiseConv) {
1722
+ LDBG (" Vectorization of flattened convs with dynamic shapes is not "
1723
+ " supported\n " );
1724
+ return failure ();
1725
+ }
1726
+
1719
1727
if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1720
- LDBG (" Not a depth-wise 1D conv, dynamic shapes are not supported\n " );
1728
+ LDBG (" Not a 1D depth-wise WC conv, dynamic shapes are not supported\n " );
1721
1729
return failure ();
1722
1730
}
1723
1731
@@ -1735,9 +1743,11 @@ static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv) {
1735
1743
return success ();
1736
1744
}
1737
1745
1738
- static LogicalResult vectorizeDynamicLinalgOpPrecondition (linalg::LinalgOp op) {
1746
+ static LogicalResult
1747
+ vectorizeDynamicLinalgOpPrecondition (linalg::LinalgOp op,
1748
+ bool flatten1DDepthwiseConv) {
1739
1749
if (isa<ConvolutionOpInterface>(op.getOperation ()))
1740
- return vectorizeDynamicConvOpPrecondition (op);
1750
+ return vectorizeDynamicConvOpPrecondition (op, flatten1DDepthwiseConv );
1741
1751
1742
1752
// TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
1743
1753
// linalg.copy ops and ops that implement ContractionOpInterface for now.
@@ -1804,10 +1814,9 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
1804
1814
return success ();
1805
1815
}
1806
1816
1807
- static LogicalResult
1808
- vectorizeLinalgOpPrecondition (LinalgOp linalgOp,
1809
- ArrayRef<int64_t > inputVectorSizes,
1810
- bool vectorizeNDExtract) {
1817
+ static LogicalResult vectorizeLinalgOpPrecondition (
1818
+ LinalgOp linalgOp, ArrayRef<int64_t > inputVectorSizes,
1819
+ bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
1811
1820
// tensor with dimension of 0 cannot be vectorized.
1812
1821
if (llvm::is_contained (linalgOp.getStaticShape (), 0 ))
1813
1822
return failure ();
@@ -1817,8 +1826,8 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
1817
1826
inputVectorSizes)))
1818
1827
return failure ();
1819
1828
1820
- if (linalgOp.hasDynamicShape () &&
1821
- failed ( vectorizeDynamicLinalgOpPrecondition ( linalgOp))) {
1829
+ if (linalgOp.hasDynamicShape () && failed ( vectorizeDynamicLinalgOpPrecondition (
1830
+ linalgOp, flatten1DDepthwiseConv ))) {
1822
1831
LDBG (" Dynamically-shaped op failed vectorization pre-conditions\n " );
1823
1832
return failure ();
1824
1833
}
@@ -1946,15 +1955,17 @@ vectorizeScalableVectorPrecondition(Operation *op,
1946
1955
1947
1956
LogicalResult mlir::linalg::vectorizeOpPrecondition (
1948
1957
Operation *op, ArrayRef<int64_t > inputVectorSizes,
1949
- ArrayRef<bool > inputScalableVecDims, bool vectorizeNDExtract) {
1958
+ ArrayRef<bool > inputScalableVecDims, bool vectorizeNDExtract,
1959
+ bool flatten1DDepthwiseConv) {
1950
1960
if (failed (vectorizeScalableVectorPrecondition (op, inputVectorSizes,
1951
1961
inputScalableVecDims)))
1952
1962
return failure ();
1953
1963
1954
1964
return TypeSwitch<Operation *, LogicalResult>(op)
1955
1965
.Case <linalg::LinalgOp>([&](auto linalgOp) {
1956
1966
return vectorizeLinalgOpPrecondition (linalgOp, inputVectorSizes,
1957
- vectorizeNDExtract);
1967
+ vectorizeNDExtract,
1968
+ flatten1DDepthwiseConv);
1958
1969
})
1959
1970
.Case <tensor::PadOp>([&](auto padOp) {
1960
1971
return vectorizePadOpPrecondition (padOp, inputVectorSizes);
@@ -2003,7 +2014,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2003
2014
LLVM_DEBUG (llvm::dbgs () << " \n " );
2004
2015
2005
2016
if (failed (vectorizeOpPrecondition (op, inputVectorSizes, inputScalableVecDims,
2006
- vectorizeNDExtract))) {
2017
+ vectorizeNDExtract,
2018
+ flatten1DDepthwiseConv))) {
2007
2019
LDBG (" Vectorization pre-conditions failed\n " );
2008
2020
return failure ();
2009
2021
}
@@ -3180,6 +3192,10 @@ struct Conv1DGenerator
3180
3192
scalableChDim = channelDimScalableFlag;
3181
3193
useMasking = true ;
3182
3194
}
3195
+
3196
+ assert (!(useMasking && flatten) &&
3197
+ " Unsupported flattened conv with dynamic shapes" );
3198
+
3183
3199
// out{n, w, c}
3184
3200
bindShapeDims (resShapedType, nSize, wSize);
3185
3201
@@ -3282,10 +3298,15 @@ struct Conv1DGenerator
3282
3298
return kw * (wSize / wSizeStep) + w;
3283
3299
};
3284
3300
3301
+ // Note - the scalable flags are ignored as flattening combined with
3302
+ // scalable vectorization is not supported.
3285
3303
auto inOutFlattenSliceSizes =
3286
3304
SmallVector<int64_t >{nSize, wSizeStep * cSize};
3287
- auto lhsCastType = VectorType::get (inOutFlattenSliceSizes, lhsEltType);
3288
- auto resCastType = VectorType::get (inOutFlattenSliceSizes, resEltType);
3305
+ auto lhsTypeAfterFlattening =
3306
+ VectorType::get (inOutFlattenSliceSizes, lhsEltType);
3307
+ auto resTypeAfterFlattening =
3308
+ VectorType::get (inOutFlattenSliceSizes, resEltType);
3309
+
3289
3310
// Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3290
3311
for (int64_t kw = 0 ; kw < kwSize; ++kw) {
3291
3312
for (int64_t w = 0 ; w < wSize; w += wSizeStep) {
@@ -3295,9 +3316,9 @@ struct Conv1DGenerator
3295
3316
// Flatten the input and output vectors (collapse the channel
3296
3317
// dimension)
3297
3318
lhsVal = rewriter.create <vector::ShapeCastOp>(
3298
- loc, lhsCastType , lhsVals[linearIndex (kw, w)]);
3299
- resVal = rewriter.create <vector::ShapeCastOp>(loc, resCastType,
3300
- resVals[w]);
3319
+ loc, lhsTypeAfterFlattening , lhsVals[linearIndex (kw, w)]);
3320
+ resVal = rewriter.create <vector::ShapeCastOp>(
3321
+ loc, resTypeAfterFlattening, resVals[w]);
3301
3322
}
3302
3323
resVals[w] = depthwiseConv1dSliceAsMulAcc (rewriter, loc, lhsVal,
3303
3324
rhsVals[kw], resVal, flatten);
@@ -3353,6 +3374,10 @@ struct Conv1DGenerator
3353
3374
lhs = promote (rewriter, loc, lhs, resTy);
3354
3375
3355
3376
if (flatten) {
3377
+ // NOTE: This following logic won't work for scalable vectors. For this
3378
+ // reason, "flattening" is not supported when shapes are dynamic (this
3379
+ // should be captured by one of the pre-conditions).
3380
+
3356
3381
// There are two options for handling the filter:
3357
3382
// * shape_cast(broadcast(filter))
3358
3383
// * broadcast(shuffle(filter))
0 commit comments