Skip to content

Commit 53b3be7

Browse files
authored
[mlir][spirv] Fix coop matrix load (llvm#65712)
- Fix order of operands/attributes - Allow for stride to be any integer type - Use ODS for parsing/printing - Update examples and tests - Fix a typo in SPIR-V tblgen code
1 parent b3a14ca commit 53b3be7

File tree

4 files changed

+54
-76
lines changed

4 files changed

+54
-76
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,20 +101,32 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
101101
``` {.ebnf}
102102
cooperative-matrix-load-op ::= ssa-id `=` `spirv.KHR.CooperativeMatrixLoad`
103103
ssa-use `,` ssa-use `,`
104-
cooperative-matrix-layout `,`
105-
(`[` memory-operand `]`)? ` : `
106-
pointer-type `as` cooperative-matrix-type
104+
`<` cooperative-matrix-layout `>`
105+
(`,` `<` memory-operand `>`)? `:`
106+
pointer-type `,` stride-type `->` cooperative-matrix-type
107107
```
108108

109+
TODO: In the SPIR-V spec, `stride` is an optional argument. We should also
110+
support this optionality in the SPIR-V dialect.
111+
109112
#### Example:
110113

111114
```
112-
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor
113-
: !spirv.ptr<i32, StorageBuffer>
114-
as !spirv.KHR.coopmatrix<16x8xi32, Workgroup, MatrixA>
115+
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>
116+
: !spirv.ptr<i32, StorageBuffer>, i32
117+
-> !spirv.KHR.coopmatrix<16x8xi32, Workgroup, MatrixA>
118+
119+
%1 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile>
120+
: !spirv.ptr<f32, StorageBuffer>, i64
121+
-> !spirv.KHR.coopmatrix<8x8xf32, Subgroup, MatrixAcc>
115122
```
116123
}];
117124

125+
let assemblyFormat = [{
126+
$pointer `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:`
127+
type(operands) `->` type($result)
128+
}];
129+
118130
let availability = [
119131
MinVersion<SPIRV_V_1_6>,
120132
MaxVersion<SPIRV_V_1_6>,
@@ -124,8 +136,8 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
124136

125137
let arguments = (ins
126138
SPIRV_AnyPtr:$pointer,
127-
SPIRV_Integer:$stride,
128139
SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
140+
SPIRV_Integer:$stride,
129141
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
130142
);
131143

mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -37,49 +37,6 @@ LogicalResult KHRCooperativeMatrixLengthOp::verify() {
3737
// spirv.KHR.CooperativeMatrixLoad
3838
//===----------------------------------------------------------------------===//
3939

40-
ParseResult KHRCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
41-
OperationState &result) {
42-
std::array<OpAsmParser::UnresolvedOperand, 2> operandInfo = {};
43-
if (parser.parseOperand(operandInfo[0]) || parser.parseComma())
44-
return failure();
45-
if (parser.parseOperand(operandInfo[1]) || parser.parseComma())
46-
return failure();
47-
48-
CooperativeMatrixLayoutKHR layout;
49-
if (parseEnumKeywordAttr<CooperativeMatrixLayoutKHRAttr>(
50-
layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) {
51-
return failure();
52-
}
53-
54-
if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName))
55-
return failure();
56-
57-
Type ptrType;
58-
Type elementType;
59-
if (parser.parseColon() || parser.parseType(ptrType) ||
60-
parser.parseKeywordType("as", elementType)) {
61-
return failure();
62-
}
63-
result.addTypes(elementType);
64-
65-
Type strideType = parser.getBuilder().getIntegerType(32);
66-
if (parser.resolveOperands(operandInfo, {ptrType, strideType},
67-
parser.getNameLoc(), result.operands)) {
68-
return failure();
69-
}
70-
71-
return success();
72-
}
73-
74-
void KHRCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
75-
printer << " " << getPointer() << ", " << getStride() << ", "
76-
<< getMatrixLayout();
77-
// Print optional memory operand attribute.
78-
if (auto memOperand = getMemoryOperand())
79-
printer << " [\"" << memOperand << "\"]";
80-
printer << " : " << getPointer().getType() << " as " << getType();
81-
}
82-
8340
static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
8441
Type coopMatrix) {
8542
auto pointerType = cast<PointerType>(pointer);

mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,37 +23,46 @@ spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" {
2323

2424
// CHECK-LABEL: @cooperative_matrix_load
2525
spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
26-
// CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, RowMajor :
27-
// CHECK-SAME: !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
28-
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor :
29-
!spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
26+
// CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor> :
27+
// CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
28+
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
29+
!spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
3030
spirv.Return
3131
}
3232

