Skip to content

[mlir][spirv] Support spirv.coopmatrix type (de-)serialization #65831

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -4053,6 +4053,8 @@ def SPIRV_KHR_CMU_MatrixA : I32EnumAttrCase<"MatrixA", 0>;
def SPIRV_KHR_CMU_MatrixB : I32EnumAttrCase<"MatrixB", 1>;
def SPIRV_KHR_CMU_MatrixAcc : I32EnumAttrCase<"MatrixAcc", 2>;

// NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
// SPIR-V proper.
def SPIRV_KHR_CooperativeMatrixUseAttr :
SPIRV_I32EnumAttr<"CooperativeMatrixUseKHR",
"valid SPIR-V Cooperative Matrix Use (KHR)",
Expand All @@ -4064,6 +4066,8 @@ def SPIRV_KHR_CooperativeMatrixUseAttr :
def SPIRV_KHR_CML_RowMajor : I32EnumAttrCase<"RowMajor", 0>;
def SPIRV_KHR_CML_ColumnMajor : I32EnumAttrCase<"ColumnMajor", 1>;

// NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
// SPIR-V proper.
def SPIRV_KHR_CooperativeMatrixLayoutAttr :
SPIRV_I32EnumAttr<"CooperativeMatrixLayoutKHR",
"valid SPIR-V Cooperative Matrix Layout (KHR)",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,14 @@ def SPIRV_KHRCooperativeMatrixLengthOp :
];

let arguments = (ins
TypeAttr:$cooperative_matrix_type
TypeAttrOf<SPIRV_AnyCooperativeMatrix>:$cooperative_matrix_type
);

let results = (outs
SPIRV_Int32:$result
);

let hasVerifier = false;
}

// -----
Expand Down
14 changes: 0 additions & 14 deletions mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,6 @@
using namespace mlir::spirv::AttrNames;

namespace mlir::spirv {
//===----------------------------------------------------------------------===//
// spirv.KHR.CooperativeMatrixLength
//===----------------------------------------------------------------------===//

LogicalResult KHRCooperativeMatrixLengthOp::verify() {
if (!isa<CooperativeMatrixType>(getCooperativeMatrixType())) {
return emitOpError(
"type attribute must be a '!spirv.coopmatrix' type, found ")
<< getCooperativeMatrixType() << " instead";
}

return success();
}

//===----------------------------------------------------------------------===//
// spirv.KHR.CooperativeMatrixLoad
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ LogicalResult spirv::Deserializer::processInstruction(
case spirv::Opcode::OpTypeRuntimeArray:
case spirv::Opcode::OpTypeStruct:
case spirv::Opcode::OpTypePointer:
case spirv::Opcode::OpTypeCooperativeMatrixKHR:
case spirv::Opcode::OpTypeCooperativeMatrixNV:
return processType(opcode, operands);
case spirv::Opcode::OpTypeForwardPointer:
Expand Down
66 changes: 56 additions & 10 deletions mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -765,8 +765,10 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
} break;
case spirv::Opcode::OpTypeArray:
return processArrayType(operands);
case spirv::Opcode::OpTypeCooperativeMatrixKHR:
return processCooperativeMatrixTypeKHR(operands);
case spirv::Opcode::OpTypeCooperativeMatrixNV:
return processCooperativeMatrixType(operands);
return processCooperativeMatrixTypeNV(operands);
case spirv::Opcode::OpTypeFunction:
return processFunctionType(operands);
case spirv::Opcode::OpTypeJointMatrixINTEL:
Expand Down Expand Up @@ -900,32 +902,76 @@ spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
return success();
}

LogicalResult
spirv::Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
ArrayRef<uint32_t> operands) {
if (operands.size() != 6) {
return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR must have element type, "
"scope, row and column parameters, and use");
}

Type elementTy = getType(operands[1]);
if (!elementTy) {
return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR references undefined <id> ")
<< operands[1];
}

std::optional<spirv::Scope> scope =
spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
if (!scope) {
return emitError(
unknownLoc,
"OpTypeCooperativeMatrixKHR references undefined scope <id> ")
<< operands[2];
}

unsigned rows = getConstantInt(operands[3]).getInt();
unsigned columns = getConstantInt(operands[4]).getInt();

std::optional<spirv::CooperativeMatrixUseKHR> use =
spirv::symbolizeCooperativeMatrixUseKHR(
getConstantInt(operands[5]).getInt());
if (!use) {
return emitError(
unknownLoc,
"OpTypeCooperativeMatrixKHR references undefined use <id> ")
<< operands[5];
}

typeMap[operands[0]] =
spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use);
return success();
}

LogicalResult spirv::Deserializer::processCooperativeMatrixTypeNV(
ArrayRef<uint32_t> operands) {
if (operands.size() != 5) {
return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element "
return emitError(unknownLoc, "OpTypeCooperativeMatrixNV must have element "
"type and row x column parameters");
}

Type elementTy = getType(operands[1]);
if (!elementTy) {
return emitError(unknownLoc,
"OpTypeCooperativeMatrix references undefined <id> ")
"OpTypeCooperativeMatrixNV references undefined <id> ")
<< operands[1];
}

auto scope = spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
std::optional<spirv::Scope> scope =
spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
if (!scope) {
return emitError(unknownLoc,
"OpTypeCooperativeMatrix references undefined scope <id> ")
return emitError(
unknownLoc,
"OpTypeCooperativeMatrixNV references undefined scope <id> ")
<< operands[2];
}

unsigned rows = getConstantInt(operands[3]).getInt();
unsigned columns = getConstantInt(operands[4]).getInt();

typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get(
elementTy, scope.value(), rows, columns);
typeMap[operands[0]] =
spirv::CooperativeMatrixNVType::get(elementTy, *scope, rows, columns);
return success();
}

Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,9 @@ class Deserializer {

LogicalResult processArrayType(ArrayRef<uint32_t> operands);

LogicalResult processCooperativeMatrixType(ArrayRef<uint32_t> operands);
LogicalResult processCooperativeMatrixTypeKHR(ArrayRef<uint32_t> operands);

LogicalResult processCooperativeMatrixTypeNV(ArrayRef<uint32_t> operands);

LogicalResult processFunctionType(ArrayRef<uint32_t> operands);

Expand Down
22 changes: 22 additions & 0 deletions mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,28 @@ LogicalResult Serializer::prepareBasicType(
return success();
}

if (auto cooperativeMatrixType =
dyn_cast<spirv::CooperativeMatrixType>(type)) {
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
elementTypeID, serializationCtx))) {
return failure();
}
typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
auto getConstantOp = [&](uint32_t id) {
auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
return prepareConstantInt(loc, attr);
};
operands.push_back(elementTypeID);
operands.push_back(
getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())));
operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
operands.push_back(
getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse())));
return success();
}

