Skip to content

[mlir][spirv] Implement vector type legalization for function signatures #98337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jul 17, 2024

Conversation

angelz913
Copy link
Contributor

@angelz913 angelz913 commented Jul 10, 2024

Description

This PR implements a minimal version of function signature conversion to unroll vectors into 1D and with a size supported by SPIR-V (2, 3 or 4 depending on the original dimension). This PR also includes new unit tests that only check for function signature conversion.

Future Plans

  • Check for capabilities that support vectors of size 8 or 16.
  • Set up OneToNTypeConversion and DialectConversion to replace the current implementation that uses GreedyPatternRewriteDriver.
  • Introduce other vector unrolling patterns to cancel out the vector.insert_strided_slice and vector.extract_strided_slice ops and fully legalize the vector types in the function body.
  • Handle func::CallOp and declarations.
  • Restructure the code in SPIRVConversion.cpp.
  • Create test passes for testing sets of patterns in isolation.
  • Optimize the way original shape is splitted into target shapes, e.g. vector<5xi32> can be splitted into vector<4xi32> and vector<1xi32>.

@llvmbot
Copy link
Member

llvmbot commented Jul 10, 2024

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Angel Zhang (angelz913)

Changes

Description

This PR implements a minimal version of function signature conversion to unroll vectors into 1D and with a size supported by SPIR-V (2, 3 or 4 depending on the original dimension). This PR also includes new unit tests that only check for function signature conversion.

Future Plans

  • Check for capabilities that support vectors of size 8 or 16.
  • Set up OneToNTypeConversion and DialectConversion to replace the current implementation that uses GreedyPatternRewriteDriver.
  • Introduce other vector unrolling patterns to cancel out the vector.insert_strided_slice and vector.extract_strided_slice ops and fully legalize the vector types in the function body.

Patch is 29.20 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/98337.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.td (+9-1)
  • (modified) mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h (+6)
  • (modified) mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp (+37-7)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+316-1)
  • (added) mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir (+132)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 560b088dbe5cd..598bba63a2a82 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -40,7 +40,15 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
   let description = [{
     This is a generic pass to convert to SPIR-V.
   }];
