Skip to content

Commit 8fb67d0

Browse files
committed
[mlir][spirv] Use combined-check for type related extension and capability requirements
Replace the seperate extension and capability checking with combined check `checkCapabilityAndExtensionRequirements()`. This makes the code flow simpler. Also adds the extra check for capability inferred extension check. Need for capability inferred extension check: If a capability is a requirement, the respective extension that implements it should also become an extension requirement, there were no support for that check, as a result, the extension requirement had to be added separately. This separate requirement addition causes problem when a feature is enabled by multiple capability, and one of the capability is part of an extension. E.g., vector size of 16 can be enabled by both "Vector16" and "vectorAnyINTEL" capability, however, only "vectorAnyINTEL" has an extension requirement ("SPV_INTEL_vector_compute"). Since the process of adding capability and extension requirement are independent, there is no way, to handle cases like this. Therefore, for cases like this, enable adding capability requirement initially, then do the check for capability inferred extension.
1 parent 43ebcad commit 8fb67d0

File tree

2 files changed

+92
-50
lines changed

2 files changed

+92
-50
lines changed

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,8 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
285285
type.getCapabilities(capabilities, storageClass);
286286

287287
// If all requirements are met, then we can accept this type as-is.
288-
if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
289-
succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
288+
if (succeeded(checkCapabilityAndExtensionRequirements(
289+
type, targetEnv, capabilities, extensions)))
290290
return type;
291291

292292
// Otherwise we need to adjust the type, which really means adjusting the
@@ -397,15 +397,35 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
397397
cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
398398
cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
399399

400+
// If the bit-width related capabilities and extensions are not met
401+
// for lower bit-width (<32-bit), convert it to 32-bit
402+
auto elementType =
403+
convertScalarType(targetEnv, options, scalarType, storageClass);
404+
if (!elementType)
405+
return nullptr;
406+
type = VectorType::get(type.getShape(), elementType);
407+
408+
SmallVector<spirv::Capability, 4> elidedCaps;
409+
SmallVector<spirv::Extension, 4> elidedExts;
410+
411+
// Relax the bitwidth requirements for capabilities and extensions
412+
if (options.emulateLT32BitScalarTypes) {
413+
elidedCaps.push_back(spirv::Capability::Int8);
414+
elidedCaps.push_back(spirv::Capability::Int16);
415+
elidedCaps.push_back(spirv::Capability::Float16);
416+
}
417+
// For capabilities whose requirements were relaxed, relax requirements for
418+
// the extensions that were infered by those capabilities (e.g., elidedCaps)
419+
for (spirv::Capability cap : elidedCaps) {
420+
std::optional<ArrayRef<spirv::Extension>> ext = spirv::getExtensions(cap);
421+
if (ext)
422+
llvm::append_range(elidedExts, *ext);
423+
}
400424
// If all requirements are met, then we can accept this type as-is.
401-
if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
402-
succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
425+
if (succeeded(checkCapabilityAndExtensionRequirements(
426+
type, targetEnv, capabilities, extensions, elidedCaps, elidedExts)))
403427
return type;
404428

405-
auto elementType =
406-
convertScalarType(targetEnv, options, scalarType, storageClass);
407-
if (elementType)
408-
return VectorType::get(type.getShape(), elementType);
409429
return nullptr;
410430
}
411431

@@ -711,8 +731,9 @@ std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
711731
SmallVector<ArrayRef<spirv::Capability>, 2> caps;
712732
scalarType.getExtensions(exts);
713733
scalarType.getCapabilities(caps);
714-
if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
715-
failed(checkExtensionRequirements(type, targetEnv, exts))) {
734+
735+
if (failed(checkCapabilityAndExtensionRequirements(type, targetEnv, caps,
736+
exts))) {
716737
auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
717738
return castOp.getResult(0);
718739
}
@@ -1205,16 +1226,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
12051226
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
12061227
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
12071228
for (Type valueType : valueTypes) {
1208-
typeExtensions.clear();
1209-
cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
1210-
if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1211-
typeExtensions)))
1212-
return false;
1213-
12141229
typeCapabilities.clear();
12151230
cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
1216-
if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1217-
typeCapabilities)))
1231+
typeExtensions.clear();
1232+
cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
1233+
// Checking for capability and extension requirements along with capability
1234+
// infered extensions.
1235+
// If a capability is present, the extension that
1236+
// supports it should also be present, this reduces the burden of adding
1237+
// extension requirement that may or maynot be added in
1238+
// CompositeType::getExtensions().
1239+
if (failed(checkCapabilityAndExtensionRequirements(
1240+
op->getName(), this->targetEnv, typeCapabilities, typeExtensions)))
12181241
return false;
12191242
}
12201243

0 commit comments

Comments
 (0)