3333
// CHECK-LABEL: @cooperative_matrix_load_memoperand
3434
spirv.func @cooperative_matrix_load_memoperand(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
35-
// CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, ColumnMajor ["Volatile"] :
36-
// CHECK-SAME: !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
37-
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ColumnMajor ["Volatile"] :
38-
!spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
35+
// CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <ColumnMajor>, <Volatile> :
36+
// CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
37+
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile> :
38+
!spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
3939
spirv.Return
4040
}
4141

4242
// CHECK-LABEL: @cooperative_matrix_load_vector_ptr_type
4343
spirv.func @cooperative_matrix_load_vector_ptr_type(%ptr : !spirv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32) "None" {
44-
// CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, RowMajor ["Volatile"] :
45-
// CHECK-SAME: !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
46-
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor ["Volatile"] :
47-
!spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
44+
// CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>, <Volatile> :
45+
// CHECK-SAME: !spirv.ptr<vector<4xi32>, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
46+
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <Volatile> :
47+
!spirv.ptr<vector<4xi32>, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
4848
spirv.Return
4949
}
5050

5151
// CHECK-LABEL: @cooperative_matrix_load_function
5252
spirv.func @cooperative_matrix_load_function(%ptr : !spirv.ptr<i32, Function>, %stride : i32) "None" {
53-
// CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, RowMajor :
54-
// CHECK-SAME: !spirv.ptr<i32, Function> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>
55-
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor :
56-
!spirv.ptr<i32, Function> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>
53+
// CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor> :
54+
// CHECK-SAME: !spirv.ptr<i32, Function>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>
55+
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
56+
!spirv.ptr<i32, Function>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>
57+
spirv.Return
58+
}
59+
60+
// CHECK-LABEL: @cooperative_matrix_load_stride_i16
61+
spirv.func @cooperative_matrix_load_stride_i16(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i16) "None" {
62+
// CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor> :
63+
// CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, i16 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
64+
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
65+
!spirv.ptr<i32, StorageBuffer>, i16 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
5766
spirv.Return
5867
}
5968

@@ -82,8 +91,8 @@ spirv.func @cooperative_matrix_store_memoperand(%ptr : !spirv.ptr<i32, StorageBu
8291

8392
spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32) "None" {
8493
// expected-error @+1 {{Pointer must point to a scalar or vector type}}
85-
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ColumnMajor :
86-
!spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
94+
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor> :
95+
!spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
8796
spirv.Return
8897
}
8998

@@ -92,25 +101,25 @@ spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr<!spirv.struct<(f32
92101
spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
93102
// expected-error @+1 {{expected ','}}
94103
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride :
95-
!spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
104+
!spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
96105
spirv.Return
97106
}
98107

99108
// -----
100109

101110
spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
102-
// expected-error @+1 {{expected valid keyword}}
111+
// expected-error @+1 {{expected '<'}}
103112
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, :
104-
!spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup, MatrixA>
113+
!spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.NV.coopmatrix<8x16xi32, Subgroup, MatrixA>
105114
spirv.Return
106115
}
107116

108117
// -----
109118

110119
spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
111120
// expected-error @+1 {{op result #0 must be any SPIR-V cooperative matrix type}}
112-
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ColumnMajor :
113-
!spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
121+
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor> :
122+
!spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.NV.coopmatrix<8x16xi32, Subgroup>
114123
spirv.Return
115124
}
116125

mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -929,9 +929,9 @@ static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
929929
if (auto *valueArg = llvm::dyn_cast_if_present<NamedTypeConstraint *>(argument)) {
930930
if (valueArg->isVariableLength()) {
931931
if (i != e - 1) {
932-
PrintFatalError(loc, "SPIR-V ops can have Variadic<..> or "
933-
"std::optional<...> arguments only if "
934-
"it's the last argument");
932+
PrintFatalError(
933+
loc, "SPIR-V ops can have Variadic<..> or "
934+
"Optional<...> arguments only if it's the last argument");
935935
}
936936
os << tabs
937937
<< formatv("for (; {0} < {1}.size(); ++{0})", wordIndex, words);

0 commit comments

Comments
 (0)