Skip to content

[mlir] Refactor ConvertVectorToLLVMPass options #128219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 10, 2025

Conversation

abulavin
Copy link
Contributor

@abulavin abulavin commented Feb 21, 2025

The VectorTransformsOptions on the ConvertVectorToLLVMPass is currently represented as a struct, which makes it not serialisable. This means a pass pipeline that contains this pass cannot be represented as textual form, which breaks reproducer generation and options such as --dump-pass-pipeline.

This PR expands the VectorTransformsOptions struct into the two options that are actually used by the Pass' patterns: vector-contract-lowering and vector-transpose-lowering . The other options present in VectorTransformOptions are not used by any patterns in this pass.

Additionally, I have changed some interfaces to only take these specific options over the full options struct as, again, the vector contract and transpose lowering patterns only need one of their respective options.

Finally, I have added a simple lit test that just prints the pass pipeline using --dump-pass-pipeline to ensure the options on this pass remain serialisable.

Fixes #129046

…Pass

Refactor ConvertVectorToLLVMPass options
Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Feb 21, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir-spirv

Author: Artemiy Bulavin (abulavin)

Changes

The VectorTransformsOptions on the ConvertVectorToLLVMPass is currently represented as a struct, which makes it not serialisable. This means a pass pipeline that contains this pass cannot be represented as textual form, which breaks reproducer generation and options such as --dump-pass-pipeline.

This PR expands the VectorTransformsOptions struct into the two options that are actually used by the Pass' patterns: vector-contract-lowering and vector-transpose-lowering . The other options present in VectorTransformOptions are not used by any patterns in this pass.

Additionally, I have changed some interfaces to only take these specific options over the full options struct as, again, many of the conversion patterns only need one of the options.

Finally, I have added a simple lit test that just prints the pass pipeline using --dump-pass-pipeline to ensure the options on this pass remain serialisable.


Patch is 22.72 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/128219.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.td (+27-5)
  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h (+7-4)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+2-2)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+2-3)
  • (modified) mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (+3-6)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+31-33)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp (+13-15)
  • (added) mlir/test/Conversion/VectorToLLVM/test-serialisable.mlir (+16)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index cccdf0a8518bf..606a38f7d98eb 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -10,7 +10,7 @@
 #define MLIR_CONVERSION_PASSES
 
 include "mlir/Pass/PassBase.td"
-
+include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
 
 //===----------------------------------------------------------------------===//
 // ToLLVM
@@ -1410,10 +1410,32 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
            "bool", /*default=*/"false",
            "Enables the use of X86Vector dialect while lowering the vector "
 	   "dialect.">,
