-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][spirv] Add support for VectorAnyINTEL capability #68034
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
aaeb968
cce7b28
4ae99c6
1041658
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4146,7 +4146,12 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">; | |
def SPIRV_Float32 : TypeAlias<F32, "Float32">; | ||
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>; | ||
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>; | ||
def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16], | ||
// Remove the vector size restriction. | ||
// Although the vector size can be upto (2^64-1), uint64, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How does this actually work out for sizes > uint32 range? In SPIR-V the |
||
// 2^32-1 (UNINT32_MAX>) is a more realistic number, it should serve the purpose | ||
// for all practical cases. | ||
// Also unsigned is used for the number elements for composite tyeps. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo: types |
||
def SPIRV_Vector : VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], | ||
[SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>; | ||
// Component type check is done in the type parser for the following SPIR-V | ||
// dialect-specific types so we use "Any" here. | ||
|
@@ -4206,10 +4211,10 @@ class SPIRV_JointMatrixOfType<list<Type> allowedTypes> : | |
"Joint Matrix">; | ||
|
||
class SPIRV_ScalarOrVectorOf<Type type> : | ||
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>]>; | ||
AnyTypeOf<[type, VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>]>; | ||
|
||
class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> : | ||
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>, | ||
AnyTypeOf<[type, VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>, | ||
SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>; | ||
|
||
class SPIRV_MatrixOrCoopMatrixOf<Type type> : | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -546,6 +546,76 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks, | |
ScalableVectorOfLength<allowedLengths>.summary, | ||
"::mlir::VectorType">; | ||
|
||
// Whether the number of elements of a vector is from the given | ||
// `allowedRanges` list, the list has two values, start and end | ||
// of the range (inclusive). | ||
class IsVectorOfLengthRangePred<list<int> allowedRanges> | ||
: And<[IsVectorTypePred, | ||
And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements()>= }] # allowedRanges[0]>, | ||
CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>; | ||
|
||
// Whether the number of elements of a fixed-length vector is from the given | ||
// `allowedRanges` list, the list has two values, start and end of the range (inclusive). | ||
class IsFixedVectorOfLengthRangePred<list<int> allowedRanges> | ||
: And<[IsFixedVectorTypePred, | ||
And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() >= }] # allowedRanges[0]>, | ||
CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>; | ||
|
||
// Whether the minimum number of elements of a scalable vector is from the given | ||
// `allowedRanges` list, the list has two values, start and end of the range (inclusive). | ||
class IsScalableVectorOfMinLengthRangePred<list<int> allowedRanges> | ||
: And<[IsScalableVectorTypePred, | ||
And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() >= }] # allowedRanges[0]>, | ||
CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>; | ||
|
||
// Any vector where the number of elements is from the given | ||
// `allowedRanges` list. | ||
class VectorOfLengthRange<list<int> allowedRanges> | ||
: Type<IsVectorOfLengthRangePred<allowedRanges>, | ||
" of length " # !interleave(allowedRanges, "-"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use |
||
"::mlir::VectorType">; | ||
|
||
// Any fixed-length vector where the number of elements is from the given | ||
// `allowedRanges` list. | ||
class FixedVectorOfLengthRange<list<int> allowedRanges> | ||
: Type<IsFixedVectorOfLengthRangePred<allowedRanges>, | ||
" of length " # !interleave(allowedRanges, "-"), | ||
"::mlir::VectorType">; | ||
|
||
// Any scalable vector where the minimum number of elements is from the given | ||
// `allowedRanges` list. | ||
class ScalableVectorOfMinLengthRange<list<int> allowedRanges> | ||
: Type<IsScalableVectorOfMinLengthRangePred<allowedRanges>, | ||
" of length " # !interleave(allowedRanges, "-"), | ||
"::mlir::VectorType">; | ||
|
||
// Any vector where the number of elements is from the given | ||
// `allowedRanges` list and the type is from the given `allowedTypes` | ||
// list. | ||
class VectorOfLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes> | ||
: Type<And<[VectorOf<allowedTypes>.predicate, VectorOfLengthRange<allowedRanges>.predicate]>, | ||
VectorOf<allowedTypes>.summary # VectorOfLengthRange<allowedRanges>.summary, | ||
"::mlir::VectorType">; | ||
|
||
// Any fixed-length vector where the number of elements is from the given | ||
// `allowedRanges` list and the type is from the given `allowedTypes` | ||
// list. | ||
class FixedVectorOfLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes> | ||
: Type< | ||
And<[FixedVectorOf<allowedTypes>.predicate, FixedVectorOfLengthRange<allowedRanges>.predicate]>, | ||
FixedVectorOf<allowedTypes>.summary # FixedVectorOfLengthRange<allowedRanges>.summary, | ||
"::mlir::VectorType">; | ||
|
||
// Any scalable vector where the minimum number of elements is from the given | ||
// `allowedRanges` list and the type is from the given `allowedTypes` | ||
// list. | ||
class ScalableVectorOfMinLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we exclude any predicated not used by the SPIR-V dialect? |
||
: Type< | ||
And<[ScalableVectorOf<allowedTypes>.predicate, ScalableVectorOfMinLengthRange<allowedRanges>.predicate]>, | ||
ScalableVectorOf<allowedTypes>.summary # ScalableVectorOfMinLengthRange<allowedRanges>.summary, | ||
"::mlir::VectorType">; | ||
|
||
|
||
def AnyVector : VectorOf<[AnyType]>; | ||
// Temporary vector type clone that allows gradual transition to 0-D vectors. | ||
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>; | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -184,9 +184,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, | |||||
parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t; | ||||||
return Type(); | ||||||
} | ||||||
if (t.getNumElements() > 4) { | ||||||
// Number of elements should be between [2 - 2^32 -1], | ||||||
// since getNumElements() returns an unsigned, the upper limit check is | ||||||
// unnecessary. | ||||||
if (t.getNumElements() < 2) { | ||||||
parser.emitError( | ||||||
typeLoc, "vector length has to be less than or equal to 4 but found ") | ||||||
typeLoc, "vector length has to be between [2 - 2^32 -1] but found ") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I find this format a bit difficult to parse with two
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 |
||||||
<< t.getNumElements(); | ||||||
return Type(); | ||||||
} | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -101,9 +101,11 @@ bool CompositeType::classof(Type type) { | |
} | ||
|
||
bool CompositeType::isValid(VectorType type) { | ||
return type.getRank() == 1 && | ||
llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) && | ||
llvm::isa<ScalarType>(type.getElementType()); | ||
// Number of elements should be between [2 - 2^32 -1], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. |
||
// since getNumElements() returns an unsigned, the upper limit check is | ||
// unnecessary. | ||
return type.getRank() == 1 && llvm::isa<ScalarType>(type.getElementType()) && | ||
type.getNumElements() >= 2; | ||
} | ||
|
||
Type CompositeType::getElementType(unsigned index) const { | ||
|
@@ -171,9 +173,17 @@ void CompositeType::getCapabilities( | |
.Case<VectorType>([&](VectorType type) { | ||
auto vecSize = getNumElements(); | ||
if (vecSize == 8 || vecSize == 16) { | ||
static const Capability caps[] = {Capability::Vector16}; | ||
ArrayRef<Capability> ref(caps, std::size(caps)); | ||
capabilities.push_back(ref); | ||
static constexpr Capability caps[] = {Capability::Vector16, | ||
Capability::VectorAnyINTEL}; | ||
capabilities.push_back(caps); | ||
} | ||
// VectorAnyINTEL capability removes the vector size restriction and | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure I follow the logic here. We are pushing duplicated capabilities here? Shouldn't we do the check |
||
// allows the vector size to be up to (2^32-1). | ||
// Vector16 capability allows the vector size to be 8 and 16 | ||
SmallVector<unsigned, 5> allowedVecRange = {2, 3, 4, 8, 16}; | ||
if (vecSize >= 2 && !llvm::is_contained(allowedVecRange, vecSize)) { | ||
static constexpr Capability caps[] = {Capability::VectorAnyINTEL}; | ||
capabilities.push_back(caps); | ||
} | ||
return llvm::cast<ScalarType>(type.getElementType()) | ||
.getCapabilities(capabilities, storage); | ||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -43,9 +43,13 @@ using namespace mlir; | |||||||
template <typename LabelT> | ||||||||
static LogicalResult checkExtensionRequirements( | ||||||||
LabelT label, const spirv::TargetEnv &targetEnv, | ||||||||
const spirv::SPIRVType::ExtensionArrayRefVector &candidates) { | ||||||||
const spirv::SPIRVType::ExtensionArrayRefVector &candidates, | ||||||||
const ArrayRef<spirv::Extension> elidedCandidates = {}) { | ||||||||
for (const auto &ors : candidates) { | ||||||||
if (targetEnv.allows(ors)) | ||||||||
if (targetEnv.allows(ors) || | ||||||||
llvm::any_of(elidedCandidates, [&ors](spirv::Extension elidedExt) { | ||||||||
return llvm::is_contained(ors, elidedExt); | ||||||||
})) | ||||||||
continue; | ||||||||
|
||||||||
LLVM_DEBUG({ | ||||||||
|
@@ -71,9 +75,13 @@ static LogicalResult checkExtensionRequirements( | |||||||
template <typename LabelT> | ||||||||
static LogicalResult checkCapabilityRequirements( | ||||||||
LabelT label, const spirv::TargetEnv &targetEnv, | ||||||||
const spirv::SPIRVType::CapabilityArrayRefVector &candidates) { | ||||||||
const spirv::SPIRVType::CapabilityArrayRefVector &candidates, | ||||||||
const ArrayRef<spirv::Capability> elidedCandidates = {}) { | ||||||||
for (const auto &ors : candidates) { | ||||||||
if (targetEnv.allows(ors)) | ||||||||
if (targetEnv.allows(ors) || | ||||||||
llvm::any_of(elidedCandidates, [&ors](spirv::Capability elidedCap) { | ||||||||
return llvm::is_contained(ors, elidedCap); | ||||||||
})) | ||||||||
continue; | ||||||||
|
||||||||
LLVM_DEBUG({ | ||||||||
|
@@ -90,8 +98,55 @@ static LogicalResult checkCapabilityRequirements( | |||||||
return success(); | ||||||||
} | ||||||||
|
||||||||
/// Returns true if the given `storageClass` needs explicit layout when used in | ||||||||
/// Shader environments. | ||||||||
/// Check capabilities and extensions requirements | ||||||||
/// Checks that `capCandidates`, `extCandidates`, and capability | ||||||||
/// (`capCandidates`) infered extension requirements are possible to be | ||||||||
/// satisfied with the given `targetEnv`. | ||||||||
/// It also provides a way to relax requirements for certain capabilities and | ||||||||
/// extensions (e.g., `elidedCapCandidates`, `elidedExtCandidates`), this is to | ||||||||
/// allow passes to relax certain requirements based on an option (e.g., | ||||||||
/// relaxing bitwidth requirement, see `convertScalarType()`, | ||||||||
/// `ConvertVectorType()`). | ||||||||
template <typename LabelT> | ||||||||
static LogicalResult checkCapabilityAndExtensionRequirements( | ||||||||
LabelT label, const spirv::TargetEnv &targetEnv, | ||||||||
const spirv::SPIRVType::CapabilityArrayRefVector &capCandidates, | ||||||||
const spirv::SPIRVType::ExtensionArrayRefVector &extCandidates, | ||||||||
const ArrayRef<spirv::Capability> elidedCapCandidates = {}, | ||||||||
const ArrayRef<spirv::Extension> elidedExtCandidates = {}) { | ||||||||
Comment on lines
+115
to
+116
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In mlir, we generally don't use |
||||||||
SmallVector<ArrayRef<spirv::Extension>, 8> updatedExtCandidates; | ||||||||
llvm::append_range(updatedExtCandidates, extCandidates); | ||||||||
|
||||||||
if (failed(checkCapabilityRequirements(label, targetEnv, capCandidates, | ||||||||
elidedCapCandidates))) | ||||||||
return failure(); | ||||||||
// Add capablity infered extensions to the list of extension requirement list, | ||||||||
// only considers the capabilities that already available in the `targetEnv`. | ||||||||
|
||||||||
// WARNING: Some capabilities are part of both the core SPIR-V | ||||||||
// specification and an extension (e.g., 'Groups' capability is part of both | ||||||||
// core specification and SPV_AMD_shader_ballot extension, hence we should | ||||||||
// relax the capability inferred extension for these cases). | ||||||||
static const spirv::Capability multiModalCaps[] = {spirv::Capability::Groups}; | ||||||||
ArrayRef<spirv::Capability> multiModalCapsArrayRef(multiModalCaps, | ||||||||
std::size(multiModalCaps)); | ||||||||
|
||||||||
for (auto cap : targetEnv.getAttr().getCapabilities()) { | ||||||||
if (llvm::any_of(multiModalCapsArrayRef, | ||||||||
[&cap](spirv::Capability mMCap) { return cap == mMCap; })) | ||||||||
continue; | ||||||||
std::optional<ArrayRef<spirv::Extension>> ext = getExtensions(cap); | ||||||||
if (ext) | ||||||||
updatedExtCandidates.push_back(*ext); | ||||||||
} | ||||||||
if (failed(checkExtensionRequirements(label, targetEnv, updatedExtCandidates, | ||||||||
elidedExtCandidates))) | ||||||||
return failure(); | ||||||||
return success(); | ||||||||
} | ||||||||
|
||||||||
/// Returns true if the given `storageClass` needs explicit layout when used | ||||||||
/// in Shader environments. | ||||||||
static bool needsExplicitLayout(spirv::StorageClass storageClass) { | ||||||||
switch (storageClass) { | ||||||||
case spirv::StorageClass::PhysicalStorageBuffer: | ||||||||
|
@@ -230,8 +285,8 @@ convertScalarType(const spirv::TargetEnv &targetEnv, | |||||||
type.getCapabilities(capabilities, storageClass); | ||||||||
|
||||||||
// If all requirements are met, then we can accept this type as-is. | ||||||||
if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && | ||||||||
succeeded(checkExtensionRequirements(type, targetEnv, extensions))) | ||||||||
if (succeeded(checkCapabilityAndExtensionRequirements( | ||||||||
type, targetEnv, capabilities, extensions))) | ||||||||
return type; | ||||||||
|
||||||||
// Otherwise we need to adjust the type, which really means adjusting the | ||||||||
|
@@ -342,15 +397,35 @@ convertVectorType(const spirv::TargetEnv &targetEnv, | |||||||
cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass); | ||||||||
cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass); | ||||||||
|
||||||||
// If the bit-width related capabilities and extensions are not met | ||||||||
// for lower bit-width (<32-bit), convert it to 32-bit | ||||||||
auto elementType = | ||||||||
convertScalarType(targetEnv, options, scalarType, storageClass); | ||||||||
if (!elementType) | ||||||||
return nullptr; | ||||||||
type = VectorType::get(type.getShape(), elementType); | ||||||||
|
||||||||
SmallVector<spirv::Capability, 4> elidedCaps; | ||||||||
SmallVector<spirv::Extension, 4> elidedExts; | ||||||||
|
||||||||
// Relax the bitwidth requirements for capabilities and extensions | ||||||||
if (options.emulateLT32BitScalarTypes) { | ||||||||
elidedCaps.push_back(spirv::Capability::Int8); | ||||||||
elidedCaps.push_back(spirv::Capability::Int16); | ||||||||
elidedCaps.push_back(spirv::Capability::Float16); | ||||||||
} | ||||||||
// For capabilities whose requirements were relaxed, relax requirements for | ||||||||
// the extensions that were infered by those capabilities (e.g., elidedCaps) | ||||||||
for (spirv::Capability cap : elidedCaps) { | ||||||||
std::optional<ArrayRef<spirv::Extension>> ext = spirv::getExtensions(cap); | ||||||||
if (ext) | ||||||||
Comment on lines
+420
to
+421
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
llvm::append_range(elidedExts, *ext); | ||||||||
} | ||||||||
// If all requirements are met, then we can accept this type as-is. | ||||||||
if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && | ||||||||
succeeded(checkExtensionRequirements(type, targetEnv, extensions))) | ||||||||
if (succeeded(checkCapabilityAndExtensionRequirements( | ||||||||
type, targetEnv, capabilities, extensions, elidedCaps, elidedExts))) | ||||||||
return type; | ||||||||
|
||||||||
auto elementType = | ||||||||
convertScalarType(targetEnv, options, scalarType, storageClass); | ||||||||
if (elementType) | ||||||||
return VectorType::get(type.getShape(), elementType); | ||||||||
return nullptr; | ||||||||
} | ||||||||
|
||||||||
|
@@ -656,8 +731,9 @@ std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv, | |||||||
SmallVector<ArrayRef<spirv::Capability>, 2> caps; | ||||||||
scalarType.getExtensions(exts); | ||||||||
scalarType.getCapabilities(caps); | ||||||||
if (failed(checkCapabilityRequirements(type, targetEnv, caps)) || | ||||||||
failed(checkExtensionRequirements(type, targetEnv, exts))) { | ||||||||
|
||||||||
if (failed(checkCapabilityAndExtensionRequirements(type, targetEnv, caps, | ||||||||
exts))) { | ||||||||
auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs); | ||||||||
return castOp.getResult(0); | ||||||||
} | ||||||||
|
@@ -1150,16 +1226,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) { | |||||||
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions; | ||||||||
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities; | ||||||||
for (Type valueType : valueTypes) { | ||||||||
typeExtensions.clear(); | ||||||||
cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions); | ||||||||
if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, | ||||||||
typeExtensions))) | ||||||||
return false; | ||||||||
|
||||||||
typeCapabilities.clear(); | ||||||||
cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities); | ||||||||
if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, | ||||||||
typeCapabilities))) | ||||||||
typeExtensions.clear(); | ||||||||
cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions); | ||||||||
// Checking for capability and extension requirements along with capability | ||||||||
// infered extensions. | ||||||||
// If a capability is present, the extension that | ||||||||
// supports it should also be present, this reduces the burden of adding | ||||||||
// extension requirement that may or maynot be added in | ||||||||
// CompositeType::getExtensions(). | ||||||||
if (failed(checkCapabilityAndExtensionRequirements( | ||||||||
op->getName(), this->targetEnv, typeCapabilities, typeExtensions))) | ||||||||
return false; | ||||||||
} | ||||||||
|
||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This "Remove the vector size restriction." sentense does not need to be in the comment--it's the goal of this patch; but not a proper doc of
SPIRV_Vector
.