-  let dependentDialects = ["spirv::SPIRVDialect"];
+  let dependentDialects = [
+    "spirv::SPIRVDialect",
+    "vector::VectorDialect",
+  ];
+  let options = [
+    Option<"runSignatureConversion", "run-signature-conversion", "bool",
+    /*default=*/"false",
+    "Run function signature conversion to convert vector types">
+  ];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 09eecafc0c8a5..112c404527927 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -17,8 +17,10 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/SmallSet.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
 
 namespace mlir {
 
@@ -134,6 +136,10 @@ class SPIRVConversionTarget : public ConversionTarget {
 void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                         RewritePatternSet &patterns);
 
+void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns);
+
+void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns);
+
 namespace spirv {
 class AccessChainOp;
 
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index b5be4654bcb25..88d7590c1daae 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -37,20 +37,53 @@ using namespace mlir;
 namespace {
 
 /// A pass to perform the SPIR-V conversion.
-struct ConvertToSPIRVPass final
-    : impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
+class ConvertToSPIRVPass
+    : public impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
+  using impl::ConvertToSPIRVPassBase<
+      ConvertToSPIRVPass>::ConvertToSPIRVPassBase;
 
   void runOnOperation() override {
     MLIRContext *context = &getContext();
     Operation *op = getOperation();
 
+    if (runSignatureConversion) {
+      // Unroll vectors in function inputs to native vector size.
+      llvm::errs() << "Start unrolling function inputs\n";
+      {
+        RewritePatternSet patterns(context);
+        populateFuncOpVectorRewritePatterns(patterns);
+        GreedyRewriteConfig config;
+        config.strictMode = GreedyRewriteStrictness::ExistingOps;
+        if (failed(
+                applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+          return signalPassFailure();
+      }
+      llvm::errs() << "Finish unrolling function inputs\n";
+
+      // Unroll vectors in function outputs to native vector size.
+      llvm::errs() << "Start unrolling function outputs\n";
+      {
+        RewritePatternSet patterns(context);
+        populateReturnOpVectorRewritePatterns(patterns);
+        GreedyRewriteConfig config;
+        config.strictMode = GreedyRewriteStrictness::ExistingOps;
+        if (failed(
+                applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+          return signalPassFailure();
+      }
+      llvm::errs() << "Finish unrolling function outputs\n";
+
+      return;
+    }
+
     spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
+    std::unique_ptr<ConversionTarget> target =
+        SPIRVConversionTarget::get(targetAttr);
     SPIRVTypeConverter typeConverter(targetAttr);
-
     RewritePatternSet patterns(context);
     ScfToSPIRVContext scfToSPIRVContext;
 
-    // Populate patterns.
+    // Populate patterns for each dialect.
     arith::populateCeilFloorDivExpandOpsPatterns(patterns);
     arith::populateArithToSPIRVPatterns(typeConverter, patterns);
     populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
@@ -60,9 +93,6 @@ struct ConvertToSPIRVPass final
     populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
     ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
 
-    std::unique_ptr<ConversionTarget> target =
-        SPIRVConversionTarget::get(targetAttr);
-
     if (failed(applyPartialConversion(op, *target, std::move(patterns))))
       return signalPassFailure();
   }
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
index 821f82ebc0796..11af020b6c188 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
@@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIRSPIRVConversion
   MLIRFuncDialect
   MLIRSPIRVDialect
   MLIRTransformUtils
+  MLIRVectorTransforms
 )
 
 add_mlir_dialect_library(MLIRSPIRVTransforms
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 4072608dc8f87..6e793573f0262 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -11,19 +11,29 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/MathExtras.h"
 
-#include <functional>
+#include <cctype>
 #include <optional>
 
 #define DEBUG_TYPE "mlir-spirv-conversion"
@@ -34,6 +44,36 @@ using namespace mlir;
 // Utility functions
 //===----------------------------------------------------------------------===//
 
+static int getComputeVectorSize(int64_t size) {
+  for (int i : {4, 3, 2}) {
+    if (size % i == 0)
+      return i;
+  }
+  return 1;
+}
+
+static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
+  llvm::errs() << "Get target shape\n";
+  SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
+  std::optional<SmallVector<int64_t>> targetShape =
+      SmallVector<int64_t>(1, getComputeVectorSize(vecType.getShape().back()));
+  if (!targetShape) {
+    llvm::errs() << "--no unrolling target shape defined\n";
+    return std::nullopt;
+  }
+  auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
+  if (!maybeShapeRatio) {
+    llvm::errs() << "--could not compute integral shape ratio -> BAIL\n";
+    return std::nullopt;
+  }
+  if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
+    llvm::errs() << "--no unrolling needed -> SKIP\n";
+    return std::nullopt;
+  }
+  llvm::errs() << "--found an integral shape ratio to unroll to -> SUCCESS\n";
+  return targetShape;
+}
+
 /// Checks that `candidates` extension requirements are possible to be satisfied
 /// with the given `targetEnv`.
 ///
@@ -813,6 +853,281 @@ void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
   patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
 }
 
+//===----------------------------------------------------------------------===//
+// func::FuncOp Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A pattern for rewriting function signature to convert vector arguments of
+/// functions to be of valid types
+class FuncOpVectorUnroll : public OpRewritePattern<func::FuncOp> {
+public:
+  using OpRewritePattern<func::FuncOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(func::FuncOp funcOp,
+                                PatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult
+FuncOpVectorUnroll::matchAndRewrite(func::FuncOp funcOp,
+                                    PatternRewriter &rewriter) const {
+  auto fnType = funcOp.getFunctionType();
+
+  // Create a new func op with the original type and copy the function body.
+  auto newFuncOp =
+      rewriter.create<func::FuncOp>(funcOp.getLoc(), funcOp.getName(), fnType);
+  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+                              newFuncOp.end());
+
+  llvm::errs() << "After creating new func op and copying the function body\n";
+  newFuncOp.dump();
+
+  Location loc = newFuncOp.getBody().getLoc();
+  Block &entryBlock = newFuncOp.getBlocks().front();
+  rewriter.setInsertionPointToStart(&entryBlock);
+
+  OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
+
+  // For arguments that are of illegal types and require unrolling.
+  // `unrolledInputNums` stores the indices of arguments that result from
+  // unrolling in the new function signature. `newInputNo` is a counter.
+  SmallVector<size_t> unrolledInputNums;
+  size_t newInputNo = 0;
+
+  // For arguments that are of legal types and do not require unrolling.
+  // `tmpOps` stores a mapping from temporary operations that serve as
+  // placeholders for new arguments that will be added later. These operations
+  // will be erased once the entry block's argument list is updated.
+  DenseMap<Operation *, size_t> tmpOps;
+
+  // This counts the number of new operations created.
+  size_t newOpCount = 0;
+
+  // Enumerate through the arguments.
+  for (const auto &argType : enumerate(fnType.getInputs())) {
+    size_t origInputNo = argType.index();
+    Type origType = argType.value();
+    // Check whether the argument is of vector type.
+    auto origVecType = llvm::dyn_cast<VectorType>(origType);
+    if (!origVecType) {
+      // We need a placeholder for the old argument that will be erased later.
+      Value result = rewriter.create<arith::ConstantOp>(
+          loc, origType, rewriter.getZeroAttr(origType));
+      rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+      tmpOps.insert({result.getDefiningOp(), newInputNo});
+      oneToNTypeMapping.addInputs(origInputNo, origType);
+      newInputNo++;
+      newOpCount++;
+      continue;
+    }
+    // Check whether the vector needs unrolling.
+    auto targetShape = getTargetShape(origVecType);
+    if (!targetShape) {
+      // We need a placeholder for the old argument that will be erased later.
+      Value result = rewriter.create<arith::ConstantOp>(
+          loc, origType, rewriter.getZeroAttr(origType));
+      rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+      tmpOps.insert({result.getDefiningOp(), newInputNo});
+      oneToNTypeMapping.addInputs(origInputNo, origType);
+      newInputNo++;
+      newOpCount++;
+      continue;
+    }
+    llvm::errs() << "Got target shape\n";
+    VectorType unrolledType =
+        VectorType::get(*targetShape, origVecType.getElementType());
+    llvm::errs() << "Unrolled type is ";
+    unrolledType.dump();
+    SmallVector<int64_t> originalShape =
+        llvm::to_vector<4>(origVecType.getShape());
+
+    // Prepare the result vector.
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, origVecType, rewriter.getZeroAttr(origVecType));
+    newOpCount++;
+    // Prepare the placeholder for the new arguments that will be added later.
+    Value dummy = rewriter.create<arith::ConstantOp>(
+        loc, unrolledType, rewriter.getZeroAttr(unrolledType));
+    newOpCount++;
+
+    // Create the `vector.insert_strided_slice` ops.
+    SmallVector<int64_t> strides(targetShape->size(), 1);
+    SmallVector<Type> newTypes;
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalShape, *targetShape)) {
+      result = rewriter.create<vector::InsertStridedSliceOp>(loc, dummy, result,
+                                                             offsets, strides);
+      newTypes.push_back(unrolledType);
+      unrolledInputNums.push_back(newInputNo);
+      newInputNo++;
+      newOpCount++;
+    }
+    rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
+    oneToNTypeMapping.addInputs(origInputNo, newTypes);
+  }
+
+  llvm::errs() << "After enumerating through the arguments\n";
+  newFuncOp.dump();
+
+  // Change the function signature.
+  auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
+  auto newFnType =
+      FunctionType::get(rewriter.getContext(), TypeRange(convertedTypes),
+                        TypeRange(fnType.getResults()));
+  rewriter.modifyOpInPlace(newFuncOp,
+                           [&] { newFuncOp.setFunctionType(newFnType); });
+
+  llvm::errs() << "After changing function signature\n";
+  newFuncOp.dump();
+
+  // Update the arguments in the entry block.
+  entryBlock.eraseArguments(0, fnType.getNumInputs());
+  SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
+  entryBlock.addArguments(convertedTypes, locs);
+
+  llvm::errs() << "After updating the arguments in the entry block\n";
+  newFuncOp.dump();
+
+  // Replace the placeholder values with the new arguments. We assume there is
+  // only one block for now.
+  size_t idx = 0;
+  for (auto opPair : llvm::enumerate(entryBlock.getOperations())) {
+    size_t count = opPair.index();
+    Operation &op = opPair.value();
+    // We first look for operands that are placeholders for initially legal
+    // arguments.
+    for (auto operandPair : llvm::enumerate(op.getOperands())) {
+      Operation *operandOp = operandPair.value().getDefiningOp();
+      if (tmpOps.find(operandOp) != tmpOps.end())
+        rewriter.modifyOpInPlace(&op, [&] {
+          op.setOperand(operandPair.index(),
+                        newFuncOp.getArgument(tmpOps[operandOp]));
+        });
+    }
+    // Since all newly created operations are in the beginning, reaching the end
+    // of them means that any later `vector.insert_strided_slice` should not be
+    // touched.
+    if (count >= newOpCount)
+      continue;
+    auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op);
+    if (vecOp) {
+      size_t unrolledInputNo = unrolledInputNums[idx];
+      rewriter.modifyOpInPlace(&op, [&] {
+        op.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
+      });
+      idx++;
+    }
+    count++;
+  }
+
+  // Erase the original funcOp. The `tmpOps` do not need to be erased since
+  // they have no uses and will be handled by dead-code elimination.
+  rewriter.eraseOp(funcOp);
+  return success();
+}
+
+void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
+  patterns.add<FuncOpVectorUnroll>(patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// func::ReturnOp Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A pattern for rewriting function signature and the return op to convert
+/// vectors to be of valid types.
+class ReturnOpVectorUnroll : public OpRewritePattern<func::ReturnOp> {
+public:
+  using OpRewritePattern<func::ReturnOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(func::ReturnOp returnOp,
+                                PatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult
+ReturnOpVectorUnroll::matchAndRewrite(func::ReturnOp returnOp,
+                                      PatternRewriter &rewriter) const {
+
+  // Check whether the parent funcOp is valid.
+  func::FuncOp funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
+  if (!funcOp)
+    return failure();
+
+  auto fnType = funcOp.getFunctionType();
+  OneToNTypeMapping oneToNTypeMapping(fnType.getResults());
+  Location loc = returnOp.getLoc();
+
+  // For the new return op.
+  SmallVector<Value> newOperands;
+
+  // Enumerate through the results.
+  for (const auto &argType : enumerate(fnType.getResults())) {
+    size_t origResultNo = argType.index();
+    Type origType = argType.value();
+    auto origVecType = llvm::dyn_cast<VectorType>(origType);
+    // Check whether the argument is of vector type.
+    if (!origVecType) {
+      oneToNTypeMapping.addInputs(origResultNo, origType);
+      newOperands.push_back(returnOp.getOperand(origResultNo));
+      continue;
+    }
+    // Check whether the vector needs unrolling.
+    auto targetShape = getTargetShape(origVecType);
+    if (!targetShape) {
+      // The original argument can be used.
+      oneToNTypeMapping.addInputs(origResultNo, origType);
+      newOperands.push_back(returnOp.getOperand(origResultNo));
+      continue;
+    }
+    llvm::errs() << "Got target shape\n";
+    VectorType unrolledType =
+        VectorType::get(*targetShape, origVecType.getElementType());
+    llvm::errs() << "Unrolled type is ";
+    unrolledType.dump();
+
+    // Create `vector.extract_strided_slice` ops to form legal vectors from the
+    // original operand of illegal type.
+    SmallVector<int64_t> originalShape =
+        llvm::to_vector<4>(origVecType.getShape());
+    SmallVector<int64_t> strides(targetShape->size(), 1);
+    SmallVector<Type> newTypes;
+    Value returnValue = returnOp.getOperand(origResultNo);
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalShape, *targetShape)) {
+      Value result = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, returnValue, offsets, *targetShape, strides);
+      newOperands.push_back(result);
+      newTypes.push_back(unrolledType);
+    }
+    oneToNTypeMapping.addInputs(origResultNo, newTypes);
+  }
+
+  llvm::errs() << "After enumerating through the arguments\n";
+  funcOp.dump();
+
+  // Change the function signature.
+  auto newFnType =
+      FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
+                        TypeRange(oneToNTypeMapping.getConvertedTypes()));
+  rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setFunctionType(newFnType); });
+  llvm::errs() << "After changing function signature\n";
+  funcOp.dump();
+
+  // Replace the return op using the new operands. This will automatically
+  // update the entry block as well.
+  rewriter.replaceOp(returnOp,
+                     rewriter.create<func::ReturnOp>(loc, newOperands));
+
+  return success();
+}
+
+void mlir::populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns) {
+  patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
+}
+
 //===----------------------------------------------------------------------===//
 // Builtin Variables
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
new file mode 100644
index 0000000000000..d5c777908d7e2
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
@@ -0,0 +1,132 @@
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion" -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @simple_scalar
+// CHECK-SAME: (%[[ARG0:.+]]: i32)
+func.func @simple_scalar(%arg0 : i32) -> i32 {
+  // CHECK: return %[[ARG0]] : i32
+  return %arg0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @simple_vector_4
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>)
+func.func @simple_vector_4(%arg0 : vector<4xi32>) -> vector<4xi32> {
+  // CHECK: return %[[ARG0]] : vector<4xi32>
+  return %arg0 : vector<4x...
[truncated]

Copy link

github-actions bot commented Jul 10, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@kuhar kuhar requested a review from Hardcode84 July 10, 2024 15:58
@Hardcode84
Copy link
Contributor

I have a general question: can we use/extend existing vector::populateVectorUnrollPatterns instead of introducing SPIR-V-specific pass?

@kuhar
Copy link
Member

kuhar commented Jul 10, 2024

I have a general question: can we use/extend existing vector::populateVectorUnrollPatterns instead of introducing SPIR-V-specific pass?

Angel actually started like this, but we were concerned that there are unlikely to be any other users in the repo, so from the point of view of code maintenance it would be better to keep it on the spirv side. If someone else needs it, we can always promote it to vector transforms.

@Hardcode84
Copy link
Contributor

Angel actually started like this, but we were concerned that there are unlikely to be any other users in the repo, so from the point of view of code maintenance it would be better to keep it on the spirv side. If someone else needs it, we can always promote it to vector transforms.

Ok. I'm fine either way, but it looked like a generic enough utility.

@angelz913 angelz913 force-pushed the spirv-func-vector-unrolling branch 2 times, most recently from 4b91a99 to 139ea57 Compare July 10, 2024 19:54
@Hardcode84
Copy link
Contributor

Also, one potential issue is that PR doesn't cover func::call op and if you have one, it will result in broken IR.

@angelz913 angelz913 force-pushed the spirv-func-vector-unrolling branch 3 times, most recently from 7dac5b8 to 6d4a13c Compare July 11, 2024 15:17
@angelz913 angelz913 changed the title [mlir][spirv] Implement vector type legalization in function signatures [mlir][spirv] Implement vector type legalization for function signatures Jul 11, 2024
@angelz913
Copy link
Contributor Author

Also, one potential issue is that PR doesn't cover func::call op and if you have one, it will result in broken IR.

We will deal with them later. I have included this in the "Future Steps" section in the PR description.

applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
return signalPassFailure();
}
return;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we are skipping the rest of the pass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking is that this PR only takes care of the new patterns, so the new tests I added only check for the correctness of those patterns, without going through the SPIR-V dialect conversion (the rest of this pass). In the future, I will add more vector patterns and more command-line options for controlling/testing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, currently this looks like two different passes hiding in one class. I'd expect the rest of the code to run after signature conversion.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can have a dedicated --test-convert-to-spirv-patterns pass or whatever to test all sorts of patterns you'd like to test in isolation. In the test pass you can have many options, like TestLinalgTransforms.cpp. This keeps the main pass introduced here less cluttered and more targeted end-to-end so more clean, given I think we are sort of aiming for this pass to be directly consumed by downstream users eventually. WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed this offline and also converged on adding fine-grained test passes