-    Option<"vectorTransformsOptions", "vector-transform-options",
-           "vector::VectorTransformsOptions",
-           /*default=*/"vector::VectorTransformsOptions()",
-           "Options to lower some operations like contractions and transposes.">,
+    Option<"vectorContractLowering", "vector-contract-lowering",
+           "vector::VectorContractLowering",
+           /*default=*/"vector::VectorContractLowering::Dot",
+           VectorContractLoweringAttr.summary, [{::llvm::cl::values(
+           clEnumValN(::mlir::vector::VectorContractLowering::Dot, "dot", 
+            "Progressively lower to finer grained `vector.contract` and dot-products. (default)"),
+           clEnumValN(::mlir::vector::VectorContractLowering::Matmul, "matmul",
+            "Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics."),
+           clEnumValN(::mlir::vector::VectorContractLowering::OuterProduct, "outerproduct",
+            "Lower to `vector.outerproduct`."),
+           clEnumValN(::mlir::vector::VectorContractLowering::ParallelArith, "parallelarith",
+            "Lower contract with all reduction dimensions unrolled to 1 to a vector elementwise operations.")
+	        )}]>,
+    Option<"vectorTransposeLowering", "vector-transpose-lowering",
+           "vector::VectorTransposeLowering",
+           /*default=*/"vector::VectorTransposeLowering::EltWise",
+           VectorTransposeLoweringAttr.summary, [{::llvm::cl::values(
+           clEnumValN(::mlir::vector::VectorTransposeLowering::EltWise, "eltwise",
+            "Lower transpose into element-wise extract and inserts (default)"),
+           clEnumValN(::mlir::vector::VectorTransposeLowering::Flat, "flat",
+            "Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix intrinsics"),
+           clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle1D, "shuffle1d",
+            "Lower 2-D transpose to `vector.shuffle` on 1-D vector."),
+           clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle16x16, "shuffle16x16",
+            "Lower 2-D transpose to `vector.shuffle` on 16x16 vector.")
+          )}]>,
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 6aeae30a0a6c0..601a65333d026 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
 #define MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
 
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 
 namespace mlir {
@@ -47,7 +48,8 @@ namespace vector {
 /// Progressively lower a `vector.contract` with row-major matmul semantics to
 /// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`.
 void populateVectorContractLoweringPatterns(
-    RewritePatternSet &patterns, VectorTransformsOptions options,
+    RewritePatternSet &patterns,
+    VectorContractLowering vectorContractLoweringOption,
     PatternBenefit benefit = 1, bool disableOuterProductLowering = false);
 
 /// Populate the pattern set with the following patterns:
@@ -142,9 +144,10 @@ void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
 ///
 /// [TransposeOp2DToShuffleLowering]
 ///
-void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns,
-                                             VectorTransformsOptions options,
-                                             PatternBenefit benefit = 1);
+void populateVectorTransposeLoweringPatterns(
+    RewritePatternSet &patterns,
+    VectorTransposeLowering vectorTransposeLowering,
+    PatternBenefit benefit = 1);
 
 /// Populate the pattern set with the following patterns:
 ///
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index e3a81bd20212d..eb1555df5d574 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -69,11 +69,11 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorToVectorCanonicalizationPatterns(patterns);
     populateVectorBitCastLoweringPatterns(patterns);
     populateVectorBroadcastLoweringPatterns(patterns);
-    populateVectorContractLoweringPatterns(patterns, vectorTransformsOptions);
+    populateVectorContractLoweringPatterns(patterns, vectorContractLowering);
     populateVectorMaskOpLoweringPatterns(patterns);
     populateVectorShapeCastLoweringPatterns(patterns);
     populateVectorInterleaveLoweringPatterns(patterns);
-    populateVectorTransposeLoweringPatterns(patterns, vectorTransformsOptions);
+    populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
     // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
     populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
     populateVectorMaskMaterializationPatterns(patterns,
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index c56dbcca2175d..a60410d01ac57 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1374,9 +1374,8 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
   // further transformations to canonicalize/cancel.
   {
     RewritePatternSet patterns(context);
-    auto options = vector::VectorTransformsOptions().setVectorTransposeLowering(
-        vector::VectorTransposeLowering::EltWise);
-    vector::populateVectorTransposeLoweringPatterns(patterns, options);
+    vector::populateVectorTransposeLoweringPatterns(
+        patterns, vector::VectorTransposeLowering::EltWise);
     vector::populateVectorShapeCastLoweringPatterns(patterns);
     if (failed(applyPatternsGreedily(op, std::move(patterns))))
       return failure();
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 241e83e234d62..20c577273d786 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -102,9 +102,7 @@ void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
 
 void transform::ApplyLowerContractionPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  vector::VectorTransformsOptions vectorTransformOptions;
-  vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy());
-  populateVectorContractLoweringPatterns(patterns, vectorTransformOptions,
+  populateVectorContractLoweringPatterns(patterns, getLoweringStrategy(),
                                          /*benefit=*/1,
                                          /*disableOuterProductLowering=*/true);
 }
@@ -161,9 +159,8 @@ void transform::ApplyLowerTransferPatternsOp::populatePatterns(
 
 void transform::ApplyLowerTransposePatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  vector::populateVectorTransposeLoweringPatterns(
-      patterns, vector::VectorTransformsOptions().setVectorTransposeLowering(
-                    getLoweringStrategy()));
+  vector::populateVectorTransposeLoweringPatterns(patterns,
+                                                  getLoweringStrategy());
   if (getAvx2LoweringStrategy()) {
     auto avx2LoweringOptions =
         x86vector::avx2::LoweringOptions().setTransposeOptions(
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 21261478f0648..d2f60a55fb4a6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -221,7 +221,7 @@ namespace {
 /// ```
 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
 //
-/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
+/// This only kicks in when VectorTransformsOptions is set to Matmul and
 /// the vector.contract op is a row-major matrix multiply.
 class ContractionOpToMatmulOpLowering
     : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
@@ -236,11 +236,11 @@ class ContractionOpToMatmulOpLowering
   }
 
   ContractionOpToMatmulOpLowering(
-      vector::VectorTransformsOptions vectorTransformOptions,
+      vector::VectorContractLowering vectorContractLowering,
       MLIRContext *context, PatternBenefit benefit = 1,
       FilterConstraintType constraint = defaultFilter)
       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions),
+        vectorContractLowering(vectorContractLowering),
         filter(std::move(constraint)) {}
 
   FailureOr<Value>
@@ -249,7 +249,7 @@ class ContractionOpToMatmulOpLowering
 
 private:
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
+  vector::VectorContractLowering vectorContractLowering;
   FilterConstraintType filter;
 };
 
@@ -281,11 +281,11 @@ class ContractionOpToOuterProductOpLowering
   }
 
   ContractionOpToOuterProductOpLowering(
-      vector::VectorTransformsOptions vectorTransformOptions,
+      vector::VectorContractLowering vectorContractLowering,
       MLIRContext *context, PatternBenefit benefit = 1,
       FilterConstraintType constraint = defaultFilter)
       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions),
+        vectorContractLowering(vectorContractLowering),
         filter(std::move(constraint)) {}
 
   FailureOr<Value>
@@ -294,7 +294,7 @@ class ContractionOpToOuterProductOpLowering
 
 private:
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
+  vector::VectorContractLowering vectorContractLowering;
   FilterConstraintType filter;
 };
 
