Skip to content

[mlir] Add a contiguous<perm, offset> layout, use as identity layout #131663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions mlir/include/mlir-c/BuiltinAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,13 @@ MLIR_CAPI_EXPORTED MlirAttribute
mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, intptr_t numStrides,
const int64_t *strides);

// Creates a strided layout attribute from given strides and offset,
// canonicalizing the 0D and 1D unit stride to contiguous layout attributes. The
// returned value may not be a StridedLayoutAttr.
MLIR_CAPI_EXPORTED MlirAttribute
mlirStridedLayoutAttrGetCanonical(MlirContext ctx, int64_t offset,
intptr_t numStrides, const int64_t *strides);

// Returns the offset in the given strided layout layout attribute.
MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr);

Expand All @@ -711,6 +718,38 @@ MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr,
/// Returns the typeID of a StridedLayout attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirStridedLayoutAttrGetTypeID(void);

//===----------------------------------------------------------------------===//
// Contiguous layout attribute.
//===----------------------------------------------------------------------===//

// Checks wheather the given attribute is a contiguous layout attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAContiguousLayout(MlirAttribute attr);

// Creates a contiguous layout attribute from given permutation and offset.
// There must be `rank` values in `permutation`.
MLIR_CAPI_EXPORTED MlirAttribute mlirContiguousLayoutAttrGet(
MlirContext ctx, int64_t offset, intptr_t rank, const int64_t *permutation);

// Creates a row-major contiguous layout attribute from given offset and rank.
MLIR_CAPI_EXPORTED MlirAttribute mlirContiguousLayoutAttrGetRowMajor(
MlirContext ctx, int64_t offset, int64_t rank);

// Returns the offset in the given contiguous layout attribute.
MLIR_CAPI_EXPORTED int64_t
mlirContiguousLayoutAttrGetOffset(MlirAttribute attr);

// Returns the number of permutation entries in the given contiguous layout
// attribute.
MLIR_CAPI_EXPORTED intptr_t mlirContiguousLayoutAttrGetRank(MlirAttribute attr);

// Returns the pos-th permutation entry stored in the given contiguous layout
// attribute.
MLIR_CAPI_EXPORTED int64_t
mlirContiguousLayoutAttrGetPermutationEntry(MlirAttribute attr, intptr_t pos);

/// Returns the typeID of a ContiguousLayout attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirContiguousLayoutAttrGetTypeID(void);

#ifdef __cplusplus
}
#endif
Expand Down
5 changes: 3 additions & 2 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ def MemRefTypeAttr
class MemRef_Op<string mnemonic, list<Trait> traits = []>
: Op<MemRef_Dialect, mnemonic, traits>;

