Skip to content

Commit 1258c3f

Browse files
authored
[mlir][spirv] Support spirv.coopmatrix type (de-)serialization (#65831)
Extend SPIR-V target serialization and deserialization to handle coop matrix types. Add a roundtrip test. In addition to `FileCheck` checks, the resulting spirv binary also passes `spir-val` (external tool). Also fix a type attribute bug surfaced by the `CooperativeMatrixLength` op. Multiple matrix operand attributes will be handled in a future patch to reduce the scope.
1 parent 30e688e commit 1258c3f

File tree

11 files changed

+205
-43
lines changed

11 files changed

+205
-43
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4053,6 +4053,8 @@ def SPIRV_KHR_CMU_MatrixA : I32EnumAttrCase<"MatrixA", 0>;
40534053
def SPIRV_KHR_CMU_MatrixB : I32EnumAttrCase<"MatrixB", 1>;
40544054
def SPIRV_KHR_CMU_MatrixAcc : I32EnumAttrCase<"MatrixAcc", 2>;
40554055

4056+
// NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
4057+
// SPIR-V proper.
40564058
def SPIRV_KHR_CooperativeMatrixUseAttr :
40574059
SPIRV_I32EnumAttr<"CooperativeMatrixUseKHR",
40584060
"valid SPIR-V Cooperative Matrix Use (KHR)",
@@ -4064,6 +4066,8 @@ def SPIRV_KHR_CooperativeMatrixUseAttr :
40644066
def SPIRV_KHR_CML_RowMajor : I32EnumAttrCase<"RowMajor", 0>;
40654067
def SPIRV_KHR_CML_ColumnMajor : I32EnumAttrCase<"ColumnMajor", 1>;
40664068

4069+
// NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
4070+
// SPIR-V proper.
40674071
def SPIRV_KHR_CooperativeMatrixLayoutAttr :
40684072
SPIRV_I32EnumAttr<"CooperativeMatrixLayoutKHR",
40694073
"valid SPIR-V Cooperative Matrix Layout (KHR)",

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,14 @@ def SPIRV_KHRCooperativeMatrixLengthOp :
5555
];
5656

5757
let arguments = (ins
58-
TypeAttr:$cooperative_matrix_type
58+
TypeAttrOf<SPIRV_AnyCooperativeMatrix>:$cooperative_matrix_type
5959
);
6060

6161
let results = (outs
6262
SPIRV_Int32:$result
6363
);
64+
65+
let hasVerifier = false;
6466
}
6567

6668
// -----

mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,6 @@
1919
using namespace mlir::spirv::AttrNames;
2020

2121
namespace mlir::spirv {
22-
//===----------------------------------------------------------------------===//
23-
// spirv.KHR.CooperativeMatrixLength
24-
//===----------------------------------------------------------------------===//
25-
26-
LogicalResult KHRCooperativeMatrixLengthOp::verify() {
27-
if (!isa<CooperativeMatrixType>(getCooperativeMatrixType())) {
28-
return emitOpError(
29-
"type attribute must be a '!spirv.coopmatrix' type, found ")
30-
<< getCooperativeMatrixType() << " instead";
31-
}
32-
33-
return success();
34-
}
35-
3622
//===----------------------------------------------------------------------===//
3723
// spirv.KHR.CooperativeMatrixLoad
3824
//===----------------------------------------------------------------------===//

mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ LogicalResult spirv::Deserializer::processInstruction(
164164
case spirv::Opcode::OpTypeRuntimeArray:
165165
case spirv::Opcode::OpTypeStruct:
166166
case spirv::Opcode::OpTypePointer:
167+
case spirv::Opcode::OpTypeCooperativeMatrixKHR:
167168
case spirv::Opcode::OpTypeCooperativeMatrixNV:
168169
return processType(opcode, operands);
169170
case spirv::Opcode::OpTypeForwardPointer:

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -765,8 +765,10 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
765765
} break;
766766
case spirv::Opcode::OpTypeArray:
767767
return processArrayType(operands);
768+
case spirv::Opcode::OpTypeCooperativeMatrixKHR:
769+
return processCooperativeMatrixTypeKHR(operands);
768770
case spirv::Opcode::OpTypeCooperativeMatrixNV:
769-
return processCooperativeMatrixType(operands);
771+
return processCooperativeMatrixTypeNV(operands);
770772
case spirv::Opcode::OpTypeFunction:
771773
return processFunctionType(operands);
772774
case spirv::Opcode::OpTypeJointMatrixINTEL:
@@ -900,32 +902,76 @@ spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
900902
return success();
901903
}
902904

