-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][spirv] Implement vector type legalization for function signatures #98337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Angel Zhang (angelz913) ChangesDescriptionThis PR implements a minimal version of function signature conversion to unroll vectors into 1D and with a size supported by SPIR-V (2, 3 or 4 depending on the original dimension). This PR also includes new unit tests that only check for function signature conversion. Future Plans
Patch is 29.20 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/98337.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 560b088dbe5cd..598bba63a2a82 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -40,7 +40,15 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
let description = [{
This is a generic pass to convert to SPIR-V.
}];
- let dependentDialects = ["spirv::SPIRVDialect"];
+ let dependentDialects = [
+ "spirv::SPIRVDialect",
+ "vector::VectorDialect",
+ ];
+ let options = [
+ Option<"runSignatureConversion", "run-signature-conversion", "bool",
+ /*default=*/"false",
+ "Run function signature conversion to convert vector types">
+ ];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 09eecafc0c8a5..112c404527927 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -17,8 +17,10 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallSet.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
namespace mlir {
@@ -134,6 +136,10 @@ class SPIRVConversionTarget : public ConversionTarget {
void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
+void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns);
+
+void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns);
+
namespace spirv {
class AccessChainOp;
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index b5be4654bcb25..88d7590c1daae 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -37,20 +37,53 @@ using namespace mlir;
namespace {
/// A pass to perform the SPIR-V conversion.
-struct ConvertToSPIRVPass final
- : impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
+class ConvertToSPIRVPass
+ : public impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
+ using impl::ConvertToSPIRVPassBase<
+ ConvertToSPIRVPass>::ConvertToSPIRVPassBase;
void runOnOperation() override {
MLIRContext *context = &getContext();
Operation *op = getOperation();
+ if (runSignatureConversion) {
+ // Unroll vectors in function inputs to native vector size.
+ llvm::errs() << "Start unrolling function inputs\n";
+ {
+ RewritePatternSet patterns(context);
+ populateFuncOpVectorRewritePatterns(patterns);
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
+ if (failed(
+ applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+ return signalPassFailure();
+ }
+ llvm::errs() << "Finish unrolling function inputs\n";
+
+ // Unroll vectors in function outputs to native vector size.
+ llvm::errs() << "Start unrolling function outputs\n";
+ {
+ RewritePatternSet patterns(context);
+ populateReturnOpVectorRewritePatterns(patterns);
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
+ if (failed(
+ applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+ return signalPassFailure();
+ }
+ llvm::errs() << "Finish unrolling function outputs\n";
+
+ return;
+ }
+
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
+ std::unique_ptr<ConversionTarget> target =
+ SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
-
RewritePatternSet patterns(context);
ScfToSPIRVContext scfToSPIRVContext;
- // Populate patterns.
+ // Populate patterns for each dialect.
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
@@ -60,9 +93,6 @@ struct ConvertToSPIRVPass final
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
- std::unique_ptr<ConversionTarget> target =
- SPIRVConversionTarget::get(targetAttr);
-
if (failed(applyPartialConversion(op, *target, std::move(patterns))))
return signalPassFailure();
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
index 821f82ebc0796..11af020b6c188 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
@@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIRSPIRVConversion
MLIRFuncDialect
MLIRSPIRVDialect
MLIRTransformUtils
+ MLIRVectorTransforms
)
add_mlir_dialect_library(MLIRSPIRVTransforms
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 4072608dc8f87..6e793573f0262 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -11,19 +11,29 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
-#include <functional>
+#include <cctype>
#include <optional>
#define DEBUG_TYPE "mlir-spirv-conversion"
@@ -34,6 +44,36 @@ using namespace mlir;
// Utility functions
//===----------------------------------------------------------------------===//
+static int getComputeVectorSize(int64_t size) {
+ for (int i : {4, 3, 2}) {
+ if (size % i == 0)
+ return i;
+ }
+ return 1;
+}
+
+static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
+ llvm::errs() << "Get target shape\n";
+ SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
+ std::optional<SmallVector<int64_t>> targetShape =
+ SmallVector<int64_t>(1, getComputeVectorSize(vecType.getShape().back()));
+ if (!targetShape) {
+ llvm::errs() << "--no unrolling target shape defined\n";
+ return std::nullopt;
+ }
+ auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
+ if (!maybeShapeRatio) {
+ llvm::errs() << "--could not compute integral shape ratio -> BAIL\n";
+ return std::nullopt;
+ }
+ if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
+ llvm::errs() << "--no unrolling needed -> SKIP\n";
+ return std::nullopt;
+ }
+ llvm::errs() << "--found an integral shape ratio to unroll to -> SUCCESS\n";
+ return targetShape;
+}
+
/// Checks that `candidates` extension requirements are possible to be satisfied
/// with the given `targetEnv`.
///
@@ -813,6 +853,281 @@ void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
}
+//===----------------------------------------------------------------------===//
+// func::FuncOp Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A pattern for rewriting function signature to convert vector arguments of
+/// functions to be of valid types
+class FuncOpVectorUnroll : public OpRewritePattern<func::FuncOp> {
+public:
+ using OpRewritePattern<func::FuncOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(func::FuncOp funcOp,
+ PatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult
+FuncOpVectorUnroll::matchAndRewrite(func::FuncOp funcOp,
+ PatternRewriter &rewriter) const {
+ auto fnType = funcOp.getFunctionType();
+
+ // Create a new func op with the original type and copy the function body.
+ auto newFuncOp =
+ rewriter.create<func::FuncOp>(funcOp.getLoc(), funcOp.getName(), fnType);
+ rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+
+ llvm::errs() << "After creating new func op and copying the function body\n";
+ newFuncOp.dump();
+
+ Location loc = newFuncOp.getBody().getLoc();
+ Block &entryBlock = newFuncOp.getBlocks().front();
+ rewriter.setInsertionPointToStart(&entryBlock);
+
+ OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
+
+ // For arguments that are of illegal types and require unrolling.
+ // `unrolledInputNums` stores the indices of arguments that result from
+ // unrolling in the new function signature. `newInputNo` is a counter.
+ SmallVector<size_t> unrolledInputNums;
+ size_t newInputNo = 0;
+
+ // For arguments that are of legal types and do not require unrolling.
+ // `tmpOps` stores a mapping from temporary operations that serve as
+ // placeholders for new arguments that will be added later. These operations
+ // will be erased once the entry block's argument list is updated.
+ DenseMap<Operation *, size_t> tmpOps;
+
+ // This counts the number of new operations created.
+ size_t newOpCount = 0;
+
+ // Enumerate through the arguments.
+ for (const auto &argType : enumerate(fnType.getInputs())) {
+ size_t origInputNo = argType.index();
+ Type origType = argType.value();
+ // Check whether the argument is of vector type.
+ auto origVecType = llvm::dyn_cast<VectorType>(origType);
+ if (!origVecType) {
+ // We need a placeholder for the old argument that will be erased later.
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, origType, rewriter.getZeroAttr(origType));
+ rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+ tmpOps.insert({result.getDefiningOp(), newInputNo});
+ oneToNTypeMapping.addInputs(origInputNo, origType);
+ newInputNo++;
+ newOpCount++;
+ continue;
+ }
+ // Check whether the vector needs unrolling.
+ auto targetShape = getTargetShape(origVecType);
+ if (!targetShape) {
+ // We need a placeholder for the old argument that will be erased later.
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, origType, rewriter.getZeroAttr(origType));
+ rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+ tmpOps.insert({result.getDefiningOp(), newInputNo});
+ oneToNTypeMapping.addInputs(origInputNo, origType);
+ newInputNo++;
+ newOpCount++;
+ continue;
+ }
+ llvm::errs() << "Got target shape\n";
+ VectorType unrolledType =
+ VectorType::get(*targetShape, origVecType.getElementType());
+ llvm::errs() << "Unrolled type is ";
+ unrolledType.dump();
+ SmallVector<int64_t> originalShape =
+ llvm::to_vector<4>(origVecType.getShape());
+
+ // Prepare the result vector.
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, origVecType, rewriter.getZeroAttr(origVecType));
+ newOpCount++;
+ // Prepare the placeholder for the new arguments that will be added later.
+ Value dummy = rewriter.create<arith::ConstantOp>(
+ loc, unrolledType, rewriter.getZeroAttr(unrolledType));
+ newOpCount++;
+
+ // Create the `vector.insert_strided_slice` ops.
+ SmallVector<int64_t> strides(targetShape->size(), 1);
+ SmallVector<Type> newTypes;
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalShape, *targetShape)) {
+ result = rewriter.create<vector::InsertStridedSliceOp>(loc, dummy, result,
+ offsets, strides);
+ newTypes.push_back(unrolledType);
+ unrolledInputNums.push_back(newInputNo);
+ newInputNo++;
+ newOpCount++;
+ }
+ rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+ oneToNTypeMapping.addInputs(origInputNo, newTypes);
+ }
+
+ llvm::errs() << "After enumerating through the arguments\n";
+ newFuncOp.dump();
+
+ // Change the function signature.
+ auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
+ auto newFnType =
+ FunctionType::get(rewriter.getContext(), TypeRange(convertedTypes),
+ TypeRange(fnType.getResults()));
+ rewriter.modifyOpInPlace(newFuncOp,
+ [&] { newFuncOp.setFunctionType(newFnType); });
+
+ llvm::errs() << "After changing function signature\n";
+ newFuncOp.dump();
+
+ // Update the arguments in the entry block.
+ entryBlock.eraseArguments(0, fnType.getNumInputs());
+ SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
+ entryBlock.addArguments(convertedTypes, locs);
+
+ llvm::errs() << "After updating the arguments in the entry block\n";
+ newFuncOp.dump();
+
+ // Replace the placeholder values with the new arguments. We assume there is
+ // only one block for now.
+ size_t idx = 0;
+ for (auto opPair : llvm::enumerate(entryBlock.getOperations())) {
+ size_t count = opPair.index();
+ Operation &op = opPair.value();
+ // We first look for operands that are placeholders for initially legal
+ // arguments.
+ for (auto operandPair : llvm::enumerate(op.getOperands())) {
+ Operation *operandOp = operandPair.value().getDefiningOp();
+ if (tmpOps.find(operandOp) != tmpOps.end())
+ rewriter.modifyOpInPlace(&op, [&] {
+ op.setOperand(operandPair.index(),
+ newFuncOp.getArgument(tmpOps[operandOp]));
+ });
+ }
+ // Since all newly created operations are in the beginning, reaching the end
+ // of them means that any later `vector.insert_strided_slice` should not be
+ // touched.
+ if (count >= newOpCount)
+ continue;
+ auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op);
+ if (vecOp) {
+ size_t unrolledInputNo = unrolledInputNums[idx];
+ rewriter.modifyOpInPlace(&op, [&] {
+ op.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
+ });
+ idx++;
+ }
+ count++;
+ }
+
+ // Erase the original funcOp. The `tmpOps` do not need to be erased since
+ // they have no uses and will be handled by dead-code elimination.
+ rewriter.eraseOp(funcOp);
+ return success();
+}
+
+void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
+ patterns.add<FuncOpVectorUnroll>(patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// func::ReturnOp Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A pattern for rewriting function signature and the return op to convert
+/// vectors to be of valid types.
+class ReturnOpVectorUnroll : public OpRewritePattern<func::ReturnOp> {
+public:
+ using OpRewritePattern<func::ReturnOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(func::ReturnOp returnOp,
+ PatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult
+ReturnOpVectorUnroll::matchAndRewrite(func::ReturnOp returnOp,
+ PatternRewriter &rewriter) const {
+
+ // Check whether the parent funcOp is valid.
+ func::FuncOp funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
+ if (!funcOp)
+ return failure();
+
+ auto fnType = funcOp.getFunctionType();
+ OneToNTypeMapping oneToNTypeMapping(fnType.getResults());
+ Location loc = returnOp.getLoc();
+
+ // For the new return op.
+ SmallVector<Value> newOperands;
+
+ // Enumerate through the results.
+ for (const auto &argType : enumerate(fnType.getResults())) {
+ size_t origResultNo = argType.index();
+ Type origType = argType.value();
+ auto origVecType = llvm::dyn_cast<VectorType>(origType);
+ // Check whether the argument is of vector type.
+ if (!origVecType) {
+ oneToNTypeMapping.addInputs(origResultNo, origType);
+ newOperands.push_back(returnOp.getOperand(origResultNo));
+ continue;
+ }
+ // Check whether the vector needs unrolling.
+ auto targetShape = getTargetShape(origVecType);
+ if (!targetShape) {
+ // The original argument can be used.
+ oneToNTypeMapping.addInputs(origResultNo, origType);
+ newOperands.push_back(returnOp.getOperand(origResultNo));
+ continue;
+ }
+ llvm::errs() << "Got target shape\n";
+ VectorType unrolledType =
+ VectorType::get(*targetShape, origVecType.getElementType());
+ llvm::errs() << "Unrolled type is ";
+ unrolledType.dump();
+
+ // Create `vector.extract_strided_slice` ops to form legal vectors from the
+ // original operand of illegal type.
+ SmallVector<int64_t> originalShape =
+ llvm::to_vector<4>(origVecType.getShape());
+ SmallVector<int64_t> strides(targetShape->size(), 1);
+ SmallVector<Type> newTypes;
+ Value returnValue = returnOp.getOperand(origResultNo);
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalShape, *targetShape)) {
+ Value result = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, returnValue, offsets, *targetShape, strides);
+ newOperands.push_back(result);
+ newTypes.push_back(unrolledType);
+ }
+ oneToNTypeMapping.addInputs(origResultNo, newTypes);
+ }
+
+ llvm::errs() << "After enumerating through the arguments\n";
+ funcOp.dump();
+
+ // Change the function signature.
+ auto newFnType =
+ FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
+ TypeRange(oneToNTypeMapping.getConvertedTypes()));
+ rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setFunctionType(newFnType); });
+ llvm::errs() << "After changing function signature\n";
+ funcOp.dump();
+
+ // Replace the return op using the new operands. This will automatically
+ // update the entry block as well.
+ rewriter.replaceOp(returnOp,
+ rewriter.create<func::ReturnOp>(loc, newOperands));
+
+ return success();
+}
+
+void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) {
+ patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
+}
+
//===----------------------------------------------------------------------===//
// Builtin Variables
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
new file mode 100644
index 0000000000000..d5c777908d7e2
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
@@ -0,0 +1,132 @@
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion" -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @simple_scalar
+// CHECK-SAME: (%[[ARG0:.+]]: i32)
+func.func @simple_scalar(%arg0 : i32) -> i32 {
+ // CHECK: return %[[ARG0]] : i32
+ return %arg0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @simple_vector_4
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>)
+func.func @simple_vector_4(%arg0 : vector<4xi32>) -> vector<4xi32> {
+ // CHECK: return %[[ARG0]] : vector<4xi32>
+ return %arg0 : vector<4x...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
I have a general question: can we use/extend existing |
Angel actually started like this, but we were concerned that there are unlikely to be any other users in the repo, so from the point of view of code maintenance it would be better to keep it on the spirv side. If someone else needs it, we can always promote it to vector transforms. |
Ok. I'm fine either way, but it looked like a generic enough utility. |
4b91a99
to
139ea57
Compare
Also, one potential issue is that PR doesn't cover |
7dac5b8
to
6d4a13c
Compare
We will deal with them later. I have included this in the "Future Steps" section in the PR description. |
applyPatternsAndFoldGreedily(op, std::move(patterns), config))) | ||
return signalPassFailure(); | ||
} | ||
return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we are skipping the rest of the pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My thinking is that this PR only takes care of the new patterns, so the new tests I added only check for the correctness of those patterns, without going through the SPIR-V dialect conversion (the rest of this pass). In the future, I will add more vector patterns and more command-line options for controlling/testing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, currently this looks like two different passes hiding in one class. I'd expect the rest of the code to run after signature conversion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can have a dedicated --test-convert-to-spirv-patterns
pass or whatever to test all sorts of patterns you'd like to test in isolation. In the test pass you can have many options, like TestLinalgTransforms.cpp
. This keeps the main pass introduced here less cluttered and more targeted end-to-end so more clean, given I think we are sort of aiming for this pass to be directly consumed by downstream users eventually. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We discussed this offline and also converged on adding fine-grained test passes
mlir/test/lib/Conversion/ConvertToSPIRV/TestSPIRVFuncSignatureConversion.cpp
Outdated
Show resolved
Hide resolved
de7e65a
to
7f8dc86
Compare
7f8dc86
to
2fb2a64
Compare
dfb6e2e
to
af0fb06
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
af0fb06
to
d702290
Compare
d702290
to
7347d06
Compare
auto it = tmpOps.find(operandOp); | ||
if (it != tmpOps.end()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This wasn't resolved
10984de
to
a8972cc
Compare
Co-authored-by: Jakub Kuderski <[email protected]>
…res (#98337) Summary: ### Description This PR implements a minimal version of function signature conversion to unroll vectors into 1D and with a size supported by SPIR-V (2, 3 or 4 depending on the original dimension). This PR also includes new unit tests that only check for function signature conversion. ### Future Plans - Check for capabilities that support vectors of size 8 or 16. - Set up `OneToNTypeConversion` and `DialectConversion` to replace the current implementation that uses `GreedyPatternRewriteDriver`. - Introduce other vector unrolling patterns to cancel out the `vector.insert_strided_slice` and `vector.extract_strided_slice` ops and fully legalize the vector types in the function body. - Handle `func::CallOp` and declarations. - Restructure the code in `SPIRVConversion.cpp`. - Create test passes for testing sets of patterns in isolation. - Optimize the way original shape is splitted into target shapes, e.g. `vector<5xi32>` can be splitted into `vector<4xi32>` and `vector<1xi32>`. --------- Co-authored-by: Jakub Kuderski <[email protected]> Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60250907
Description
This PR implements a minimal version of function signature conversion to unroll vectors into 1D and with a size supported by SPIR-V (2, 3 or 4 depending on the original dimension). This PR also includes new unit tests that only check for function signature conversion.
Future Plans
OneToNTypeConversion
andDialectConversion
to replace the current implementation that usesGreedyPatternRewriteDriver
.vector.insert_strided_slice
andvector.extract_strided_slice
ops and fully legalize the vector types in the function body.func::CallOp
and declarations.SPIRVConversion.cpp
.vector<5xi32>
can be splitted intovector<4xi32>
andvector<1xi32>
.