18
18
#include " mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19
19
#include " mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
20
20
#include " mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21
+ #include " mlir/IR/Attributes.h"
21
22
#include " mlir/IR/BuiltinAttributes.h"
22
23
#include " mlir/IR/BuiltinTypes.h"
23
24
#include " mlir/Interfaces/FunctionInterfaces.h"
24
25
#include " mlir/Transforms/DialectConversion.h"
26
+ #include " llvm/ADT/SmallVectorExtras.h"
25
27
#include " llvm/ADT/StringExtras.h"
26
28
#include " llvm/Support/Debug.h"
27
29
@@ -54,7 +56,8 @@ using namespace mlir;
54
56
MAP_FN(spirv::StorageClass::PushConstant, 7 ) \
55
57
MAP_FN(spirv::StorageClass::UniformConstant, 8 ) \
56
58
MAP_FN(spirv::StorageClass::Input, 9 ) \
57
- MAP_FN(spirv::StorageClass::Output, 10 )
59
+ MAP_FN(spirv::StorageClass::Output, 10 ) \
60
+ MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11 )
58
61
59
62
std::optional<spirv::StorageClass>
60
63
spirv::mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr) {
@@ -185,13 +188,10 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
185
188
});
186
189
187
190
addConversion ([this ](FunctionType type) {
188
- SmallVector<Type> inputs, results;
189
- inputs.reserve (type.getNumInputs ());
190
- results.reserve (type.getNumResults ());
191
- for (Type input : type.getInputs ())
192
- inputs.push_back (convertType (input));
193
- for (Type result : type.getResults ())
194
- results.push_back (convertType (result));
191
+ auto inputs = llvm::map_to_vector (
192
+ type.getInputs (), [this ](Type ty) { return convertType (ty); });
193
+ auto results = llvm::map_to_vector (
194
+ type.getResults (), [this ](Type ty) { return convertType (ty); });
195
195
return FunctionType::get (type.getContext (), inputs, results);
196
196
});
197
197
}
@@ -250,49 +250,54 @@ spirv::getMemorySpaceToStorageClassTarget(MLIRContext &context) {
250
250
namespace {
251
251
// / Converts any op that has operands/results/attributes with numeric MemRef
252
252
// / memory spaces.
253
- struct MapMemRefStoragePattern final : public ConversionPattern {
253
+ struct MapMemRefStoragePattern final : ConversionPattern {
254
254
MapMemRefStoragePattern (MLIRContext *context, TypeConverter &converter)
255
255
: ConversionPattern(converter, MatchAnyOpTypeTag(), 1 , context) {}
256
256
257
257
LogicalResult
258
258
matchAndRewrite (Operation *op, ArrayRef<Value> operands,
259
- ConversionPatternRewriter &rewriter) const override ;
260
- };
261
- } // namespace
262
-
263
- LogicalResult MapMemRefStoragePattern::matchAndRewrite (
264
- Operation *op, ArrayRef<Value> operands,
265
- ConversionPatternRewriter &rewriter) const {
266
- llvm::SmallVector<NamedAttribute, 4 > newAttrs;
267
- newAttrs.reserve (op->getAttrs ().size ());
268
- for (auto attr : op->getAttrs ()) {
269
- if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue ())) {
270
- auto newAttr = getTypeConverter ()->convertType (typeAttr.getValue ());
271
- newAttrs.emplace_back (attr.getName (), TypeAttr::get (newAttr));
272
- } else {
273
- newAttrs.push_back (attr);
259
+ ConversionPatternRewriter &rewriter) const override {
260
+ llvm::SmallVector<NamedAttribute> newAttrs;
261
+ newAttrs.reserve (op->getAttrs ().size ());
262
+ for (NamedAttribute attr : op->getAttrs ()) {
263
+ if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue ())) {
264
+ Type newAttr = getTypeConverter ()->convertType (typeAttr.getValue ());
265
+ if (!newAttr) {
266
+ return rewriter.notifyMatchFailure (
267
+ op, " type attribute conversion failed" );
268
+ }
269
+ newAttrs.emplace_back (attr.getName (), TypeAttr::get (newAttr));
270
+ } else {
271
+ newAttrs.push_back (attr);
272
+ }
274
273
}
275
- }
276
274
277
- llvm::SmallVector<Type, 4 > newResults;
278
- (void )getTypeConverter ()->convertTypes (op->getResultTypes (), newResults);
279
-
280
- OperationState state (op->getLoc (), op->getName ().getStringRef (), operands,
281
- newResults, newAttrs, op->getSuccessors ());
275
+ llvm::SmallVector<Type, 4 > newResults;
276
+ if (failed (
277
+ getTypeConverter ()->convertTypes (op->getResultTypes (), newResults)))
278
+ return rewriter.notifyMatchFailure (op, " result type conversion failed" );
279
+
280
+ OperationState state (op->getLoc (), op->getName ().getStringRef (), operands,
281
+ newResults, newAttrs, op->getSuccessors ());
282
+
283
+ for (Region ®ion : op->getRegions ()) {
284
+ Region *newRegion = state.addRegion ();
285
+ rewriter.inlineRegionBefore (region, *newRegion, newRegion->begin ());
286
+ TypeConverter::SignatureConversion result (newRegion->getNumArguments ());
287
+ if (failed (getTypeConverter ()->convertSignatureArgs (
288
+ newRegion->getArgumentTypes (), result))) {
289
+ return rewriter.notifyMatchFailure (
290
+ op, " signature argument type conversion failed" );
291
+ }
292
+ rewriter.applySignatureConversion (newRegion, result);
293
+ }
282
294
283
- for (Region ®ion : op->getRegions ()) {
284
- Region *newRegion = state.addRegion ();
285
- rewriter.inlineRegionBefore (region, *newRegion, newRegion->begin ());
286
- TypeConverter::SignatureConversion result (newRegion->getNumArguments ());
287
- (void )getTypeConverter ()->convertSignatureArgs (
288
- newRegion->getArgumentTypes (), result);
289
- rewriter.applySignatureConversion (newRegion, result);
295
+ Operation *newOp = rewriter.create (state);
296
+ rewriter.replaceOp (op, newOp->getResults ());
297
+ return success ();
290
298
}
291
-
292
- Operation *newOp = rewriter.create (state);
293
- rewriter.replaceOp (op, newOp->getResults ());
294
- return success ();
295
- }
299
+ };
300
+ } // namespace
296
301
297
302
void spirv::populateMemorySpaceToStorageClassPatterns (
298
303
spirv::MemorySpaceToStorageClassConverter &typeConverter,
@@ -308,58 +313,53 @@ namespace {
308
313
class MapMemRefStorageClassPass final
309
314
: public impl::MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
310
315
public:
311
- explicit MapMemRefStorageClassPass () {
312
- memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
313
- }
316
+ MapMemRefStorageClassPass () = default ;
317
+
314
318
explicit MapMemRefStorageClassPass (
315
319
const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
316
320
: memorySpaceMap(memorySpaceMap) {}
317
321
318
- LogicalResult initializeOptions (StringRef options) override ;
319
-
320
- void runOnOperation () override ;
321
-
322
- private:
323
- spirv::MemorySpaceToStorageClassMap memorySpaceMap;
324
- };
325
- } // namespace
322
+ LogicalResult initializeOptions (StringRef options) override {
323
+ if (failed (Pass::initializeOptions (options)))
324
+ return failure ();
326
325
327
- LogicalResult MapMemRefStorageClassPass::initializeOptions (StringRef options) {
328
- if (failed (Pass::initializeOptions (options)))
329
- return failure ();
326
+ if (clientAPI == " opencl" )
327
+ memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
328
+ else if (clientAPI != " vulkan" )
329
+ return failure ();
330
330
331
- if (clientAPI == " opencl" ) {
332
- memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
331
+ return success ();
333
332
}
334
333
335
- if (clientAPI != " vulkan" && clientAPI != " opencl" )
336
- return failure ();
334
+ void runOnOperation () override {
335
+ MLIRContext *context = &getContext ();
336
+ Operation *op = getOperation ();
337
+
338
+ if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv (op)) {
339
+ spirv::TargetEnv targetEnv (attr);
340
+ if (targetEnv.allows (spirv::Capability::Kernel)) {
341
+ memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
342
+ } else if (targetEnv.allows (spirv::Capability::Shader)) {
343
+ memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
344
+ }
345
+ }
337
346
338
- return success ();
339
- }
347
+ std::unique_ptr<ConversionTarget> target =
348
+ spirv::getMemorySpaceToStorageClassTarget (*context);
349
+ spirv::MemorySpaceToStorageClassConverter converter (memorySpaceMap);
340
350
341
- void MapMemRefStorageClassPass::runOnOperation () {
342
- MLIRContext *context = &getContext ();
343
- Operation *op = getOperation ();
351
+ RewritePatternSet patterns (context);
352
+ spirv::populateMemorySpaceToStorageClassPatterns (converter, patterns);
344
353
345
- if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv (op)) {
346
- spirv::TargetEnv targetEnv (attr);
347
- if (targetEnv.allows (spirv::Capability::Kernel)) {
348
- memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
349
- } else if (targetEnv.allows (spirv::Capability::Shader)) {
350
- memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
351
- }
354
+ if (failed (applyFullConversion (op, *target, std::move (patterns))))
355
+ return signalPassFailure ();
352
356
}
353
357
354
- auto target = spirv::getMemorySpaceToStorageClassTarget (*context);
355
- spirv::MemorySpaceToStorageClassConverter converter (memorySpaceMap);
356
-
357
- RewritePatternSet patterns (context);
358
- spirv::populateMemorySpaceToStorageClassPatterns (converter, patterns);
359
-
360
- if (failed (applyFullConversion (op, *target, std::move (patterns))))
361
- return signalPassFailure ();
362
- }
358
+ private:
359
+ spirv::MemorySpaceToStorageClassMap memorySpaceMap =
360
+ spirv::mapMemorySpaceToVulkanStorageClass;
361
+ };
362
+ } // namespace
363
363
364
364
std::unique_ptr<OperationPass<>> mlir::createMapMemRefStorageClassPass () {
365
365
return std::make_unique<MapMemRefStorageClassPass>();
0 commit comments