@angelz913 angelz913 force-pushed the spirv-func-vector-unrolling branch 2 times, most recently from de7e65a to 7f8dc86 Compare July 15, 2024 20:25
@angelz913 angelz913 force-pushed the spirv-func-vector-unrolling branch from 7f8dc86 to 2fb2a64 Compare July 15, 2024 20:45
@angelz913 angelz913 force-pushed the spirv-func-vector-unrolling branch from dfb6e2e to af0fb06 Compare July 16, 2024 01:30
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@Hardcode84 Hardcode84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@angelz913 angelz913 force-pushed the spirv-func-vector-unrolling branch from af0fb06 to d702290 Compare July 17, 2024 03:23
@angelz913 angelz913 force-pushed the spirv-func-vector-unrolling branch from d702290 to 7347d06 Compare July 17, 2024 04:39
auto it = tmpOps.find(operandOp);
if (it != tmpOps.end())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't resolved

@angelz913 angelz913 force-pushed the spirv-func-vector-unrolling branch from 10984de to a8972cc Compare July 17, 2024 13:19
@kuhar kuhar merged commit 6867e49 into llvm:main Jul 17, 2024
5 of 6 checks passed
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
…res (#98337)

Summary:
### Description
This PR implements a minimal version of function signature conversion to
unroll vectors into 1D and with a size supported by SPIR-V (2, 3 or 4
depending on the original dimension). This PR also includes new unit
tests that only check for function signature conversion.

### Future Plans
- Check for capabilities that support vectors of size 8 or 16.
- Set up `OneToNTypeConversion` and `DialectConversion` to replace the
current implementation that uses `GreedyPatternRewriteDriver`.
- Introduce other vector unrolling patterns to cancel out the
`vector.insert_strided_slice` and `vector.extract_strided_slice` ops and
fully legalize the vector types in the function body.
- Handle `func::CallOp` and declarations.
- Restructure the code in `SPIRVConversion.cpp`.
- Create test passes for testing sets of patterns in isolation.
- Optimize the way original shape is splitted into target shapes, e.g.
`vector<5xi32>` can be splitted into `vector<4xi32>` and
`vector<1xi32>`.

---------

Co-authored-by: Jakub Kuderski <[email protected]>

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60250907
@angelz913 angelz913 deleted the spirv-func-vector-unrolling branch July 25, 2024 21:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants