@@ -43,9 +43,13 @@ using namespace mlir;
43
43
template <typename LabelT>
44
44
static LogicalResult checkExtensionRequirements (
45
45
LabelT label, const spirv::TargetEnv &targetEnv,
46
- const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
46
+ const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
47
+ ArrayRef<spirv::Extension> elidedCandidates = {}) {
47
48
for (const auto &ors : candidates) {
48
- if (targetEnv.allows (ors))
49
+ if (targetEnv.allows (ors) ||
50
+ llvm::any_of (elidedCandidates, [&ors](spirv::Extension elidedExt) {
51
+ return llvm::is_contained (ors, elidedExt);
52
+ }))
49
53
continue ;
50
54
51
55
LLVM_DEBUG ({
@@ -71,9 +75,13 @@ static LogicalResult checkExtensionRequirements(
71
75
template <typename LabelT>
72
76
static LogicalResult checkCapabilityRequirements (
73
77
LabelT label, const spirv::TargetEnv &targetEnv,
74
- const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
78
+ const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
79
+ ArrayRef<spirv::Capability> elidedCandidates = {}) {
75
80
for (const auto &ors : candidates) {
76
- if (targetEnv.allows (ors))
81
+ if (targetEnv.allows (ors) ||
82
+ llvm::any_of (elidedCandidates, [&ors](spirv::Capability elidedCap) {
83
+ return llvm::is_contained (ors, elidedCap);
84
+ }))
77
85
continue ;
78
86
79
87
LLVM_DEBUG ({
@@ -90,8 +98,55 @@ static LogicalResult checkCapabilityRequirements(
90
98
return success ();
91
99
}
92
100
93
- // / Returns true if the given `storageClass` needs explicit layout when used in
94
- // / Shader environments.
101
+ // / Check capabilities and extensions requirements
102
+ // / Checks that `capCandidates`, `extCandidates`, and capability
103
+ // / (`capCandidates`) infered extension requirements are possible to be
104
+ // / satisfied with the given `targetEnv`.
105
+ // / It also provides a way to relax requirements for certain capabilities and
106
+ // / extensions (e.g., `elidedCapCandidates`, `elidedExtCandidates`), this is to
107
+ // / allow passes to relax certain requirements based on an option (e.g.,
108
+ // / relaxing bitwidth requirement, see `convertScalarType()`,
109
+ // / `ConvertVectorType()`).
110
+ template <typename LabelT>
111
+ static LogicalResult checkCapabilityAndExtensionRequirements (
112
+ LabelT label, const spirv::TargetEnv &targetEnv,
113
+ const spirv::SPIRVType::CapabilityArrayRefVector &capCandidates,
114
+ const spirv::SPIRVType::ExtensionArrayRefVector &extCandidates,
115
+ ArrayRef<spirv::Capability> elidedCapCandidates = {},
116
+ ArrayRef<spirv::Extension> elidedExtCandidates = {}) {
117
+ SmallVector<ArrayRef<spirv::Extension>, 8 > updatedExtCandidates;
118
+ llvm::append_range (updatedExtCandidates, extCandidates);
119
+
120
+ if (failed (checkCapabilityRequirements (label, targetEnv, capCandidates,
121
+ elidedCapCandidates)))
122
+ return failure ();
123
+ // Add capablity infered extensions to the list of extension requirement list,
124
+ // only considers the capabilities that already available in the `targetEnv`.
125
+
126
+ // WARNING: Some capabilities are part of both the core SPIR-V
127
+ // specification and an extension (e.g., 'Groups' capability is part of both
128
+ // core specification and SPV_AMD_shader_ballot extension, hence we should
129
+ // relax the capability inferred extension for these cases).
130
+ static const spirv::Capability multiModalCaps[] = {spirv::Capability::Groups};
131
+ ArrayRef<spirv::Capability> multiModalCapsArrayRef (multiModalCaps,
132
+ std::size (multiModalCaps));
133
+
134
+ for (auto cap : targetEnv.getAttr ().getCapabilities ()) {
135
+ if (llvm::any_of (multiModalCapsArrayRef,
136
+ [&cap](spirv::Capability mMCap ) { return cap == mMCap ; }))
137
+ continue ;
138
+ std::optional<ArrayRef<spirv::Extension>> ext = getExtensions (cap);
139
+ if (ext)
140
+ updatedExtCandidates.push_back (*ext);
141
+ }
142
+ if (failed (checkExtensionRequirements (label, targetEnv, updatedExtCandidates,
143
+ elidedExtCandidates)))
144
+ return failure ();
145
+ return success ();
146
+ }
147
+
148
+ // / Returns true if the given `storageClass` needs explicit layout when used
149
+ // / in Shader environments.
95
150
static bool needsExplicitLayout (spirv::StorageClass storageClass) {
96
151
switch (storageClass) {
97
152
case spirv::StorageClass::PhysicalStorageBuffer:
0 commit comments