Skip to content

[mlir][spirv] Add support for VectorAnyINTEL capability #68034

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -4146,7 +4146,12 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
// Remove the vector size restriction.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This "Remove the vector size restriction." sentense does not need to be in the comment--it's the goal of this patch; but not a proper doc of SPIRV_Vector.

// Although the vector size can be upto (2^64-1), uint64,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this actually work out for sizes > uint32 range? In SPIR-V the OpTypeVector's component count is spec'ed to be a unsigned 32-bit integer.. Did the Intel spec somehow change the definition there? Could you point me to the spec?

// 2^32-1 (UNINT32_MAX>) is a more realistic number, it should serve the purpose
// for all practical cases.
// Also unsigned is used for the number elements for composite tyeps.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: types

def SPIRV_Vector : VectorOfLengthRangeAndType<[2, 0xFFFFFFFF],
[SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>;
// Component type check is done in the type parser for the following SPIR-V
// dialect-specific types so we use "Any" here.
Expand Down Expand Up @@ -4206,10 +4211,10 @@ class SPIRV_JointMatrixOfType<list<Type> allowedTypes> :
"Joint Matrix">;

class SPIRV_ScalarOrVectorOf<Type type> :
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>]>;
AnyTypeOf<[type, VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>]>;

class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
AnyTypeOf<[type, VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>,
SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>;

class SPIRV_MatrixOrCoopMatrixOf<Type type> :
Expand Down
70 changes: 70 additions & 0 deletions mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,76 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
ScalableVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;

// Whether the number of elements of a vector is from the given
// `allowedRanges` list, the list has two values, start and end
// of the range (inclusive).
class IsVectorOfLengthRangePred<list<int> allowedRanges>
: And<[IsVectorTypePred,
And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements()>= }] # allowedRanges[0]>,
CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>;

// Whether the number of elements of a fixed-length vector is from the given
// `allowedRanges` list, the list has two values, start and end of the range (inclusive).
class IsFixedVectorOfLengthRangePred<list<int> allowedRanges>
: And<[IsFixedVectorTypePred,
And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() >= }] # allowedRanges[0]>,
CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>;

// Whether the minimum number of elements of a scalable vector is from the given
// `allowedRanges` list, the list has two values, start and end of the range (inclusive).
class IsScalableVectorOfMinLengthRangePred<list<int> allowedRanges>
: And<[IsScalableVectorTypePred,
And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() >= }] # allowedRanges[0]>,
CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>;

// Any vector where the number of elements is from the given
// `allowedRanges` list.
class VectorOfLengthRange<list<int> allowedRanges>
: Type<IsVectorOfLengthRangePred<allowedRanges>,
" of length " # !interleave(allowedRanges, "-"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use to here instead of -? Using - reads like minus to me..

"::mlir::VectorType">;

// Any fixed-length vector where the number of elements is from the given
// `allowedRanges` list.
class FixedVectorOfLengthRange<list<int> allowedRanges>
: Type<IsFixedVectorOfLengthRangePred<allowedRanges>,
" of length " # !interleave(allowedRanges, "-"),
"::mlir::VectorType">;

// Any scalable vector where the minimum number of elements is from the given
// `allowedRanges` list.
class ScalableVectorOfMinLengthRange<list<int> allowedRanges>
: Type<IsScalableVectorOfMinLengthRangePred<allowedRanges>,
" of length " # !interleave(allowedRanges, "-"),
"::mlir::VectorType">;

// Any vector where the number of elements is from the given
// `allowedRanges` list and the type is from the given `allowedTypes`
// list.
class VectorOfLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes>
: Type<And<[VectorOf<allowedTypes>.predicate, VectorOfLengthRange<allowedRanges>.predicate]>,
VectorOf<allowedTypes>.summary # VectorOfLengthRange<allowedRanges>.summary,
"::mlir::VectorType">;

// Any fixed-length vector where the number of elements is from the given
// `allowedRanges` list and the type is from the given `allowedTypes`
// list.
class FixedVectorOfLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes>
: Type<
And<[FixedVectorOf<allowedTypes>.predicate, FixedVectorOfLengthRange<allowedRanges>.predicate]>,
FixedVectorOf<allowedTypes>.summary # FixedVectorOfLengthRange<allowedRanges>.summary,
"::mlir::VectorType">;