// Base class for ops with static/dynamic offset, sizes and strides
// attributes/arguments.
// Base class for ops with static/dynamic offset, sizes and optional strides
// attributes/arguments. When the strides are not specified, this implies a
// contiguous layout.
class MemRef_OpWithOffsetSizesAndStrides<string mnemonic,
list<Trait> traits = []>
: MemRef_Op<mnemonic, traits> {
Expand Down
26 changes: 14 additions & 12 deletions mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,10 @@ LogicalResult reshapeLikeShapesAreCompatible(
ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape);

/// Returns true iff the type is a MemRefType and has a non-identity layout.
bool hasNonIdentityLayout(Type type);
/// Returns true iff the type is a MemRefType and has a layout that is not
/// row-major contiguous - that is, the identity layout with an optional
/// offset.
bool hasNonRowMajorContiguousLayout(Type type);

enum class ReshapeOpKind { kExpand, kCollapse };

Expand All @@ -197,9 +199,9 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {

ShapedType resultType = reshapeOp.getResultType();

if (hasNonIdentityLayout(srcReshapeOp.getSrc().getType()) ||
hasNonIdentityLayout(reshapeOp.getSrc().getType()) ||
hasNonIdentityLayout(reshapeOp.getResult().getType()))
if (hasNonRowMajorContiguousLayout(srcReshapeOp.getSrc().getType()) ||
hasNonRowMajorContiguousLayout(reshapeOp.getSrc().getType()) ||
hasNonRowMajorContiguousLayout(reshapeOp.getResult().getType()))
return failure();

std::optional<SmallVector<ReassociationIndices>> reassociationIndices =
Expand Down Expand Up @@ -265,9 +267,9 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
ShapedType srcType = expandOp.getSrcType();
ShapedType resultType = collapseOp.getResultType();

if (hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
hasNonIdentityLayout(expandOp.getSrc().getType()) ||
hasNonIdentityLayout(expandOp.getResult().getType()))
if (hasNonRowMajorContiguousLayout(collapseOp.getSrc().getType()) ||
hasNonRowMajorContiguousLayout(expandOp.getSrc().getType()) ||
hasNonRowMajorContiguousLayout(expandOp.getResult().getType()))
return failure();

int64_t srcRank = srcType.getRank();
Expand Down Expand Up @@ -331,9 +333,9 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
ShapedType srcType = collapseOp.getSrcType();
ShapedType resultType = expandOp.getResultType();

if (hasNonIdentityLayout(expandOp.getSrc().getType()) ||
hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
hasNonIdentityLayout(collapseOp.getResult().getType()))
if (hasNonRowMajorContiguousLayout(expandOp.getSrc().getType()) ||
hasNonRowMajorContiguousLayout(collapseOp.getSrc().getType()) ||
hasNonRowMajorContiguousLayout(collapseOp.getResult().getType()))
return failure();

int64_t srcRank = srcType.getRank();
Expand Down Expand Up @@ -451,7 +453,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
/// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
///
/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
/// tensor<1x1x1x10xf32> into tensor<1x10xf32>
/// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
/// tensor<1x10xf32> into tensor<10x10xf32>
Expand Down
22 changes: 22 additions & 0 deletions mlir/include/mlir/IR/BuiltinAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,28 @@ inline bool operator!=(StringRef lhs, StringAttr rhs) { return !(lhs == rhs); }

namespace mlir {

/// Given an N-dimensional permutation and an offset (which can use
/// ShapedType::kDynamic) to represent a dynamic value), return the
/// N-dimensional map that is permuted according to said permutation and adds
/// the offset to the final output. If the permutation has no outputs (it's a
/// 0-D map), add one result to hold the offset.
///
/// Examples:
/// =========
///
/// offset = 0, permutation = [0, 1, 2] gives
/// [](d0, d1, d2) -> (d0, d1, d2)
/// while offset = 5 gives [](d0, d1, d2) -> (d0, d1, d2 + 5)
/// and offset = ? gives [s0](d0, d1, d2) -> (d0, d1, d2 + s0).
///
/// offset = ?, permutation = [2, 1, 0] gives
/// [s0](d0, d1, d2) -> (d2, d1, d0 + s0)
///
/// Finally, offset = 0, permutation = [], gives []() -> (0), while
/// offset = ?, permutation = [] gives [s0]() -> (s0).
AffineMap makePermutedMapWithOffset(ArrayRef<int64_t> permutation,
int64_t offset, MLIRContext *context);

/// Given a list of strides (in which ShapedType::kDynamic
/// represents a dynamic value), return the single result AffineMap which
/// represents the linearized strided layout map. Dimensions correspond to the
Expand Down
91 changes: 89 additions & 2 deletions mlir/include/mlir/IR/BuiltinAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def Builtin_DenseArrayRawDataParameter : ArrayRefParameter<
}];
}

