Skip to content

Commit c56bd7a

Browse files
authored
[mlir][linalg] Enable masked vectorisation for depthwise convolutions (#81625)
This patch adds support for masked vectorisation of depthwise 1D WC convolutions,`linalg.depthwise_conv_1d_nwc_wc`. This is implemented by adding support for masking. Two major assumptions are made: * only the channel dimension can be dynamic/scalable (i.e. the trailing dim), * when specifying vector sizes to use in the vectoriser, only the size corresponding to the channel dim is effectively used (other dims are inferred from the context). In terms of scalable vectorisation, this should be sufficient to cover all practical cases (i.e. making arbitrary dim scalable wouldn't make much sense). As for more generic cases with dynamic shapes (e.g. W or N dims being dynamic), more work would be needed. In particular, one would have to consider the filter and input/output tensors separately.
1 parent bba790d commit c56bd7a

File tree

6 files changed

+418
-40
lines changed

6 files changed

+418
-40
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,8 @@ LogicalResult promoteSubviewsPrecondition(Operation *op,
460460
LogicalResult vectorizeOpPrecondition(Operation *op,
461461
ArrayRef<int64_t> inputVectorSizes = {},
462462
ArrayRef<bool> inputScalableVecDims = {},
463-
bool vectorizeNDExtract = false);
463+
bool vectorizeNDExtract = false,
464+
bool flatten1DDepthwiseConv = false);
464465

465466
//===----------------------------------------------------------------------===//
466467
// Transformations exposed as functional-style API calls.

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,15 @@
99
#ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
1010
#define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
1111

12+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
13+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1214
#include "mlir/Dialect/Utils/IndexingUtils.h"
1315
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1416
#include "mlir/IR/BuiltinAttributes.h"
1517
#include "mlir/Support/LLVM.h"
1618

1719
#include "llvm/ADT/DenseMap.h"
20+
#include "llvm/ADT/TypeSwitch.h"
1821

1922
namespace mlir {
2023

@@ -98,6 +101,17 @@ bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
98101
std::optional<StaticTileOffsetRange>
99102
createUnrollIterator(VectorType vType, int64_t targetRank = 1);
100103

104+
/// A wrapper for getMixedSizes for vector.transfer_read and
105+
/// vector.transfer_write Ops (for source and destination, respectively).
106+
///
107+
/// Tensor and MemRef types implement their own, very similar version of
108+
/// getMixedSizes. This method will call the appropriate version (depending on
109+
/// `hasTensorSemantics`). It will also automatically extract the operand for
110+
/// which to call it on (source for "read" and destination for "write" ops).
111+
SmallVector<OpFoldResult> getMixedSizesXfer(bool hasTensorSemantics,
112+
Operation *xfer,
113+
RewriterBase &rewriter);
114+
101115
} // namespace vector
102116

103117
/// Constructs a permutation map of invariant memref indices to vector

0 commit comments

Comments
 (0)