@@ -329,11 +329,11 @@ class ContractionOpToDotLowering
   }
 
   ContractionOpToDotLowering(
-      vector::VectorTransformsOptions vectorTransformOptions,
+      vector::VectorContractLowering vectorContractLowering,
       MLIRContext *context, PatternBenefit benefit = 1,
       const FilterConstraintType &constraint = defaultFilter)
       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
+        vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
 
   FailureOr<Value>
   matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
@@ -341,7 +341,7 @@ class ContractionOpToDotLowering
 
 private:
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
+  vector::VectorContractLowering vectorContractLowering;
   FilterConstraintType filter;
 };
 
@@ -370,11 +370,12 @@ class ContractionOpLowering
     return success();
   }
 
-  ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
-                        MLIRContext *context, PatternBenefit benefit = 1,
-                        FilterConstraintType constraint = defaultFilter)
+  ContractionOpLowering(
+      vector::VectorContractLowering vectorContractLoweringOption,
+      MLIRContext *context, PatternBenefit benefit = 1,
+      FilterConstraintType constraint = defaultFilter)
       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions),
+        vectorContractLoweringOption(vectorContractLoweringOption),
         filter(std::move(constraint)) {}
 
   FailureOr<Value>
@@ -383,7 +384,7 @@ class ContractionOpLowering
 
 private:
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
+  vector::VectorContractLowering vectorContractLoweringOption;
   FilterConstraintType filter;
   // Lower one parallel dimension.
   FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
@@ -641,8 +642,7 @@ FailureOr<Value>
 ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
     vector::ContractionOp op, MaskingOpInterface maskOp,
     PatternRewriter &rewriter) const {
-  if (vectorTransformOptions.vectorContractLowering !=
-      vector::VectorContractLowering::OuterProduct)
+  if (vectorContractLowering != vector::VectorContractLowering::OuterProduct)
     return failure();
 
   if (failed(filter(op)))
@@ -672,8 +672,7 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
   if (failed(filter(op)))
     return failure();
 
-  if (vectorTransformOptions.vectorContractLowering !=
-      vector::VectorContractLowering::Dot)
+  if (vectorContractLowering != vector::VectorContractLowering::Dot)
     return failure();
 
   auto iteratorTypes = op.getIteratorTypes().getValue();
@@ -789,11 +788,11 @@ struct ContractOpToElementwise
     return success();
   }
   ContractOpToElementwise(
-      vector::VectorTransformsOptions vectorTransformOptions,
+      vector::VectorContractLowering vectorContractLowering,
       MLIRContext *context, PatternBenefit benefit = 1,
       const FilterConstraintType &constraint = defaultFilter)
       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
+        vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
 
   FailureOr<Value>
   matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
@@ -806,8 +805,7 @@ struct ContractOpToElementwise
     if (failed(filter(contractOp)))
       return failure();
 
-    if (vectorTransformOptions.vectorContractLowering !=
-        vector::VectorContractLowering::ParallelArith)
+    if (vectorContractLowering != vector::VectorContractLowering::ParallelArith)
       return failure();
 
     ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
@@ -898,7 +896,7 @@ struct ContractOpToElementwise
 
 private:
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
+  vector::VectorContractLowering vectorContractLowering;
   FilterConstraintType filter;
 };
 
@@ -941,25 +939,25 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
   // TODO: implement benefits, cost models.
   MLIRContext *ctx = op.getContext();
 
-  ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
+  ContractionOpToMatmulOpLowering pat1(vectorContractLoweringOption, ctx);
   FailureOr<Value> newVal1 =
       pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
   if (!failed(newVal1))
     return newVal1;
 
-  ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
+  ContractionOpToOuterProductOpLowering pat2(vectorContractLoweringOption, ctx);
   FailureOr<Value> newVal2 =
       pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
   if (!failed(newVal2))
     return newVal2;
 
-  ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
+  ContractionOpToDotLowering pat3(vectorContractLoweringOption, ctx);
   FailureOr<Value> newVal3 =
       pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
   if (!failed(newVal3))
     return newVal3;
 
-  ContractOpToElementwise pat4(vectorTransformOptions, ctx);
+  ContractOpToElementwise pat4(vectorContractLoweringOption, ctx);
   FailureOr<Value> newVal4 =
       pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
   if (!failed(newVal4))
@@ -1292,8 +1290,7 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
   if (maskOp)
     return failure();
 
-  if (vectorTransformOptions.vectorContractLowering !=
-      vector::VectorContractLowering::Matmul)
+  if (vectorContractLowering != vector::VectorContractLowering::Matmul)
     return failure();
   if (failed(filter(op)))
     return failure();
@@ -1382,13 +1379,14 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
 } // namespace
 
 void mlir::vector::populateVectorContractLoweringPatterns(
-    RewritePatternSet &patterns, VectorTransformsOptions options,
-    PatternBenefit benefit, bool disableOuterProductLowering) {
+    RewritePatternSet &patterns,
+    VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit,
+    bool disableOuterProductLowering) {
   if (!disableOuterProductLowering)
     patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
   patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
                ContractionOpToOuterProductOpLowering>(
-      options, patterns.getContext(), benefit);
+      vectorContractLoweringOption, patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateVectorOuterProductLoweringPatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index fb4dee33bc5f5..732e316c93381 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -304,10 +304,10 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
 
-  TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
+  TransposeOpLowering(vector::VectorTransposeLowering vectorTransposeLowering,
                       MLIRContext *context, PatternBenefit benefit = 1)
       : OpRewritePattern<vector::TransposeOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions) {}