def Builtin_DenseArray : Builtin_Attr<"DenseArray", "dense_array",
def Builtin_DenseArray : Builtin_Attr<"DenseArray", "dense_array",
[BlobAttrInterface]> {
let summary = "A dense array of integer or floating point elements.";
let description = [{
Expand Down Expand Up @@ -494,7 +494,7 @@ def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements",
/// when building the attribute. The provided `blobName` is used as a hint
/// for the key of the new handle for the `blob` resource, but may be
/// changed if necessary to ensure uniqueness during insertion.
/// This base class builder does no element type specific size or alignment
/// This base class builder does no element type specific size or alignment
/// checking. Use the typed subclasses for more safety unless if performing
/// generic operations.
AttrBuilderWithInferredContext<(ins
Expand Down Expand Up @@ -1051,9 +1051,96 @@ def StridedLayoutAttr : Builtin_Attr<"StridedLayout", "strided_layout",
/// Returns true if this layout is static, i.e. the strides and offset all
/// have a known value > 0.
bool hasStaticLayout() const;

/// Get a "canonical" strided layout for the given strides.
/// This constructs a strided layout with the given `offset` and `strides`,
/// except that if either the strides are empty or equal to [1], it returns
/// the corresponding ContiguousLayoutAttr in order to guard against multiple
/// representations of the identity layout.
static ::mlir::MemRefLayoutAttrInterface getCanonical(MLIRContext *context,
int64_t offset, ::llvm::ArrayRef<int64_t> strides);
}];
}

//===----------------------------------------------------------------------===//
// ContiguousLayoutAttr
//===----------------------------------------------------------------------===//

def ContiguousLayoutAttr : Builtin_Attr<"ContiguousLayout", "contiguous_layout",
[DeclareAttrInterfaceMethods<MemRefLayoutAttrInterface,
["isIdentity", "verifyLayout"]>]> {
let summary = "An Attribute representing a contiguous layout of a shaped type";
let description = [{
Syntax:

```
contiguous-layout-attribute ::= `contiguous` `<` maybe-permutation
(`,` `offset` `:` dimension)? `>`
maybe-permutation ::= decimal-literal | `[` permutation `]`
permutation ::= decimal-literal (`,` decimal-literal)*
dimension ::= decimal-literal | `?`
```

A contiguous layout is a layout that represents a sequence of dimensions
laid out in linear memory in its canonical form. Specifically, it indicates
that if one permutes the dimensions of a memref according to `permutaton`,
they will be in a row-major contiguos form: that is, the stride (in the
sense of the strided layout) of dimension `permutation[i]` is equal
to the products of the sizes of all dimensions appearing later in the permutation.

For example, a MxN memref with a `contiguous<[1, 0]>` layout is colmn-major:
advancing in the M dimension requires moving by 1 element in linear memory,
while the N dimension requires moving by M elements. Conversely,
if the layout is `contiguous<[0, 1]>` (which can be written `contiguous<2>`
for brevity and will be omitted from printing without an offset), the stride
of the N dimension will be 1 element while the stride of the M dimension will be
N elements.

As a more complex example, `memref<AxBxCxT, contigous<[2, 0, 1], offset: D>>`
, where A, B, C, and D are potentially dynamic values, means that
the value at index `[%i, %j, %k]` is located `%k * A * B + %i * B + %j + D`
elements from the beginning of the memory underlying that memref.

The permutation must contain the integers between 0 and the rank of the memref - 1,
and must have one distinct entry for each memref dimension. The value
`[0, 1, ..., N-1]`, specifying a row-major format, may be printed as `N`
for clarity.

If an offset is specified, it is a number of elements to move within
the underlying linear memory after the permutation is applied. This offset
may be _dynamic_, meaning that it may not be known at compile time.
A dynamic offset is represented as a `?` in the assembly syntax and as
`ShapedType::kDynamic` in the code. The offset must be non-negative.

See [Dialects/Builtin.md#memreftype](MemRef type) for more information.
}];

let parameters = (ins
"int64_t":$offset,
ArrayRefParameter<
"int64_t",
"permutation (64-bit integer)"
>:$permutation
);

let builders = [
// Builder for row-major contiguous attribute.
AttrBuilder<(ins "int64_t":$offset, "int64_t":$rank)>
];
let genVerifyDecl = 1;

let extraClassDeclaration = [{
/// Print the attribute to the given output stream.
void print(raw_ostream &os) const;

/// Returns true if this layout is static, i.e. the offset has a static value.
bool hasStaticLayout() const;

/// Return true if this layout has a row-major permutation - that is, the
/// dimensions of the shape are not permuted.
bool isRowMajor() const;
}];
}

//===----------------------------------------------------------------------===//
// StringAttr
Expand Down
60 changes: 51 additions & 9 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -585,20 +585,27 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
layout must avoid internal aliasing, i.e., two distinct tuples of
_in-bounds_ indices must be pointing to different elements in memory. The
layout is an attribute that implements `MemRefLayoutAttrInterface`. The
bulitin dialect offers two kinds of layouts: strided and affine map, each
of which is available as an attribute. Other attributes may be used to
represent the layout as long as they can be converted to a
bulitin dialect offers three kinds of layouts: contiguous, strided and
affine map, each of which is available as an attribute. Other attributes may be
used to represent the layout as long as they can be converted to a
[semi-affine map](Affine.md/#semi-affine-maps) and implement the required
interface. Users of memref are expected to fallback to the affine
representation when handling unknown memref layouts. Multi-dimensional
affine forms are interpreted in _row-major_ fashion.

In absence of an explicit layout, a memref is considered to have a
multi-dimensional identity affine map layout. Identity layout maps do not
contribute to the MemRef type identification and are discarded on
construction. That is, a type with an explicit identity map is
row-major contiguous layout with an offset of 0, which is equivalent
to a multi-dimensional identity map. For backwards compatibility,
identity layout maps do not contribute to the MemRef type identification and
are discarded on construction. That is, a type with an explicit identity map is
`memref<?x?xf32, (i,j)->(i,j)>` is strictly the same as the one without a
layout, `memref<?x?xf32>`.
layout, `memref<?x?xf32>`, which, written explicitly, has the layout
`memref<?x?xf32, contiguous<2>>`.

The built-in layouts form a hierarchy: all contiguous layuts are strided layouts,
and all strided layouts are affine map layouts, but the reverse is not true.
Using a more specific layout may permit a greater degree of optimization in
the generated code.

##### Affine Map Layout

Expand Down Expand Up @@ -656,6 +663,37 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
Therefore, it is never subject to the implicit row-major layout
interpretation.

### Contiguous layout

The most restricted of the built-in layouts is the _contiguous_ layout, which
expresses the fact that the in-memory layout of the memref would be row-major
without padding after the associated permutation is applied. Equivalently,
a contigous layout is a strided layout where the strides are implicitly computed
from the (permuted) sizes of the memref.

This layout is necessary to allow optimizations during lowering passes in the
presence of dynamic sizes, since
`memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>` doesn't specify if it's
dimensions have padding in between tem or not - the two non-1 strides are
dynamic. By contrast, `contiguous<3, offset: ?>` indiates a row-major layout
with an offset, while `contiguous<[2, 1, 0], offset: ?>` indicates a
column-major layout. While this scheme could be expressed with an affine map,
some operations expect memrefs to be in a form compatible with the `strided`
layout, which can be difficult to detect from analyzing an affine expression.

In general, the layout `contiguous<[p0, p1, ..., pN], offset: V>`
corresponds to the affine map

```mlir
affine_map<(d0, ..., dN) -> (d[p0], d[p1], ... + d[pN] + V)>
```

where `V` is either `s0` if it is dynamic or some constant value.

For convenience, the layout `contigous<[0, 1, ..., N], offset: V>` is printed
as `contigous<N+1, offset: V>`, and the `, offset: V` segment is omitted if `V`
is `0`.

##### Codegen of Unranked Memref

Using unranked memref in codegen besides the case mentioned above is highly
Expand Down Expand Up @@ -815,6 +853,10 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
/// considering both _all_ and _only_ the trailing 3 dims,
/// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when
/// considering the trailing 3 dims.
/// - memref<?x?x?xi8, contiguous<3, offset: ?>> is contiguous when
/// considering all dimensions.
/// - memref<?x?x?x?xi32, contiguous<[1, 0, 2, 3], offset: ?>> is
/// _only_ contiguous when considering the trailing 2 dimensions.
///
bool areTrailingDimsContiguous(int64_t n);

Expand All @@ -830,8 +872,8 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [

/// Returns the strides of the MemRef if the layout map is in strided form.
/// MemRefs with a layout map in strided form include:
/// 1. empty or identity layout map, in which case the stride information
/// is the canonical form computed from sizes;
/// 1. the empty layout, the identity layout affine map, and any ContigousLayoutAttr,
/// in which case the stride information is the canonical form computed from sizes;
/// 2. a StridedLayoutAttr layout;
/// 3. any other layout that be converted into a single affine map layout
/// of the form `K + k0 * d0 + ... kn * dn`, where K and ki's are
Expand Down
Loading