Skip to content

Commit 7f6d445

Browse files
authored
[mlir][spirv] Clean up map-memref-storage-class pass (#79937)
Clean up the code before making more substantial changes. NFC modulo extra error checking and physical storage buffer storage class handling. * Add switch case for physical storage buffer * Handle type conversion failures * Inline methods to reduce scrolling * Other minor cleanups
1 parent 659ce8f commit 7f6d445

File tree

1 file changed

+81
-81
lines changed

1 file changed

+81
-81
lines changed

mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp

Lines changed: 81 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1919
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
2020
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21+
#include "mlir/IR/Attributes.h"
2122
#include "mlir/IR/BuiltinAttributes.h"
2223
#include "mlir/IR/BuiltinTypes.h"
2324
#include "mlir/Interfaces/FunctionInterfaces.h"
2425
#include "mlir/Transforms/DialectConversion.h"
26+
#include "llvm/ADT/SmallVectorExtras.h"
2527
#include "llvm/ADT/StringExtras.h"
2628
#include "llvm/Support/Debug.h"
2729

@@ -54,7 +56,8 @@ using namespace mlir;
5456
MAP_FN(spirv::StorageClass::PushConstant, 7) \
5557
MAP_FN(spirv::StorageClass::UniformConstant, 8) \
5658
MAP_FN(spirv::StorageClass::Input, 9) \
57-
MAP_FN(spirv::StorageClass::Output, 10)
59+
MAP_FN(spirv::StorageClass::Output, 10) \
60+
MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11)
5861

5962
std::optional<spirv::StorageClass>
6063
spirv::mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr) {
@@ -185,13 +188,10 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
185188
});
186189

187190
addConversion([this](FunctionType type) {
188-
SmallVector<Type> inputs, results;
189-
inputs.reserve(type.getNumInputs());
190-
results.reserve(type.getNumResults());
191-
for (Type input : type.getInputs())
192-
inputs.push_back(convertType(input));
193-
for (Type result : type.getResults())
194-
results.push_back(convertType(result));
191+
auto inputs = llvm::map_to_vector(
192+
type.getInputs(), [this](Type ty) { return convertType(ty); });
193+
auto results = llvm::map_to_vector(
194+
type.getResults(), [this](Type ty) { return convertType(ty); });
195195
return FunctionType::get(type.getContext(), inputs, results);
196196
});
197197
}
@@ -250,49 +250,54 @@ spirv::getMemorySpaceToStorageClassTarget(MLIRContext &context) {
250250
namespace {
251251
/// Converts any op that has operands/results/attributes with numeric MemRef
252252
/// memory spaces.
253-
struct MapMemRefStoragePattern final : public ConversionPattern {
253+
struct MapMemRefStoragePattern final : ConversionPattern {
254254
MapMemRefStoragePattern(MLIRContext *context, TypeConverter &converter)
255255
: ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {}
256256

257257
LogicalResult
258258
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
259-
ConversionPatternRewriter &rewriter) const override;
260-
};
261-
} // namespace
262-
263-
LogicalResult MapMemRefStoragePattern::matchAndRewrite(
264-
Operation *op, ArrayRef<Value> operands,
265-
ConversionPatternRewriter &rewriter) const {
266-
llvm::SmallVector<NamedAttribute, 4> newAttrs;
267-
newAttrs.reserve(op->getAttrs().size());
268-
for (auto attr : op->getAttrs()) {
269-
if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
270-
auto newAttr = getTypeConverter()->convertType(typeAttr.getValue());
271-
newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
272-
} else {
273-
newAttrs.push_back(attr);
259+
ConversionPatternRewriter &rewriter) const override {
260+
llvm::SmallVector<NamedAttribute> newAttrs;
261+
newAttrs.reserve(op->getAttrs().size());
262+
for (NamedAttribute attr : op->getAttrs()) {
263+
if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
264+
Type newAttr = getTypeConverter()->convertType(typeAttr.getValue());
265+
if (!newAttr) {
266+
return rewriter.notifyMatchFailure(
267+
op, "type attribute conversion failed");
268+
}
269+
newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
270+
} else {
271+
newAttrs.push_back(attr);
272+
}
274273
}
275-
}
276274

277-
llvm::SmallVector<Type, 4> newResults;
278-
(void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
279-
280-
OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
281-
newResults, newAttrs, op->getSuccessors());
275+
llvm::SmallVector<Type, 4> newResults;
276+
if (failed(
277+
getTypeConverter()->convertTypes(op->getResultTypes(), newResults)))
278+
return rewriter.notifyMatchFailure(op, "result type conversion failed");
279+
280+
OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
281+
newResults, newAttrs, op->getSuccessors());
282+
283+
for (Region &region : op->getRegions()) {
284+
Region *newRegion = state.addRegion();
285+
rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
286+
TypeConverter::SignatureConversion result(newRegion->getNumArguments());
287+
if (failed(getTypeConverter()->convertSignatureArgs(
288+
newRegion->getArgumentTypes(), result))) {
289+
return rewriter.notifyMatchFailure(
290+
op, "signature argument type conversion failed");
291+
}
292+
rewriter.applySignatureConversion(newRegion, result);
293+
}
282294