// Any scalable vector where the minimum number of elements is from the given
// `allowedRanges` list and the type is from the given `allowedTypes`
// list.
class ScalableVectorOfMinLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we exclude any predicated not used by the SPIR-V dialect?

: Type<
And<[ScalableVectorOf<allowedTypes>.predicate, ScalableVectorOfMinLengthRange<allowedRanges>.predicate]>,
ScalableVectorOf<allowedTypes>.summary # ScalableVectorOfMinLengthRange<allowedRanges>.summary,
"::mlir::VectorType">;


def AnyVector : VectorOf<[AnyType]>;
// Temporary vector type clone that allows gradual transition to 0-D vectors.
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
Expand Down
7 changes: 5 additions & 2 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
return Type();
}
if (t.getNumElements() > 4) {
// Number of elements should be between [2 - 2^32 -1],
// since getNumElements() returns an unsigned, the upper limit check is
// unnecessary.
if (t.getNumElements() < 2) {
parser.emitError(
typeLoc, "vector length has to be less than or equal to 4 but found ")
typeLoc, "vector length has to be between [2 - 2^32 -1] but found ")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I find this format a bit difficult to parse with two - signs. What do you think about something like this:

Suggested change
typeLoc, "vector length has to be between [2 - 2^32 -1] but found ")
typeLoc, "vector length must be in the range [2, 2^32), but found ")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

<< t.getNumElements();
return Type();
}
Expand Down
22 changes: 16 additions & 6 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,11 @@ bool CompositeType::classof(Type type) {
}

bool CompositeType::isValid(VectorType type) {
return type.getRank() == 1 &&
llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
llvm::isa<ScalarType>(type.getElementType());
// Number of elements should be between [2 - 2^32 -1],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

// since getNumElements() returns an unsigned, the upper limit check is
// unnecessary.
return type.getRank() == 1 && llvm::isa<ScalarType>(type.getElementType()) &&
type.getNumElements() >= 2;
}