+        vectorTransposeLowering(vectorTransposeLowering) {}
 
   LogicalResult matchAndRewrite(vector::TransposeOp op,
                                 PatternRewriter &rewriter) const override {
@@ -324,14 +324,13 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
     // Set up convenience transposition table.
     ArrayRef<int64_t> transp = op.getPermutation();
 
-    if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
+    if (isShuffleLike(vectorTransposeLowering) &&
         succeeded(isTranspose2DSlice(op)))
       return rewriter.notifyMatchFailure(
           op, "Options specifies lowering to shuffle");
 
     // Handle a true 2-D matrix transpose differently when requested.
-    if (vectorTransformOptions.vectorTransposeLowering ==
-            vector::VectorTransposeLowering::Flat &&
+    if (vectorTransposeLowering == vector::VectorTransposeLowering::Flat &&
         resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
       Type flattenedType =
           VectorType::get(resType.getNumElements(), resType.getElementType());
@@ -380,7 +379,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
 
 private:
   /// Options to control the vector patterns.
-  vector::VectorTransformsOptions vectorTransformOptions;
+  vector::VectorTransposeLowering vectorTransposeLowering;
 };
 
 /// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
@@ -454,14 +453,14 @@ class TransposeOp2DToShuffleLowering
   using OpRewritePattern::OpRewritePattern;
 
   TransposeOp2DToShuffleLowering(
-      vector::VectorTransformsOptions vectorTransformOptions,
+      vector::VectorTransposeLowering vectorTransposeLowering,
       MLIRContext *context, PatternBenefit benefit = 1)
       : OpRewritePattern<vector::TransposeOp>(context, benefit),
-        vectorTransformOptions(vectorTransformOptions) {}
+        vectorTransposeLower...
[truncated]

@abulavin
Copy link
Contributor Author

@abulavin abulavin changed the title Refactor ConvertVectorToLLVMPass options [mlir] Refactor ConvertVectorToLLVMPass options Feb 21, 2025
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! This LGTM but please wait for others to chime in. Thanks!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on this 🙏🏻

Before we proceed, I raised a question about the overall direction in MLIR regarding these flags. I don't mean to block this - if this functionality is required, lets land it. However, if we are to land it, we should also properly test/exercise these flags. That's one thing that was missing in the previous PR.

Regardless of the outcome, this is helping us improve MLIR 🙏🏻

Comment on lines 1413 to 1438
Option<"vectorContractLowering", "vector-contract-lowering",
"vector::VectorContractLowering",
/*default=*/"vector::VectorContractLowering::Dot",
VectorContractLoweringAttr.summary, [{::llvm::cl::values(
clEnumValN(::mlir::vector::VectorContractLowering::Dot, "dot",
"Progressively lower to finer grained `vector.contract` and dot-products. (default)"),
clEnumValN(::mlir::vector::VectorContractLowering::Matmul, "matmul",
"Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics."),
clEnumValN(::mlir::vector::VectorContractLowering::OuterProduct, "outerproduct",
"Lower to `vector.outerproduct`."),
clEnumValN(::mlir::vector::VectorContractLowering::ParallelArith, "parallelarith",
"Lower contract with all reduction dimensions unrolled to 1 to a vector elementwise operations.")
)}]>,
Option<"vectorTransposeLowering", "vector-transpose-lowering",
"vector::VectorTransposeLowering",
/*default=*/"vector::VectorTransposeLowering::EltWise",
VectorTransposeLoweringAttr.summary, [{::llvm::cl::values(
clEnumValN(::mlir::vector::VectorTransposeLowering::EltWise, "eltwise",
"Lower transpose into element-wise extract and inserts (default)"),
clEnumValN(::mlir::vector::VectorTransposeLowering::Flat, "flat",
"Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix intrinsics"),
clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle1D, "shuffle1d",
"Lower 2-D transpose to `vector.shuffle` on 1-D vector."),
clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle16x16, "shuffle16x16",
"Lower 2-D transpose to `vector.shuffle` on 16x16 vector.")
)}]>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO, this is becoming too verbose and too fine-grained - do we need this?

In general, there are three different ways to drive this pass (and the relevant transformations):

  • C++ API
  • Transform Dialect (this is what we tend to use in tests)
  • LLVM command-line (llvm::cl) flags (added in this PR)

From my perspective, llvm::cl flags are mostly useful for testing and prototyping. However, in practice, we seem to favor Transform Dialect for finer-grained control. That said, this may be a matter of personal preference.

My bigger concern is that we are not testing these flags. If they duplicate existing Transform Dialect functionality, we should ensure that relevant RUN lines are duplicated, rather than adding tests only to check the flags themselves.

This is a broader issue in MLIR and somewhat tangential to this PR (which is attempting to fix the issue reported here: MLIR: Specifying Structs as cl::Options for PassOptions).

Let’s wait for others to chime in. More generally, I hope we can establish clearer guidelines on the relationship between different ways of "driving" transformations and preferred testing approaches.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be news to me that transform dialect is the preferred way of testing something in MLIR: so far it's been an infrastructure on the side, not the primary way to access functionality (which remains passes and pass pipelines).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous PR added these options to the pass initially so I'm assuming that this implies the need for this kind of control?

I'm also unfamiliar with using the transform dialect for testing -- my understanding was that pass pipelines were generally tested using invocations of mlir-opt. Indeed, there are a large number of test passes registered here that would suggest so, unless I've misunderstood. If yourself and the other reviewers would like me to write some tests as using the transform interpreter than I am happy to do so.

That being said, when I was writing this PR I attempted to write a test for these options. However, as someone who has limited knowledge of the vector to llvm lowering, it's not obvious to me how to correctly isolate and test the effect of a single pattern option on the entirety of the ConvertVectorToLLVMPass. For example, it's hard for me to say (and seems like a brittle test) what the resulting LLVM IR looks like when vector.contract is lowered with the dot option versus the matmul option, when this is only one of many patterns in the pass. As a result, my hope is that the correctness of these specific patterns in this PR is covered by existing tests in the vector dialect.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are only two knobs which make a huge difference in the transformations that happen within the pass. It just happens that they are enums and have a bunch of values. I think we should support them.