903-
LogicalResult
904-
spirv::Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
905+
LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
906+
ArrayRef<uint32_t> operands) {
907+
if (operands.size() != 6) {
908+
return emitError(unknownLoc,
909+
"OpTypeCooperativeMatrixKHR must have element type, "
910+
"scope, row and column parameters, and use");
911+
}
912+
913+
Type elementTy = getType(operands[1]);
914+
if (!elementTy) {
915+
return emitError(unknownLoc,
916+
"OpTypeCooperativeMatrixKHR references undefined <id> ")
917+
<< operands[1];
918+
}
919+
920+
std::optional<spirv::Scope> scope =
921+
spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
922+
if (!scope) {
923+
return emitError(
924+
unknownLoc,
925+
"OpTypeCooperativeMatrixKHR references undefined scope <id> ")
926+
<< operands[2];
927+
}
928+
929+
unsigned rows = getConstantInt(operands[3]).getInt();
930+
unsigned columns = getConstantInt(operands[4]).getInt();
931+
932+
std::optional<spirv::CooperativeMatrixUseKHR> use =
933+
spirv::symbolizeCooperativeMatrixUseKHR(
934+
getConstantInt(operands[5]).getInt());
935+
if (!use) {
936+
return emitError(
937+
unknownLoc,
938+
"OpTypeCooperativeMatrixKHR references undefined use <id> ")
939+
<< operands[5];
940+
}
941+
942+
typeMap[operands[0]] =
943+
spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use);
944+
return success();
945+
}
946+
947+
LogicalResult spirv::Deserializer::processCooperativeMatrixTypeNV(
948+
ArrayRef<uint32_t> operands) {
905949
if (operands.size() != 5) {
906-
return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element "
950+
return emitError(unknownLoc, "OpTypeCooperativeMatrixNV must have element "
907951
"type and row x column parameters");
908952
}
909953

910954
Type elementTy = getType(operands[1]);
911955
if (!elementTy) {
912956
return emitError(unknownLoc,
913-
"OpTypeCooperativeMatrix references undefined <id> ")
957+
"OpTypeCooperativeMatrixNV references undefined <id> ")
914958
<< operands[1];
915959
}
916960

917-
auto scope = spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
961+
std::optional<spirv::Scope> scope =
962+
spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
918963
if (!scope) {
919-
return emitError(unknownLoc,
920-
"OpTypeCooperativeMatrix references undefined scope <id> ")
964+
return emitError(
965+
unknownLoc,
966+
"OpTypeCooperativeMatrixNV references undefined scope <id> ")
921967
<< operands[2];
922968
}
923969

924970
unsigned rows = getConstantInt(operands[3]).getInt();
925971
unsigned columns = getConstantInt(operands[4]).getInt();
926972

927-
typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get(
928-
elementTy, scope.value(), rows, columns);
973+
typeMap[operands[0]] =
974+
spirv::CooperativeMatrixNVType::get(elementTy, *scope, rows, columns);
929975
return success();
930976
}
931977

mlir/lib/Target/SPIRV/Deserialization/Deserializer.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,9 @@ class Deserializer {
254254

255255
LogicalResult processArrayType(ArrayRef<uint32_t> operands);
256256

257-
LogicalResult processCooperativeMatrixType(ArrayRef<uint32_t> operands);
257+
LogicalResult processCooperativeMatrixTypeKHR(ArrayRef<uint32_t> operands);
258+
259+
LogicalResult processCooperativeMatrixTypeNV(ArrayRef<uint32_t> operands);
258260

259261
LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
260262

mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,28 @@ LogicalResult Serializer::prepareBasicType(
593593
return success();
594594
}
595595

596+
if (auto cooperativeMatrixType =
597+
dyn_cast<spirv::CooperativeMatrixType>(type)) {
598+
uint32_t elementTypeID = 0;
599+
if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
600+
elementTypeID, serializationCtx))) {
601+
return failure();
602+
}
603+
typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
604+
auto getConstantOp = [&](uint32_t id) {
605+
auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
606+
return prepareConstantInt(loc, attr);
607+
};
608+
operands.push_back(elementTypeID);
609+
operands.push_back(
610+
getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())));
611+
operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
612+
operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
613+
operands.push_back(
614+
getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse())));
615+
return success();
616+
}
617+
596618
if (auto cooperativeMatrixType =
597619
dyn_cast<spirv::CooperativeMatrixNVType>(type)) {
598620
uint32_t elementTypeID = 0;

mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ spirv.func @cooperative_matrix_length() -> i32 "None" {
1414
// -----
1515

1616
spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" {
17-
// expected-error @+1 {{'spirv.KHR.CooperativeMatrixLength' op type attribute must be a '!spirv.coopmatrix'}}
17+
// expected-error @+1 {{'cooperative_matrix_type' failed to satisfy constraint: type attribute of any SPIR-V cooperative matrix type}}
1818
%0 = spirv.KHR.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
1919
spirv.ReturnValue %0 : i32
2020
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip \
2+
// RUN: --split-input-file %s | FileCheck %s
3+
4+
spirv.module Logical GLSL450 requires
5+
#spirv.vce<v1.5, [Shader, Int8, Int16, Int64, Linkage, CooperativeMatrixKHR],
6+
[SPV_KHR_storage_buffer_storage_class, SPV_KHR_cooperative_matrix]> {
7+
8+
// CHECK-LABEL: @cooperative_matrix_length
9+
spirv.func @cooperative_matrix_length() "None" {
10+
// CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixLength : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>
11+
%0 = spirv.KHR.CooperativeMatrixLength : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>
12+
spirv.Return
13+
}
14+
15+
// CHECK-LABEL: @cooperative_matrix_load_1
16+
spirv.func @cooperative_matrix_load_1(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
17+
// CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>
18+
// CHECK-SAME: : !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
19+
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
20+
!spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
21+
spirv.Return
22+
}
23+
24+
// CHECK-LABEL: @cooperative_matrix_load_2
25+
spirv.func @cooperative_matrix_load_2(%ptr : !spirv.ptr<f32, StorageBuffer>, %stride : i64) "None" {
26+
// CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <ColumnMajor>, <Volatile>
27+
// CHECK-SAME: : !spirv.ptr<f32, StorageBuffer>, i64 -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixAcc>
28+
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile> :
29+
!spirv.ptr<f32, StorageBuffer>, i64 -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixAcc>
30+
spirv.Return
31+
}
32+
33+
// CHECK-LABEL: @cooperative_matrix_store_1
34+
spirv.func @cooperative_matrix_store_1(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
35+
%m : !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>) "None" {
36+
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, <RowMajor>
37+
// CHECK-SAME: : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
38+
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor> :
39+
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
40+
spirv.Return
41+
}
42+
43+
// CHECK-LABEL: @cooperative_matrix_store_2
44+
spirv.func @cooperative_matrix_store_2(%ptr : !spirv.ptr<f32, Workgroup>, %stride : i64,
45+
%m : !spirv.coopmatrix<4x8xf32, Subgroup, MatrixB>) "None" {
46+
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, <ColumnMajor>, <Nontemporal>
47+
// CHECK-SAME: : !spirv.ptr<f32, Workgroup>, !spirv.coopmatrix<4x8xf32, Subgroup, MatrixB>, i64
48+
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <ColumnMajor>, <Nontemporal> :
49+
!spirv.ptr<f32, Workgroup>, !spirv.coopmatrix<4x8xf32, Subgroup, MatrixB>, i64
50+
spirv.Return
51+
}
52+
53+
// CHECK-LABEL: @cooperative_matrix_muladd
54+
spirv.func @cooperative_matrix_muladd_1(%a : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
55+
%b : !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>,
56+
%c : !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>) "None" {
57+
// CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} :
58+
// CHECK-SAME: !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
59+
// CHECK-SAME: !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
60+
// CHECK-SAME: -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
61+
%p = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
62+
!spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
63+
-> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
64+
65+
// CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}, <BSigned> :
66+
// CHECK-SAME: !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
67+
// CHECK-SAME: !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
68+
// CHECK-SAME: -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
69+
%q = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
70+
<BSigned> : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
71+
!spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
72+
-> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
73+
74+
// TODO: Handle multiple matrix operands and add relevant testcases here.
75+
spirv.Return
76+
}
77+
78+
// CHECK-LABEL: @cooperative_matrix_muladd
79+
spirv.func @cooperative_matrix_muladd_2(%a : !spirv.coopmatrix<8x8xf32, Workgroup, MatrixA>,
80+
%b : !spirv.coopmatrix<8x8xf32, Workgroup, MatrixB>,
81+
%c : !spirv.coopmatrix<8x8xf32, Workgroup, MatrixAcc>) "None" {
82+
// CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} :
83+
// CHECK-SAME: !spirv.coopmatrix<8x8xf32, Workgroup, MatrixA>,
84+
// CHECK-SAME: !spirv.coopmatrix<8x8xf32, Workgroup, MatrixB>
85+
// CHECK-SAME: -> !spirv.coopmatrix<8x8xf32, Workgroup, MatrixAcc>
86+
%r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x8xf32, Workgroup, MatrixA>,
87+
!spirv.coopmatrix<8x8xf32, Workgroup, MatrixB>
88+
-> !spirv.coopmatrix<8x8xf32, Workgroup, MatrixAcc>
89+
90+
spirv.Return
91+
}
92+
93+
}

mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/TableGen/Format.h"
1717
#include "mlir/TableGen/GenInfo.h"
1818
#include "mlir/TableGen/Operator.h"
19+
#include "llvm/ADT/STLExtras.h"
1920
#include "llvm/ADT/Sequence.h"
2021
#include "llvm/ADT/SmallVector.h"
2122
#include "llvm/ADT/StringExtras.h"
@@ -512,6 +513,14 @@ static mlir::GenRegistration
512513
// Serialization AutoGen
513514
//===----------------------------------------------------------------------===//
514515

516+
// These enums are encoded as <id> to constant values in SPIR-V blob, but we
517+
// directly use the constant value as attribute in SPIR-V dialect. So need
518+
// to handle them separately from normal enum attributes.
519+
constexpr llvm::StringLiteral constantIdEnumAttrs[] = {
520+
"SPIRV_ScopeAttr", "SPIRV_KHR_CooperativeMatrixUseAttr",
521+
"SPIRV_KHR_CooperativeMatrixLayoutAttr", "SPIRV_MemorySemanticsAttr",
522+
"SPIRV_MatrixLayoutAttr"};
523+
515524
/// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The
516525
/// generates code extracts the attribute with name `attrName` from
517526
/// `operandList` of `op`.
@@ -521,12 +530,7 @@ static void emitAttributeSerialization(const Attribute &attr,
521530
StringRef attrName, raw_ostream &os) {
522531
os << tabs
523532
<< formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName);
524-
if (attr.getAttrDefName() == "SPIRV_ScopeAttr" ||
525-
attr.getAttrDefName() == "SPIRV_MemorySemanticsAttr" ||
526-
attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr") {
527-
// These two enums are encoded as <id> to constant values in SPIR-V blob,
528-
// but we directly use the constant value as attribute in SPIR-V dialect. So
529-
// need to handle them separately from normal enum attributes.
533+
if (llvm::is_contained(constantIdEnumAttrs, attr.getAttrDefName())) {
530534
EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
531535
os << tabs
532536
<< formatv(" {0}.push_back(prepareConstantInt({1}.getLoc(), "
@@ -557,11 +561,18 @@ static void emitAttributeSerialization(const Attribute &attr,
557561
" {0}.push_back(static_cast<uint32_t>("
558562
"llvm::cast<IntegerAttr>(attr).getValue().getZExtValue()));\n",
559563
operandList);
560-
} else if (attr.isEnumAttr() || attr.getAttrDefName() == "TypeAttr") {
564+
} else if (attr.isEnumAttr() || attr.isTypeAttr()) {
565+
// It may be the first time this type appears in the IR, so we need to
566+
// process it.
567+
StringRef attrTypeID = "attrTypeID";
568+
os << tabs << formatv(" uint32_t {0} = 0;\n", attrTypeID);
561569
os << tabs
562-
<< formatv(" {0}.push_back(static_cast<uint32_t>("
563-
"getTypeID(llvm::cast<TypeAttr>(attr).getValue())));\n",
564-
operandList);
570+
<< formatv(" if (failed(processType({0}.getLoc(), "
571+
"llvm::cast<TypeAttr>(attr).getValue(), {1}))) {{\n",
572+
opVar, attrTypeID);
573+
os << tabs << " return failure();\n";
574+
os << tabs << " }\n";
575+
os << tabs << formatv(" {0}.push_back(attrTypeID);\n", operandList);
565576
} else {
566577
PrintFatalError(
567578
loc,
@@ -816,12 +827,7 @@ static void emitAttributeDeserialization(const Attribute &attr,
816827
StringRef attrList, StringRef attrName,
817828
StringRef words, StringRef wordIndex,
818829
raw_ostream &os) {
819-
if (attr.getAttrDefName() == "SPIRV_ScopeAttr" ||
820-
attr.getAttrDefName() == "SPIRV_MemorySemanticsAttr" ||
821-
attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr") {
822-
// These two enums are encoded as <id> to constant values in SPIR-V blob,
823-
// but we directly use the constant value as attribute in SPIR-V dialect. So
824-
// need to handle them separately from normal enum attributes.
830+
if (llvm::is_contained(constantIdEnumAttrs, attr.getAttrDefName())) {
825831
EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
826832
os << tabs
827833
<< formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "

0 commit comments

Comments
 (0)