Skip to content

Commit 46aa23a

Browse files
committed
[mlir][spirv] Extend capabilities and extensions requirements checking.
Allow a way to relax requirements for certain capabilities and extensions (e.g., `elidedCandidates`). Also add a combined check for capabilities and extensions in `checkCapabilityAndExtensionRequirements`. This function checks capabilities, extensions, and capability infered extension requirements.
1 parent e0f86ca commit 46aa23a

File tree

1 file changed

+61
-6
lines changed

1 file changed

+61
-6
lines changed

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,13 @@ using namespace mlir;
4343
template <typename LabelT>
4444
static LogicalResult checkExtensionRequirements(
4545
LabelT label, const spirv::TargetEnv &targetEnv,
46-
const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
46+
const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
47+
ArrayRef<spirv::Extension> elidedCandidates = {}) {
4748
for (const auto &ors : candidates) {
48-
if (targetEnv.allows(ors))
49+
if (targetEnv.allows(ors) ||
50+
llvm::any_of(elidedCandidates, [&ors](spirv::Extension elidedExt) {
51+
return llvm::is_contained(ors, elidedExt);
52+
}))
4953
continue;
5054

5155
LLVM_DEBUG({
@@ -71,9 +75,13 @@ static LogicalResult checkExtensionRequirements(
7175
template <typename LabelT>
7276
static LogicalResult checkCapabilityRequirements(
7377
LabelT label, const spirv::TargetEnv &targetEnv,
74-
const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
78+
const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
79+
ArrayRef<spirv::Capability> elidedCandidates = {}) {
7580
for (const auto &ors : candidates) {
76-
if (targetEnv.allows(ors))
81+
if (targetEnv.allows(ors) ||
82+
llvm::any_of(elidedCandidates, [&ors](spirv::Capability elidedCap) {
83+
return llvm::is_contained(ors, elidedCap);
84+
}))
7785
continue;
7886

7987
LLVM_DEBUG({
@@ -90,8 +98,55 @@ static LogicalResult checkCapabilityRequirements(
9098
return success();
9199
}
92100

93-
/// Returns true if the given `storageClass` needs explicit layout when used in
94-
/// Shader environments.
101+
/// Check capabilities and extensions requirements
102+
/// Checks that `capCandidates`, `extCandidates`, and capability
103+
/// (`capCandidates`) infered extension requirements are possible to be
104+
/// satisfied with the given `targetEnv`.
105+
/// It also provides a way to relax requirements for certain capabilities and
106+
/// extensions (e.g., `elidedCapCandidates`, `elidedExtCandidates`), this is to
107+
/// allow passes to relax certain requirements based on an option (e.g.,
108+
/// relaxing bitwidth requirement, see `convertScalarType()`,
109+
/// `ConvertVectorType()`).
110+
template <typename LabelT>
111+
static LogicalResult checkCapabilityAndExtensionRequirements(
112+
LabelT label, const spirv::TargetEnv &targetEnv,
113+
const spirv::SPIRVType::CapabilityArrayRefVector &capCandidates,
114+
const spirv::SPIRVType::ExtensionArrayRefVector &extCandidates,
115+
ArrayRef<spirv::Capability> elidedCapCandidates = {},
116+
ArrayRef<spirv::Extension> elidedExtCandidates = {}) {
117+
SmallVector<ArrayRef<spirv::Extension>, 8> updatedExtCandidates;
118+
llvm::append_range(updatedExtCandidates, extCandidates);
119+
120+
if (failed(checkCapabilityRequirements(label, targetEnv, capCandidates,
121+
elidedCapCandidates)))
122+
return failure();
123+
// Add capablity infered extensions to the list of extension requirement list,
124+
// only considers the capabilities that already available in the `targetEnv`.
125+
126+
// WARNING: Some capabilities are part of both the core SPIR-V
127+
// specification and an extension (e.g., 'Groups' capability is part of both
128+
// core specification and SPV_AMD_shader_ballot extension, hence we should
129+
// relax the capability inferred extension for these cases).
130+
static const spirv::Capability multiModalCaps[] = {spirv::Capability::Groups};
131+
ArrayRef<spirv::Capability> multiModalCapsArrayRef(multiModalCaps,
132+
std::size(multiModalCaps));
133+
134+
for (auto cap : targetEnv.getAttr().getCapabilities()) {
135+
if (llvm::any_of(multiModalCapsArrayRef,
136+
[&cap](spirv::Capability mMCap) { return cap == mMCap; }))
137+
continue;
138+
std::optional<ArrayRef<spirv::Extension>> ext = getExtensions(cap);
139+
if (ext)
140+
updatedExtCandidates.push_back(*ext);
141+
}
142+
if (failed(checkExtensionRequirements(label, targetEnv, updatedExtCandidates,
143+
elidedExtCandidates)))
144+
return failure();
145+
return success();
146+
}
147+
148+
/// Returns true if the given `storageClass` needs explicit layout when used
149+
/// in Shader environments.
95150
static bool needsExplicitLayout(spirv::StorageClass storageClass) {
96151
switch (storageClass) {
97152
case spirv::StorageClass::PhysicalStorageBuffer:

0 commit comments

Comments
 (0)