Would it be acceptable to add the default enum values to tests already using this pass and also add more flag combinations to the dump pipeline test? I'm afraid that if we widely test all the combinations on existing tests we get stuck with failures.

It would be news to me that transform dialect is the preferred way of testing something in MLIR:

Well, for better or worse, quite a few tests in the Vector dialect were moved from pipeline to TD so they are no longer tested on the pipeline side.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The #123491 added these options to the pass initially so I'm assuming that this implies the need for this kind of control?

Well, the lack of upstream testing would suggest otherwise. :)

Now, there's a good reason for that - these transformations are tested via Transform Dialect (TD). However, without explicit tests, it becomes very difficult to distinguish between dead code and actively used functionality. And that’s my biggest concern.

That said, the test you added for serializing the options is a great idea to mitigate this - thanks!

Still, there’s no clear indication (in-tree) that this functionality is actually needed.


Indeed, there are a large number of test passes registered here that would suggest so, unless I've misunderstood.

Yes, but there’s a comparable number of Transform Dialect Ops.

Let’s be a bit more specific:

  • 27 test files in mlir/test/Dialect/Tensor7 use TD.
  • 158 test files in mlir/test/Dialect/Linalg79 use TD.
  • 71 test files in mlir/test/Dialect/Vector29 use TD.

(I focused on the "Tensor compiler" specifically.)

So, TD is not the dominant driver for test pipelines, but it's also not just a side infrastructure. I suspect it’s more commonly used where fine-grained control is required - for example, here.

Now, note that this PR adds:

clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle1D, "shuffle1d",
            "Lower 2-D transpose to `vector.shuffle` on 1-D vector."),

So the key question is:

  • Do we duplicate the test to verify that the newly supported flag works as expected?

As a result, my hope is that the correctness of these specific patterns in this PR is covered by existing tests in the vector dialect.

Absolutely! Adding new Vector tests would be way beyond the scope of this PR 😅

Would it be acceptable to add the default enum values to tests already using this pass and also add more flag combinations to the dump pipeline test?

I'm not sure I follow ...

My thinking is more along the lines of adding new RUN lines to e.g. https://github.com/llvm/llvm-project/blob/60cc3af0d93ecb8bfc9d6bebc6cbc395df3bb4b6/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir. For example, we could add:

mlir-opt -vector-transpose-lowering=shuffle1d %s | FileCheck --check-prefix=SHUFFLE_1D

while still keeping the existing:

// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s