283-
for (Region &region : op->getRegions()) {
284-
Region *newRegion = state.addRegion();
285-
rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
286-
TypeConverter::SignatureConversion result(newRegion->getNumArguments());
287-
(void)getTypeConverter()->convertSignatureArgs(
288-
newRegion->getArgumentTypes(), result);
289-
rewriter.applySignatureConversion(newRegion, result);
295+
Operation *newOp = rewriter.create(state);
296+
rewriter.replaceOp(op, newOp->getResults());
297+
return success();
290298
}
291-
292-
Operation *newOp = rewriter.create(state);
293-
rewriter.replaceOp(op, newOp->getResults());
294-
return success();
295-
}
299+
};
300+
} // namespace
296301

297302
void spirv::populateMemorySpaceToStorageClassPatterns(
298303
spirv::MemorySpaceToStorageClassConverter &typeConverter,
@@ -308,58 +313,53 @@ namespace {
308313
class MapMemRefStorageClassPass final
309314
: public impl::MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
310315
public:
311-
explicit MapMemRefStorageClassPass() {
312-
memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
313-
}
316+
MapMemRefStorageClassPass() = default;
317+
314318
explicit MapMemRefStorageClassPass(
315319
const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
316320
: memorySpaceMap(memorySpaceMap) {}
317321

318-
LogicalResult initializeOptions(StringRef options) override;
319-
320-
void runOnOperation() override;
321-
322-
private:
323-
spirv::MemorySpaceToStorageClassMap memorySpaceMap;
324-
};
325-
} // namespace
322+
LogicalResult initializeOptions(StringRef options) override {
323+
if (failed(Pass::initializeOptions(options)))
324+
return failure();
326325

327-
LogicalResult MapMemRefStorageClassPass::initializeOptions(StringRef options) {
328-
if (failed(Pass::initializeOptions(options)))
329-
return failure();
326+
if (clientAPI == "opencl")
327+
memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
328+
else if (clientAPI != "vulkan")
329+
return failure();
330330

331-
if (clientAPI == "opencl") {
332-
memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
331+
return success();
333332
}
334333

335-
if (clientAPI != "vulkan" && clientAPI != "opencl")
336-
return failure();
334+
void runOnOperation() override {
335+
MLIRContext *context = &getContext();
336+
Operation *op = getOperation();
337+
338+
if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
339+
spirv::TargetEnv targetEnv(attr);
340+
if (targetEnv.allows(spirv::Capability::Kernel)) {
341+
memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
342+
} else if (targetEnv.allows(spirv::Capability::Shader)) {
343+
memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
344+
}
345+
}
337346

338-
return success();
339-
}
347+
std::unique_ptr<ConversionTarget> target =
348+
spirv::getMemorySpaceToStorageClassTarget(*context);
349+
spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
340350

341-
void MapMemRefStorageClassPass::runOnOperation() {
342-
MLIRContext *context = &getContext();
343-
Operation *op = getOperation();
351+
RewritePatternSet patterns(context);
352+
spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
344353

345-
if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
346-
spirv::TargetEnv targetEnv(attr);
347-
if (targetEnv.allows(spirv::Capability::Kernel)) {
348-
memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
349-
} else if (targetEnv.allows(spirv::Capability::Shader)) {
350-
memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
351-
}
354+
if (failed(applyFullConversion(op, *target, std::move(patterns))))
355+
return signalPassFailure();
352356
}
353357

354-
auto target = spirv::getMemorySpaceToStorageClassTarget(*context);
355-
spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
356-
357-
RewritePatternSet patterns(context);
358-
spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
359-
360-
if (failed(applyFullConversion(op, *target, std::move(patterns))))
361-
return signalPassFailure();
362-
}
358+
private:
359+
spirv::MemorySpaceToStorageClassMap memorySpaceMap =
360+
spirv::mapMemorySpaceToVulkanStorageClass;
361+
};
362+
} // namespace
363363

364364
std::unique_ptr<OperationPass<>> mlir::createMapMemRefStorageClassPass() {
365365
return std::make_unique<MapMemRefStorageClassPass>();

0 commit comments

Comments
 (0)