Type CompositeType::getElementType(unsigned index) const {
Expand Down Expand Up @@ -171,9 +173,17 @@ void CompositeType::getCapabilities(
.Case<VectorType>([&](VectorType type) {
auto vecSize = getNumElements();
if (vecSize == 8 || vecSize == 16) {
static const Capability caps[] = {Capability::Vector16};
ArrayRef<Capability> ref(caps, std::size(caps));
capabilities.push_back(ref);
static constexpr Capability caps[] = {Capability::Vector16,
Capability::VectorAnyINTEL};
capabilities.push_back(caps);
}
// VectorAnyINTEL capability removes the vector size restriction and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow the logic here. We are pushing duplicated capabilities here? Shouldn't we do the check vecSize > 4 to include VectorAnyINTEL first, and then additionally add Vector16 if vecSize == 8 || vecSize == 16?

// allows the vector size to be up to (2^32-1).
// Vector16 capability allows the vector size to be 8 and 16
SmallVector<unsigned, 5> allowedVecRange = {2, 3, 4, 8, 16};
if (vecSize >= 2 && !llvm::is_contained(allowedVecRange, vecSize)) {
static constexpr Capability caps[] = {Capability::VectorAnyINTEL};
capabilities.push_back(caps);
}
return llvm::cast<ScalarType>(type.getElementType())
.getCapabilities(capabilities, storage);
Expand Down
126 changes: 102 additions & 24 deletions mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,13 @@ using namespace mlir;
template <typename LabelT>
static LogicalResult checkExtensionRequirements(
LabelT label, const spirv::TargetEnv &targetEnv,
const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
const ArrayRef<spirv::Extension> elidedCandidates = {}) {
for (const auto &ors : candidates) {
if (targetEnv.allows(ors))
if (targetEnv.allows(ors) ||
llvm::any_of(elidedCandidates, [&ors](spirv::Extension elidedExt) {
return llvm::is_contained(ors, elidedExt);
}))
continue;

LLVM_DEBUG({
Expand All @@ -71,9 +75,13 @@ static LogicalResult checkExtensionRequirements(
template <typename LabelT>
static LogicalResult checkCapabilityRequirements(
LabelT label, const spirv::TargetEnv &targetEnv,
const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
const ArrayRef<spirv::Capability> elidedCandidates = {}) {
for (const auto &ors : candidates) {
if (targetEnv.allows(ors))
if (targetEnv.allows(ors) ||
llvm::any_of(elidedCandidates, [&ors](spirv::Capability elidedCap) {
return llvm::is_contained(ors, elidedCap);
}))
continue;

LLVM_DEBUG({
Expand All @@ -90,8 +98,55 @@ static LogicalResult checkCapabilityRequirements(
return success();
}

/// Returns true if the given `storageClass` needs explicit layout when used in
/// Shader environments.
/// Check capabilities and extensions requirements
/// Checks that `capCandidates`, `extCandidates`, and capability
/// (`capCandidates`) infered extension requirements are possible to be
/// satisfied with the given `targetEnv`.
/// It also provides a way to relax requirements for certain capabilities and
/// extensions (e.g., `elidedCapCandidates`, `elidedExtCandidates`), this is to
/// allow passes to relax certain requirements based on an option (e.g.,
/// relaxing bitwidth requirement, see `convertScalarType()`,
/// `ConvertVectorType()`).
template <typename LabelT>
static LogicalResult checkCapabilityAndExtensionRequirements(
LabelT label, const spirv::TargetEnv &targetEnv,
const spirv::SPIRVType::CapabilityArrayRefVector &capCandidates,
const spirv::SPIRVType::ExtensionArrayRefVector &extCandidates,
const ArrayRef<spirv::Capability> elidedCapCandidates = {},
const ArrayRef<spirv::Extension> elidedExtCandidates = {}) {
Comment on lines +115 to +116
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In mlir, we generally don't use const for arguments passed by value

SmallVector<ArrayRef<spirv::Extension>, 8> updatedExtCandidates;
llvm::append_range(updatedExtCandidates, extCandidates);

if (failed(checkCapabilityRequirements(label, targetEnv, capCandidates,
elidedCapCandidates)))
return failure();
// Add capablity infered extensions to the list of extension requirement list,
// only considers the capabilities that already available in the `targetEnv`.

// WARNING: Some capabilities are part of both the core SPIR-V
// specification and an extension (e.g., 'Groups' capability is part of both
// core specification and SPV_AMD_shader_ballot extension, hence we should
// relax the capability inferred extension for these cases).
static const spirv::Capability multiModalCaps[] = {spirv::Capability::Groups};
ArrayRef<spirv::Capability> multiModalCapsArrayRef(multiModalCaps,
std::size(multiModalCaps));

for (auto cap : targetEnv.getAttr().getCapabilities()) {
if (llvm::any_of(multiModalCapsArrayRef,
[&cap](spirv::Capability mMCap) { return cap == mMCap; }))
continue;
std::optional<ArrayRef<spirv::Extension>> ext = getExtensions(cap);
if (ext)
updatedExtCandidates.push_back(*ext);
}
if (failed(checkExtensionRequirements(label, targetEnv, updatedExtCandidates,
elidedExtCandidates)))
return failure();
return success();
}

/// Returns true if the given `storageClass` needs explicit layout when used
/// in Shader environments.
static bool needsExplicitLayout(spirv::StorageClass storageClass) {
switch (storageClass) {
case spirv::StorageClass::PhysicalStorageBuffer:
Expand Down Expand Up @@ -230,8 +285,8 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
type.getCapabilities(capabilities, storageClass);

// If all requirements are met, then we can accept this type as-is.
if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
if (succeeded(checkCapabilityAndExtensionRequirements(
type, targetEnv, capabilities, extensions)))
return type;

// Otherwise we need to adjust the type, which really means adjusting the
Expand Down Expand Up @@ -342,15 +397,35 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);

// If the bit-width related capabilities and extensions are not met
// for lower bit-width (<32-bit), convert it to 32-bit
auto elementType =
convertScalarType(targetEnv, options, scalarType, storageClass);
if (!elementType)
return nullptr;
type = VectorType::get(type.getShape(), elementType);

SmallVector<spirv::Capability, 4> elidedCaps;
SmallVector<spirv::Extension, 4> elidedExts;

// Relax the bitwidth requirements for capabilities and extensions
if (options.emulateLT32BitScalarTypes) {
elidedCaps.push_back(spirv::Capability::Int8);
elidedCaps.push_back(spirv::Capability::Int16);
elidedCaps.push_back(spirv::Capability::Float16);
}
// For capabilities whose requirements were relaxed, relax requirements for
// the extensions that were infered by those capabilities (e.g., elidedCaps)
for (spirv::Capability cap : elidedCaps) {
std::optional<ArrayRef<spirv::Extension>> ext = spirv::getExtensions(cap);
if (ext)
Comment on lines +420 to +421
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::optional<ArrayRef<spirv::Extension>> ext = spirv::getExtensions(cap);
if (ext)
if (std::optional<ArrayRef<spirv::Extension>> ext = spirv::getExtensions(cap))

llvm::append_range(elidedExts, *ext);
}
// If all requirements are met, then we can accept this type as-is.
if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
if (succeeded(checkCapabilityAndExtensionRequirements(
type, targetEnv, capabilities, extensions, elidedCaps, elidedExts)))
return type;

auto elementType =
convertScalarType(targetEnv, options, scalarType, storageClass);
if (elementType)
return VectorType::get(type.getShape(), elementType);
return nullptr;
}

Expand Down Expand Up @@ -656,8 +731,9 @@ std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
SmallVector<ArrayRef<spirv::Capability>, 2> caps;
scalarType.getExtensions(exts);
scalarType.getCapabilities(caps);
if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
failed(checkExtensionRequirements(type, targetEnv, exts))) {

if (failed(checkCapabilityAndExtensionRequirements(type, targetEnv, caps,
exts))) {
auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
return castOp.getResult(0);
}
Expand Down Expand Up @@ -1150,16 +1226,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
for (Type valueType : valueTypes) {
typeExtensions.clear();
cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
typeExtensions)))
return false;

typeCapabilities.clear();
cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
typeCapabilities)))
typeExtensions.clear();
cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
// Checking for capability and extension requirements along with capability
// infered extensions.
// If a capability is present, the extension that
// supports it should also be present, this reduces the burden of adding
// extension requirement that may or maynot be added in
// CompositeType::getExtensions().
if (failed(checkCapabilityAndExtensionRequirements(
op->getName(), this->targetEnv, typeCapabilities, typeExtensions)))
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ module attributes {
#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>>
} {

func.func @unsupported_5elem_vector(%arg0: vector<5xi32>) {
func.func @unsupported_5elem_vector(%arg0: vector<5xi32>, %arg1: vector<5xi32>) {
// expected-error@+1 {{failed to legalize operation 'arith.subi'}}
%1 = arith.subi %arg0, %arg0: vector<5xi32>
%1 = arith.subi %arg0, %arg1: vector<5xi32>
return
}

Expand Down
40 changes: 40 additions & 0 deletions mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1407,3 +1407,43 @@ func.func @float_scalar(%arg0: f16) {
}

} // end module

// -----

//===----------------------------------------------------------------------===//
// VectorAnyINTEL support
//===----------------------------------------------------------------------===//

// Check that with VectorAnyINTEL, VectorComputeINTEL capability,
// and SPV_INTEL_vector_compute extension, any sized (2-2^32 -1) vector is allowed.
module attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Kernel, VectorAnyINTEL], [SPV_INTEL_vector_compute]>, #spirv.resource_limits<>>
} {

// CHECK-LABEL: @any_vector
func.func @any_vector(%arg0: vector<16xi32>, %arg1: vector<16xi32>) {
// CHECK: spirv.ISub %{{.+}}, %{{.+}}: vector<16xi32>
%0 = arith.subi %arg0, %arg1: vector<16xi32>
return
}

// CHECK-LABEL: @max_vector
func.func @max_vector(%arg0: vector<4294967295xi32>, %arg1: vector<4294967295xi32>) {
// CHECK: spirv.ISub %{{.+}}, %{{.+}}: vector<4294967295xi32>
%0 = arith.subi %arg0, %arg1: vector<4294967295xi32>
return
}


// Check float vector types of any size.
// CHECK-LABEL: @float_vector58
func.func @float_vector58(%arg0: vector<5xf16>, %arg1: vector<8xf64>) {
// CHECK: spirv.FAdd %{{.*}}, %{{.*}}: vector<5xf16>
%0 = arith.addf %arg0, %arg0: vector<5xf16>
// CHECK: spirv.FMul %{{.*}}, %{{.*}}: vector<8xf64>
%1 = arith.mulf %arg1, %arg1: vector<8xf64>
return
}

} // end module
Loading