The downside? check-prefix explosion :(


@dcaballe, I am guessing that you will use this downstream? If yes, then lets merge this - we should aim to progress

ASAP. Also, I am not really providing any viable alternatives.

Copy link
Contributor Author

@abulavin abulavin Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@banach-space What about the alternative of adding tests like:

// RUN: mlir-opt -convert-vector-to-llvm="vector-contract-lowering=dot" %s | FileCheck %s

func.func @vector_contract(%arg0: vector<2x2xf32>, %arg1: vector<2xf32>, %arg2: vector<2xf32>) -> vector<2xf32> {
    %out = vector.contract {
    // .. and so on
}

// CHECK-LABEL: @vector_contract
// Insert various checks on the resulting LLVMIR here

This does still suffer from the check-prefix explosion but does thoroughly test the flags and the effect of these options on the overall ConvertToLLVMPass pretty unambiguously. There's also no duplication with existing vector to LLVM conversion tests because these are new options for the pass. In my opinion these are the most suitable kinds of tests but do require some upfront work for me to properly understand the transformations before I can write tests for them. I've not written Vector to LLVM conversion tests before though so I would like some thoughts on whether this is worth the effort.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion!

There's also no duplication with existing vector to LLVM conversion tests because these are new options for the pass.

Note, we actually do have "options" for these, see:

(this is a TD Op rather than an "option" - the point is that it exposes exactly that functionality).

As for test duplication, this is already tested here:

transform.apply_patterns to %f {
transform.apply_patterns.vector.lower_contraction lowering_strategy = "dot"
} : !transform.any_op
transform.yield

This brings me back to my original point, there are effectively 3 ways to drive most (all?) transformations:

  • C++ API
  • TD Ops
  • LLVM command-line (llvm::cl) flags (added in this PR)

While there's plenty of duplication throughout wider LLVM, we should actively be thinking how to reduce that.

@dcaballe
Copy link
Contributor

I think test-serialisable.mlir is a great way to test the CL API without actually testing the underlying transformations. We should add more RUN rules there to check all the enum values for the new options. That should good coverage.

Would it be acceptable to add the default enum values to tests already using this pass and also add more flag combinations to the dump pipeline test?

I'm not sure I follow ...

My point was that we can s/convert-vector-to-llvm/convert-vector-to-llvm="vector-contract-lowering=, ..." for all the existing tests and explicitly add the default enum values for the new options that this PR is adding. That would at least test some enum values of the CL API without testing the actual conversion functionality. I wouldn't go with anything that entails adding more CHECK rules as this is out of scope for this PR.

@banach-space
Copy link
Contributor

My point was that we can s/convert-vector-to-llvm/convert-vector-to-llvm="vector-contract-lowering=, ..." for all the existing tests and explicitly add the default enum values for the new options that this PR is adding.

If these are "defaults" then adding them to -convert-vector-to-llvm should not be required, right? As in, that's the default values we are already using today:

/// Structure to control the behavior of vector transform patterns.
struct VectorTransformsOptions {
/// Option to control the lowering of vector.contract.
VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
VectorTransformsOptions &
setVectorTransformsOptions(VectorContractLowering opt) {
vectorContractLowering = opt;
return *this;
}
/// Option to control the lowering of vector.multi_reduction.
VectorMultiReductionLowering vectorMultiReductionLowering =
VectorMultiReductionLowering::InnerParallel;
VectorTransformsOptions &
setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
vectorMultiReductionLowering = opt;
return *this;
}
/// Option to control the lowering of vector.transpose.
VectorTransposeLowering vectorTransposeLowering =
VectorTransposeLowering::EltWise;
VectorTransformsOptions &
setVectorTransposeLowering(VectorTransposeLowering opt) {
vectorTransposeLowering = opt;
return *this;
}
/// Option to control the splitting of vector transfers.
VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
vectorTransferSplit = opt;
return *this;
}
};

Perhaps just update "test-seriasable.mlir" so that it verifies that the intended defaults are indeed used. And then, add a separate RUN line where the defaults are overriden with some other values?

@abulavin
Copy link
Contributor Author

Perhaps just update "test-seriasable.mlir" so that it verifies that the intended defaults are indeed used. And then, add a separate RUN line where the defaults are overriden with some other values?

Good idea, I have just updated this PR to do this. Thanks.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am happy with this direction and really appreciate the effort that you've put into this, Artemiy!

I've left some minor comments for you, but nothing major. Once addressed, lets land this. Thanks!

@@ -10,7 +10,7 @@
#define MLIR_CONVERSION_PASSES

include "mlir/Pass/PassBase.td"

include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the additional include required?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good spot, thank you

Copy link
Contributor

@banach-space banach-space Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that this has been removed and then added back?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it looks unused but it's actually required for this file. You can see some enums from this file are referenced in the description of the options for the ConvertVectorToLLVMPass, to not duplicate the option descriptions. I removed it the first time by mistake.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Am I correct that you don't have commit access?

@abulavin
Copy link
Contributor Author

LGTM, thanks!

Am I correct that you don't have commit access?

Thanks for your thorough and helpful review. I enjoyed our discussion.
No, I do not have write access, could you please merge it for me once CI on c65d803 passes? I had to add back a tablegen header as removing it causes the build to fail.

@abulavin
Copy link
Contributor Author

@banach-space if there are no further issues please can this be merged for me? I do not have write access.

@banach-space
Copy link
Contributor

banach-space commented Mar 10, 2025

@banach-space if there are no further issues please can this be merged for me? I do not have write access.

Sorry about the delay, landing this now. Thanks for bearing with me and for helping us to improve this 🙏🏻

EDIT

Please make sure your e-mail is correctly configured: https://llvm.org/docs/GitHub.html#before-your-first-pr

@abulavin
Copy link
Contributor Author

Please make sure your e-mail is correctly configured: https://llvm.org/docs/GitHub.html#before-your-first-pr

I believe that's all done! Cheers.

@banach-space banach-space merged commit f3dcc0f into llvm:main Mar 10, 2025
7 checks passed
Copy link

@abulavin Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[MLIR][Crash]--convert-vector-to-llvm="enable-x86vector" --dump-pass-pipeline --test-vector-scan-lowering triggers crash.
6 participants