Skip to content

[mlir][spirv] Implement vector type legalization for function signatures #98337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
11ede0c
[mlir][spirv] Implement vector type legalization in function signatures
angelz913 Jun 27, 2024
e02bc02
Function input vector unrolling working and moved pattern to SPIRV
angelz913 Jul 5, 2024
eeda7b1
Implement function result and ReturnOp vector unrolling
angelz913 Jul 8, 2024
3e1e311
Compute the target shape based on original vector shape
angelz913 Jul 8, 2024
49ca525
Fix bug in function output unrolling
angelz913 Jul 8, 2024
0b5f482
Working for signatures with legal and illegal types
angelz913 Jul 9, 2024
6a5e36e
Only keep the signature conversion, and refactor code
angelz913 Jul 9, 2024
5c807e2
Add an option for testing signature conversion
angelz913 Jul 10, 2024
2642de1
Add unit tests
angelz913 Jul 10, 2024
dbde3fc
Code formatting
angelz913 Jul 10, 2024
3ca24ea
Update mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
angelz913 Jul 10, 2024
f84789f
Update mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
angelz913 Jul 10, 2024
c640e20
Run both patterns at the same time
angelz913 Jul 10, 2024
99fc708
Code refactoring and formatting
angelz913 Jul 10, 2024
995952b
Add negative tests for function declarations and for scalable vectors
angelz913 Jul 11, 2024
e682f15
More refactoring
angelz913 Jul 12, 2024
2fb2a64
Add a new pass for testing signature conversion patterns in isolation
angelz913 Jul 15, 2024
7347d06
Update cmake and bazel dependencies
angelz913 Jul 15, 2024
a8972cc
Fix compile warnings
angelz913 Jul 17, 2024
b6d2ca7
Update mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
angelz913 Jul 17, 2024
c1db655
Code refactoring
angelz913 Jul 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,15 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
let description = [{
This is a generic pass to convert to SPIR-V.
}];
let dependentDialects = ["spirv::SPIRVDialect"];
let dependentDialects = [
"spirv::SPIRVDialect",
"vector::VectorDialect",
];
let options = [
Option<"runSignatureConversion", "run-signature-conversion", "bool",
/*default=*/"true",
"Run function signature conversion to convert vector types">
];
}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/OneToNTypeConversion.h"
#include "llvm/ADT/SmallSet.h"

namespace mlir {
Expand Down Expand Up @@ -134,6 +136,10 @@ class SPIRVConversionTarget : public ConversionTarget {
void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);

void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns);

void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns);

namespace spirv {
class AccessChainOp;

Expand Down
20 changes: 15 additions & 5 deletions mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,31 @@ namespace {
/// A pass to perform the SPIR-V conversion.
struct ConvertToSPIRVPass final
: impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
using ConvertToSPIRVPassBase::ConvertToSPIRVPassBase;

void runOnOperation() override {
MLIRContext *context = &getContext();
Operation *op = getOperation();

if (runSignatureConversion) {
// Unroll vectors in function signatures to native vector size.
RewritePatternSet patterns(context);
populateFuncOpVectorRewritePatterns(patterns);
populateReturnOpVectorRewritePatterns(patterns);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
return signalPassFailure();
}

spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);

RewritePatternSet patterns(context);
ScfToSPIRVContext scfToSPIRVContext;

// Populate patterns.
// Populate patterns for each dialect.
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
Expand All @@ -60,9 +73,6 @@ struct ConvertToSPIRVPass final
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);

std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);

if (failed(applyPartialConversion(op, *target, std::move(patterns))))
return signalPassFailure();
}
Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,15 @@ add_mlir_dialect_library(MLIRSPIRVConversion
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV

LINK_LIBS PUBLIC
MLIRArithDialect
MLIRDialectUtils
MLIRFuncDialect
MLIRIR
MLIRSPIRVDialect
MLIRSupport
MLIRTransformUtils
MLIRVectorDialect
MLIRVectorTransforms
)

add_mlir_dialect_library(MLIRSPIRVTransforms
Expand Down
Loading