Skip to content

[mlir] Add optional layout attribute to VectorType #71916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -495,4 +495,28 @@ def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> {
];
}

//===----------------------------------------------------------------------===//
// VectorLayoutAttrInterface
//===----------------------------------------------------------------------===//

def VectorLayoutAttrInterface : AttrInterface<"VectorLayoutAttrInterface"> {
let cppNamespace = "::mlir";

let description = [{
This interface is used for attributes that can represent the Vector type's
layout semantics, such as being able to map the vector indices to those
of the vector fragments held by individiual threads.
}];

let methods = [
InterfaceMethod<
"Check if the current layout is applicable to the provided shape",
"::mlir::LogicalResult", "verifyLayout",
(ins "::llvm::ArrayRef<int64_t>":$shape,
"::mlir::Type":$elementType,
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
>
];
}

#endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_
14 changes: 11 additions & 3 deletions mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,14 @@ class VectorType::Builder {
/// Build from another VectorType.
explicit Builder(VectorType other)
: elementType(other.getElementType()), shape(other.getShape()),
scalableDims(other.getScalableDims()) {}
scalableDims(other.getScalableDims()), layout(other.getLayout()) {}

/// Build from scratch.
Builder(ArrayRef<int64_t> shape, Type elementType,
ArrayRef<bool> scalableDims = {})
: elementType(elementType), shape(shape), scalableDims(scalableDims) {}
ArrayRef<bool> scalableDims = {},
VectorLayoutAttrInterface layout = {})
: elementType(elementType), shape(shape), scalableDims(scalableDims),
layout(layout) {}

Builder &setShape(ArrayRef<int64_t> newShape,
ArrayRef<bool> newIsScalableDim = {}) {
Expand Down Expand Up @@ -342,6 +344,11 @@ class VectorType::Builder {
return *this;
}

Builder &setLayout(VectorLayoutAttrInterface newLayout) {
layout = newLayout;
return *this;
}

operator VectorType() {
return VectorType::get(shape, elementType, scalableDims);
}
Expand All @@ -350,6 +357,7 @@ class VectorType::Builder {
Type elementType;
CopyOnWriteArrayRef<int64_t> shape;
CopyOnWriteArrayRef<bool> scalableDims;
VectorLayoutAttrInterface layout;
};

/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
Expand Down
23 changes: 19 additions & 4 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1029,11 +1029,13 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
Syntax:

```
vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
vector-type ::= `vector` `<` vector-dim-list vector-element-type
(`,` layout-specification)? `>`
vector-element-type ::= float-type | integer-type | index-type
vector-dim-list := (static-dim-list `x`)?
static-dim-list ::= static-dim (`x` static-dim)*
static-dim ::= (decimal-literal | `[` decimal-literal `]`)
layout-specification ::= attribute-value
```

The vector type represents a SIMD style vector used by target-specific
Expand All @@ -1050,6 +1052,16 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
declarations, `vector<0x42xi32>` is invalid because it is interpreted as a
2D vector with shape `(0, 42)` and zero shapes are not allowed.

##### Layout

A vector may optionally have a layout that can be used to capture
the mapping of the vector indices to a an arbitrary coordinate sytem.
An example of such a mapping is the mapping of vector indices to
indices of the vector fragments that are held by individual threads
in a SIMT execution model. Such layouts are common in a wide variety of
GPU matrix multiplication instructions. The layout can be any attribute
that implements `VectorLayoutAttrInterface`.

Examples:

```mlir
Expand All @@ -1068,17 +1080,20 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
// A 3D mixed fixed/scalable vector in which only the inner dimension is
// scalable.
vector<2x[4]x8xf32>

```
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
"Type":$elementType,
ArrayRefParameter<"bool">:$scalableDims
ArrayRefParameter<"bool">:$scalableDims,
"VectorLayoutAttrInterface":$layout
);
let builders = [
TypeBuilderWithInferredContext<(ins
"ArrayRef<int64_t>":$shape, "Type":$elementType,
CArg<"ArrayRef<bool>", "{}">:$scalableDims
CArg<"ArrayRef<bool>", "{}">:$scalableDims,
CArg<"VectorLayoutAttrInterface", "{}">:$layout
), [{
// While `scalableDims` is optional, its default value should be
// `false` for every dim in `shape`.
Expand All @@ -1087,7 +1102,7 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
isScalableVec.resize(shape.size(), false);
scalableDims = isScalableVec;
}
return $_get(elementType.getContext(), shape, elementType, scalableDims);
return $_get(elementType.getContext(), shape, elementType, scalableDims, layout);
}]>
];
let extraClassDeclaration = [{
Expand Down
27 changes: 25 additions & 2 deletions mlir/lib/AsmParser/TypeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,14 +459,37 @@ VectorType Parser::parseVectorType() {
// Parse the element type.
auto typeLoc = getToken().getLoc();
auto elementType = parseType();
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
if (!elementType)
return nullptr;

if (!VectorType::isValidElementType(elementType))
return emitError(typeLoc, "vector elements must be int/index/float type"),
nullptr;

return VectorType::get(dimensions, elementType, scalableDims);
VectorLayoutAttrInterface layout;
auto parseElt = [&]() -> ParseResult {
Attribute attr = parseAttribute();
if (!attr)
return failure();
if (isa<VectorLayoutAttrInterface>(attr)) {
layout = cast<VectorLayoutAttrInterface>(attr);
}
return success();
};

// Parse the vector layout
if (!consumeIf(Token::greater)) {
if (parseToken(Token::comma, "expected ',' or '>' in vector type") ||
parseCommaSeparatedListUntil(Token::greater, parseElt,
/*allowEmptyList=*/false)) {
return nullptr;
}
}

if (!layout)
return VectorType::get(dimensions, elementType, scalableDims);

return VectorType::get(dimensions, elementType, scalableDims, layout);
}

/// Parse a dimension list in a vector type. This populates the dimension list.
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2572,6 +2572,11 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
os << 'x';
}
printType(vectorTy.getElementType());
VectorLayoutAttrInterface layout = vectorTy.getLayout();
if (layout) {
os << ", ";
printAttribute(vectorTy.getLayout(), AttrTypeElision::May);
}
os << '>';
})
.Case<RankedTensorType>([&](RankedTensorType tensorTy) {
Expand Down
8 changes: 7 additions & 1 deletion mlir/lib/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,

LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
ArrayRef<bool> scalableDims) {
ArrayRef<bool> scalableDims,
VectorLayoutAttrInterface layout) {
if (!isValidElementType(elementType))
return emitError()
<< "vector elements must be int/index/float type but got "
Expand All @@ -242,6 +243,11 @@ LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
return emitError() << "number of dims must match, got "
<< scalableDims.size() << " and " << shape.size();

if (layout) {
if (failed(layout.verifyLayout(shape, elementType, emitError)))
return emitError() << "Layout verification failed!";
}

return success();
}

Expand Down
1 change: 1 addition & 0 deletions mlir/unittests/Interfaces/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_mlir_unittest(MLIRInterfacesTests
DataLayoutInterfacesTest.cpp
InferIntRangeInterfaceTest.cpp
InferTypeOpInterfaceTest.cpp
VectorLayoutInterfaceTest.cpp
)

target_link_libraries(MLIRInterfacesTests
Expand Down
Loading