if (auto cooperativeMatrixType =
dyn_cast<spirv::CooperativeMatrixNVType>(type)) {
uint32_t elementTypeID = 0;
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ spirv.func @cooperative_matrix_length() -> i32 "None" {
// -----

spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" {
// expected-error @+1 {{'spirv.KHR.CooperativeMatrixLength' op type attribute must be a '!spirv.coopmatrix'}}
// expected-error @+1 {{'cooperative_matrix_type' failed to satisfy constraint: type attribute of any SPIR-V cooperative matrix type}}
%0 = spirv.KHR.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.ReturnValue %0 : i32
}
Expand Down
93 changes: 93 additions & 0 deletions mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip \
// RUN: --split-input-file %s | FileCheck %s

spirv.module Logical GLSL450 requires
#spirv.vce<v1.5, [Shader, Int8, Int16, Int64, Linkage, CooperativeMatrixKHR],
[SPV_KHR_storage_buffer_storage_class, SPV_KHR_cooperative_matrix]> {

// CHECK-LABEL: @cooperative_matrix_length
spirv.func @cooperative_matrix_length() "None" {
// CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixLength : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>
%0 = spirv.KHR.CooperativeMatrixLength : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>
spirv.Return
}

// CHECK-LABEL: @cooperative_matrix_load_1
spirv.func @cooperative_matrix_load_1(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
// CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>
// CHECK-SAME: : !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
!spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
spirv.Return
}

// CHECK-LABEL: @cooperative_matrix_load_2
spirv.func @cooperative_matrix_load_2(%ptr : !spirv.ptr<f32, StorageBuffer>, %stride : i64) "None" {
// CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <ColumnMajor>, <Volatile>
// CHECK-SAME: : !spirv.ptr<f32, StorageBuffer>, i64 -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixAcc>
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile> :
!spirv.ptr<f32, StorageBuffer>, i64 -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixAcc>
spirv.Return
}

// CHECK-LABEL: @cooperative_matrix_store_1
spirv.func @cooperative_matrix_store_1(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
%m : !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>) "None" {
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, <RowMajor>
// CHECK-SAME: : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor> :
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
spirv.Return
}

// CHECK-LABEL: @cooperative_matrix_store_2
spirv.func @cooperative_matrix_store_2(%ptr : !spirv.ptr<f32, Workgroup>, %stride : i64,
%m : !spirv.coopmatrix<4x8xf32, Subgroup, MatrixB>) "None" {
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, <ColumnMajor>, <Nontemporal>
// CHECK-SAME: : !spirv.ptr<f32, Workgroup>, !spirv.coopmatrix<4x8xf32, Subgroup, MatrixB>, i64
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <ColumnMajor>, <Nontemporal> :
!spirv.ptr<f32, Workgroup>, !spirv.coopmatrix<4x8xf32, Subgroup, MatrixB>, i64
spirv.Return
}

// CHECK-LABEL: @cooperative_matrix_muladd
spirv.func @cooperative_matrix_muladd_1(%a : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
%b : !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>,
%c : !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>) "None" {
// CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} :
// CHECK-SAME: !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
// CHECK-SAME: !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
// CHECK-SAME: -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
%p = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
!spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
-> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>

// CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}, <BSigned> :
// CHECK-SAME: !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
// CHECK-SAME: !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
// CHECK-SAME: -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
%q = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
<BSigned> : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
!spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
-> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>

// TODO: Handle multiple matrix operands and add relevant testcases here.
spirv.Return
}

// CHECK-LABEL: @cooperative_matrix_muladd
spirv.func @cooperative_matrix_muladd_2(%a : !spirv.coopmatrix<8x8xf32, Workgroup, MatrixA>,
%b : !spirv.coopmatrix<8x8xf32, Workgroup, MatrixB>,
%c : !spirv.coopmatrix<8x8xf32, Workgroup, MatrixAcc>) "None" {
// CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} :
// CHECK-SAME: !spirv.coopmatrix<8x8xf32, Workgroup, MatrixA>,
// CHECK-SAME: !spirv.coopmatrix<8x8xf32, Workgroup, MatrixB>
// CHECK-SAME: -> !spirv.coopmatrix<8x8xf32, Workgroup, MatrixAcc>
%r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x8xf32, Workgroup, MatrixA>,
!spirv.coopmatrix<8x8xf32, Workgroup, MatrixB>
-> !spirv.coopmatrix<8x8xf32, Workgroup, MatrixAcc>

spirv.Return
}

}
38 changes: 22 additions & 16 deletions mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
Expand Down Expand Up @@ -512,6 +513,14 @@ static mlir::GenRegistration
// Serialization AutoGen
//===----------------------------------------------------------------------===//

// These enums are encoded as <id> to constant values in SPIR-V blob, but we
// directly use the constant value as attribute in SPIR-V dialect. So need
// to handle them separately from normal enum attributes.
constexpr llvm::StringLiteral constantIdEnumAttrs[] = {
"SPIRV_ScopeAttr", "SPIRV_KHR_CooperativeMatrixUseAttr",
"SPIRV_KHR_CooperativeMatrixLayoutAttr", "SPIRV_MemorySemanticsAttr",
"SPIRV_MatrixLayoutAttr"};

/// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The
/// generates code extracts the attribute with name `attrName` from
/// `operandList` of `op`.
Expand All @@ -521,12 +530,7 @@ static void emitAttributeSerialization(const Attribute &attr,
StringRef attrName, raw_ostream &os) {
os << tabs
<< formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName);
if (attr.getAttrDefName() == "SPIRV_ScopeAttr" ||
attr.getAttrDefName() == "SPIRV_MemorySemanticsAttr" ||
attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr") {
// These two enums are encoded as <id> to constant values in SPIR-V blob,
// but we directly use the constant value as attribute in SPIR-V dialect. So
// need to handle them separately from normal enum attributes.
if (llvm::is_contained(constantIdEnumAttrs, attr.getAttrDefName())) {
EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
os << tabs
<< formatv(" {0}.push_back(prepareConstantInt({1}.getLoc(), "
Expand Down Expand Up @@ -557,11 +561,18 @@ static void emitAttributeSerialization(const Attribute &attr,
" {0}.push_back(static_cast<uint32_t>("
"llvm::cast<IntegerAttr>(attr).getValue().getZExtValue()));\n",
operandList);
} else if (attr.isEnumAttr() || attr.getAttrDefName() == "TypeAttr") {
} else if (attr.isEnumAttr() || attr.isTypeAttr()) {
// It may be the first time this type appears in the IR, so we need to
// process it.
StringRef attrTypeID = "attrTypeID";
os << tabs << formatv(" uint32_t {0} = 0;\n", attrTypeID);
os << tabs
<< formatv(" {0}.push_back(static_cast<uint32_t>("
"getTypeID(llvm::cast<TypeAttr>(attr).getValue())));\n",
operandList);
<< formatv(" if (failed(processType({0}.getLoc(), "
"llvm::cast<TypeAttr>(attr).getValue(), {1}))) {{\n",
opVar, attrTypeID);
os << tabs << " return failure();\n";
os << tabs << " }\n";
os << tabs << formatv(" {0}.push_back(attrTypeID);\n", operandList);
} else {
PrintFatalError(
loc,
Expand Down Expand Up @@ -816,12 +827,7 @@ static void emitAttributeDeserialization(const Attribute &attr,
StringRef attrList, StringRef attrName,
StringRef words, StringRef wordIndex,
raw_ostream &os) {
if (attr.getAttrDefName() == "SPIRV_ScopeAttr" ||
attr.getAttrDefName() == "SPIRV_MemorySemanticsAttr" ||
attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr") {
// These two enums are encoded as <id> to constant values in SPIR-V blob,
// but we directly use the constant value as attribute in SPIR-V dialect. So
// need to handle them separately from normal enum attributes.
if (llvm::is_contained(constantIdEnumAttrs, attr.getAttrDefName())) {
EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
os << tabs
<< formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
Expand Down