Skip to content

[mlir] Refactor ConvertVectorToLLVMPass options #128219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 10, 2025
Merged
32 changes: 27 additions & 5 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#define MLIR_CONVERSION_PASSES

include "mlir/Pass/PassBase.td"

include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the additional include required?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good spot, thank you

Copy link
Contributor

@banach-space banach-space Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that this has been removed and then added back?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it looks unused but it's actually required for this file. You can see some enums from this file are referenced in the description of the options for the ConvertVectorToLLVMPass, to not duplicate the option descriptions. I removed it the first time by mistake.


//===----------------------------------------------------------------------===//
// ToLLVM
Expand Down Expand Up @@ -1410,10 +1410,32 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
"dialect.">,
Option<"vectorTransformsOptions", "vector-transform-options",
"vector::VectorTransformsOptions",
/*default=*/"vector::VectorTransformsOptions()",
"Options to lower some operations like contractions and transposes.">,
Option<"vectorContractLowering", "vector-contract-lowering",
"vector::VectorContractLowering",
/*default=*/"vector::VectorContractLowering::Dot",
VectorContractLoweringAttr.summary, [{::llvm::cl::values(
clEnumValN(::mlir::vector::VectorContractLowering::Dot, "dot",
"Progressively lower to finer grained `vector.contract` and dot-products. (default)"),
clEnumValN(::mlir::vector::VectorContractLowering::Matmul, "matmul",
"Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics."),
clEnumValN(::mlir::vector::VectorContractLowering::OuterProduct, "outerproduct",
"Lower to `vector.outerproduct`."),
clEnumValN(::mlir::vector::VectorContractLowering::ParallelArith, "parallelarith",
"Lower contract with all reduction dimensions unrolled to 1 to a vector elementwise operations.")
)}]>,
Option<"vectorTransposeLowering", "vector-transpose-lowering",
"vector::VectorTransposeLowering",
/*default=*/"vector::VectorTransposeLowering::EltWise",
VectorTransposeLoweringAttr.summary, [{::llvm::cl::values(
clEnumValN(::mlir::vector::VectorTransposeLowering::EltWise, "eltwise",
"Lower transpose into element-wise extract and inserts (default)"),
clEnumValN(::mlir::vector::VectorTransposeLowering::Flat, "flat",
"Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix intrinsics"),
clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle1D, "shuffle1d",
"Lower 2-D transpose to `vector.shuffle` on 1-D vector."),
clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle16x16, "shuffle16x16",
"Lower 2-D transpose to `vector.shuffle` on 16x16 vector.")
)}]>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO, this is becoming too verbose and too fine-grained - do we need this?

In general, there are three different ways to drive this pass (and the relevant transformations):

  • C++ API
  • Transform Dialect (this is what we tend to use in tests)
  • LLVM command-line (llvm::cl) flags (added in this PR)

From my perspective, llvm::cl flags are mostly useful for testing and prototyping. However, in practice, we seem to favor Transform Dialect for finer-grained control. That said, this may be a matter of personal preference.

My bigger concern is that we are not testing these flags. If they duplicate existing Transform Dialect functionality, we should ensure that relevant RUN lines are duplicated, rather than adding tests only to check the flags themselves.

This is a broader issue in MLIR and somewhat tangential to this PR (which is attempting to fix the issue reported here: MLIR: Specifying Structs as cl::Options for PassOptions).

Let’s wait for others to chime in. More generally, I hope we can establish clearer guidelines on the relationship between different ways of "driving" transformations and preferred testing approaches.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be news to me that transform dialect is the preferred way of testing something in MLIR: so far it's been an infrastructure on the side, not the primary way to access functionality (which remains passes and pass pipelines).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous PR added these options to the pass initially so I'm assuming that this implies the need for this kind of control?

I'm also unfamiliar with using the transform dialect for testing -- my understanding was that pass pipelines were generally tested using invocations of mlir-opt. Indeed, there are a large number of test passes registered here that would suggest so, unless I've misunderstood. If yourself and the other reviewers would like me to write some tests as using the transform interpreter than I am happy to do so.

That being said, when I was writing this PR I attempted to write a test for these options. However, as someone who has limited knowledge of the vector to llvm lowering, it's not obvious to me how to correctly isolate and test the effect of a single pattern option on the entirety of the ConvertVectorToLLVMPass. For example, it's hard for me to say (and seems like a brittle test) what the resulting LLVM IR looks like when vector.contract is lowered with the dot option versus the matmul option, when this is only one of many patterns in the pass. As a result, my hope is that the correctness of these specific patterns in this PR is covered by existing tests in the vector dialect.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are only two knobs which make a huge difference in the transformations that happen within the pass. It just happens that they are enums and have a bunch of values. I think we should support them.

Would it be acceptable to add the default enum values to tests already using this pass and also add more flag combinations to the dump pipeline test? I'm afraid that if we widely test all the combinations on existing tests we get stuck with failures.

It would be news to me that transform dialect is the preferred way of testing something in MLIR:

Well, for better or worse, quite a few tests in the Vector dialect were moved from pipeline to TD so they are no longer tested on the pipeline side.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The #123491 added these options to the pass initially so I'm assuming that this implies the need for this kind of control?

Well, the lack of upstream testing would suggest otherwise. :)

Now, there's a good reason for that - these transformations are tested via Transform Dialect (TD). However, without explicit tests, it becomes very difficult to distinguish between dead code and actively used functionality. And that’s my biggest concern.

That said, the test you added for serializing the options is a great idea to mitigate this - thanks!

Still, there’s no clear indication (in-tree) that this functionality is actually needed.


Indeed, there are a large number of test passes registered here that would suggest so, unless I've misunderstood.

Yes, but there’s a comparable number of Transform Dialect Ops.

Let’s be a bit more specific:

  • 27 test files in mlir/test/Dialect/Tensor7 use TD.
  • 158 test files in mlir/test/Dialect/Linalg79 use TD.
  • 71 test files in mlir/test/Dialect/Vector29 use TD.

(I focused on the "Tensor compiler" specifically.)

So, TD is not the dominant driver for test pipelines, but it's also not just a side infrastructure. I suspect it’s more commonly used where fine-grained control is required - for example, here.

Now, note that this PR adds:

clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle1D, "shuffle1d",
            "Lower 2-D transpose to `vector.shuffle` on 1-D vector."),

So the key question is:

  • Do we duplicate the test to verify that the newly supported flag works as expected?

As a result, my hope is that the correctness of these specific patterns in this PR is covered by existing tests in the vector dialect.

Absolutely! Adding new Vector tests would be way beyond the scope of this PR 😅

Would it be acceptable to add the default enum values to tests already using this pass and also add more flag combinations to the dump pipeline test?

I'm not sure I follow ...

My thinking is more along the lines of adding new RUN lines to e.g. https://github.com/llvm/llvm-project/blob/60cc3af0d93ecb8bfc9d6bebc6cbc395df3bb4b6/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir. For example, we could add:

mlir-opt -vector-transpose-lowering=shuffle1d %s | FileCheck --check-prefix=SHUFFLE_1D

while still keeping the existing:

// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s

The downside? check-prefix explosion :(


@dcaballe, I am guessing that you will use this downstream? If yes, then lets merge this - we should aim to progress

ASAP. Also, I am not really providing any viable alternatives.

Copy link
Contributor Author

@abulavin abulavin Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@banach-space What about the alternative of adding tests like:

// RUN: mlir-opt -convert-vector-to-llvm="vector-contract-lowering=dot" %s | FileCheck %s

func.func @vector_contract(%arg0: vector<2x2xf32>, %arg1: vector<2xf32>, %arg2: vector<2xf32>) -> vector<2xf32> {
    %out = vector.contract {
    // .. and so on
}

// CHECK-LABEL: @vector_contract
// Insert various checks on the resulting LLVMIR here

This does still suffer from the check-prefix explosion but does thoroughly test the flags and the effect of these options on the overall ConvertToLLVMPass pretty unambiguously. There's also no duplication with existing vector to LLVM conversion tests because these are new options for the pass. In my opinion these are the most suitable kinds of tests but do require some upfront work for me to properly understand the transformations before I can write tests for them. I've not written Vector to LLVM conversion tests before though so I would like some thoughts on whether this is worth the effort.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion!

There's also no duplication with existing vector to LLVM conversion tests because these are new options for the pass.

Note, we actually do have "options" for these, see:

(this is a TD Op rather than an "option" - the point is that it exposes exactly that functionality).

As for test duplication, this is already tested here:

transform.apply_patterns to %f {
transform.apply_patterns.vector.lower_contraction lowering_strategy = "dot"
} : !transform.any_op
transform.yield

This brings me back to my original point, there are effectively 3 ways to drive most (all?) transformations:

  • C++ API
  • TD Ops
  • LLVM command-line (llvm::cl) flags (added in this PR)

While there's plenty of duplication throughout wider LLVM, we should actively be thinking how to reduce that.

];
}

Expand Down
11 changes: 7 additions & 4 deletions mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
#define MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H

#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"

namespace mlir {
Expand Down Expand Up @@ -47,7 +48,8 @@ namespace vector {
/// Progressively lower a `vector.contract` with row-major matmul semantics to
/// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`.
void populateVectorContractLoweringPatterns(
RewritePatternSet &patterns, VectorTransformsOptions options,
RewritePatternSet &patterns,
VectorContractLowering vectorContractLoweringOption,
PatternBenefit benefit = 1, bool disableOuterProductLowering = false);

/// Populate the pattern set with the following patterns:
Expand Down Expand Up @@ -142,9 +144,10 @@ void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
///
/// [TransposeOp2DToShuffleLowering]
///
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns,
VectorTransformsOptions options,
PatternBenefit benefit = 1);
void populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns,
VectorTransposeLowering vectorTransposeLowering,
PatternBenefit benefit = 1);

/// Populate the pattern set with the following patterns:
///
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorToVectorCanonicalizationPatterns(patterns);
populateVectorBitCastLoweringPatterns(patterns);
populateVectorBroadcastLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns, vectorTransformsOptions);
populateVectorContractLoweringPatterns(patterns, vectorContractLowering);
populateVectorMaskOpLoweringPatterns(patterns);
populateVectorShapeCastLoweringPatterns(patterns);
populateVectorInterleaveLoweringPatterns(patterns);
populateVectorTransposeLoweringPatterns(patterns, vectorTransformsOptions);
populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
populateVectorMaskMaterializationPatterns(patterns,
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1374,9 +1374,8 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
// further transformations to canonicalize/cancel.
{
RewritePatternSet patterns(context);
auto options = vector::VectorTransformsOptions().setVectorTransposeLowering(
vector::VectorTransposeLowering::EltWise);
vector::populateVectorTransposeLoweringPatterns(patterns, options);
vector::populateVectorTransposeLoweringPatterns(
patterns, vector::VectorTransposeLowering::EltWise);
vector::populateVectorShapeCastLoweringPatterns(patterns);
if (failed(applyPatternsGreedily(op, std::move(patterns))))
return failure();
Expand Down
9 changes: 3 additions & 6 deletions mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(

void transform::ApplyLowerContractionPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::VectorTransformsOptions vectorTransformOptions;
vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy());
populateVectorContractLoweringPatterns(patterns, vectorTransformOptions,
populateVectorContractLoweringPatterns(patterns, getLoweringStrategy(),
/*benefit=*/1,
/*disableOuterProductLowering=*/true);
}
Expand Down Expand Up @@ -161,9 +159,8 @@ void transform::ApplyLowerTransferPatternsOp::populatePatterns(

void transform::ApplyLowerTransposePatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorTransposeLoweringPatterns(
patterns, vector::VectorTransformsOptions().setVectorTransposeLowering(
getLoweringStrategy()));
vector::populateVectorTransposeLoweringPatterns(patterns,
getLoweringStrategy());
if (getAvx2LoweringStrategy()) {
auto avx2LoweringOptions =
x86vector::avx2::LoweringOptions().setTransposeOptions(
Expand Down
72 changes: 35 additions & 37 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ namespace {
/// ```
/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
//
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
/// This only kicks in when vectorContractLowering is set to Matmul and
/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToMatmulOpLowering
: public vector::MaskableOpRewritePattern<vector::ContractionOp> {
Expand All @@ -236,11 +236,11 @@ class ContractionOpToMatmulOpLowering
}

ContractionOpToMatmulOpLowering(
vector::VectorTransformsOptions vectorTransformOptions,
vector::VectorContractLowering vectorContractLowering,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions),
vectorContractLowering(vectorContractLowering),
filter(std::move(constraint)) {}

FailureOr<Value>
Expand All @@ -249,7 +249,7 @@ class ContractionOpToMatmulOpLowering

private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
vector::VectorContractLowering vectorContractLowering;
FilterConstraintType filter;
};

Expand All @@ -266,7 +266,7 @@ class ContractionOpToMatmulOpLowering
/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
/// ```
///
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
/// This only kicks in when vectorContractLowering is set to OuterProduct and
/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToOuterProductOpLowering
: public MaskableOpRewritePattern<vector::ContractionOp> {
Expand All @@ -281,11 +281,11 @@ class ContractionOpToOuterProductOpLowering
}

ContractionOpToOuterProductOpLowering(
vector::VectorTransformsOptions vectorTransformOptions,
vector::VectorContractLowering vectorContractLowering,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions),
vectorContractLowering(vectorContractLowering),
filter(std::move(constraint)) {}

FailureOr<Value>
Expand All @@ -294,7 +294,7 @@ class ContractionOpToOuterProductOpLowering

private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
vector::VectorContractLowering vectorContractLowering;
FilterConstraintType filter;
};

Expand Down Expand Up @@ -329,19 +329,19 @@ class ContractionOpToDotLowering
}

ContractionOpToDotLowering(
vector::VectorTransformsOptions vectorTransformOptions,
vector::VectorContractLowering vectorContractLowering,
MLIRContext *context, PatternBenefit benefit = 1,
const FilterConstraintType &constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}

FailureOr<Value>
matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override;

private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
vector::VectorContractLowering vectorContractLowering;
FilterConstraintType filter;
};

Expand Down Expand Up @@ -370,11 +370,12 @@ class ContractionOpLowering
return success();
}

ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
ContractionOpLowering(
vector::VectorContractLowering vectorContractLoweringOption,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions),
vectorContractLoweringOption(vectorContractLoweringOption),
filter(std::move(constraint)) {}

FailureOr<Value>
Expand All @@ -383,7 +384,7 @@ class ContractionOpLowering

private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
vector::VectorContractLowering vectorContractLoweringOption;
FilterConstraintType filter;
// Lower one parallel dimension.
FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
Expand Down Expand Up @@ -635,14 +636,13 @@ struct UnrolledOuterProductGenerator
/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
/// ```
///
/// This only kicks in when VectorTransformsOptions is set to OuterProduct but
/// This only kicks in when vectorContractLowering is set to OuterProduct but
/// otherwise supports any layout permutation of the matrix-multiply.
FailureOr<Value>
ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
vector::ContractionOp op, MaskingOpInterface maskOp,
PatternRewriter &rewriter) const {
if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::OuterProduct)
if (vectorContractLowering != vector::VectorContractLowering::OuterProduct)
return failure();

if (failed(filter(op)))
Expand Down Expand Up @@ -672,8 +672,7 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
if (failed(filter(op)))
return failure();

if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::Dot)
if (vectorContractLowering != vector::VectorContractLowering::Dot)
return failure();

auto iteratorTypes = op.getIteratorTypes().getValue();
Expand Down Expand Up @@ -789,11 +788,11 @@ struct ContractOpToElementwise
return success();
}
ContractOpToElementwise(
vector::VectorTransformsOptions vectorTransformOptions,
vector::VectorContractLowering vectorContractLowering,
MLIRContext *context, PatternBenefit benefit = 1,
const FilterConstraintType &constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}

FailureOr<Value>
matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
Expand All @@ -806,8 +805,7 @@ struct ContractOpToElementwise
if (failed(filter(contractOp)))
return failure();

if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::ParallelArith)
if (vectorContractLowering != vector::VectorContractLowering::ParallelArith)
return failure();

ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
Expand Down Expand Up @@ -898,7 +896,7 @@ struct ContractOpToElementwise

private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
vector::VectorContractLowering vectorContractLowering;
FilterConstraintType filter;
};

Expand All @@ -913,7 +911,7 @@ struct ContractOpToElementwise
/// until a pure contraction is reached (no free/batch dimensions),
/// which is replaced by a dot-product.
///
/// This only kicks in when either VectorTransformsOptions is set
/// This only kicks in when either vectorContractLoweringOption is set
/// to DOT or when other contraction patterns fail.
//
// TODO: break down into transpose/reshape/cast ops
Expand Down Expand Up @@ -941,25 +939,25 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
// TODO: implement benefits, cost models.
MLIRContext *ctx = op.getContext();

ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
ContractionOpToMatmulOpLowering pat1(vectorContractLoweringOption, ctx);
FailureOr<Value> newVal1 =
pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
if (!failed(newVal1))
return newVal1;

ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
ContractionOpToOuterProductOpLowering pat2(vectorContractLoweringOption, ctx);
FailureOr<Value> newVal2 =
pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
if (!failed(newVal2))
return newVal2;

ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
ContractionOpToDotLowering pat3(vectorContractLoweringOption, ctx);
FailureOr<Value> newVal3 =
pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
if (!failed(newVal3))
return newVal3;

ContractOpToElementwise pat4(vectorTransformOptions, ctx);
ContractOpToElementwise pat4(vectorContractLoweringOption, ctx);
FailureOr<Value> newVal4 =
pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
if (!failed(newVal4))
Expand Down Expand Up @@ -1280,7 +1278,7 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
/// ```
/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
//
/// This only kicks in when VectorTransformsOptions is set to `Matmul`.
/// This only kicks in when vectorContractLowering is set to `Matmul`.
/// vector.transpose operations are inserted if the vector.contract op is not a
/// row-major matrix multiply.
///
Expand All @@ -1292,8 +1290,7 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
if (maskOp)
return failure();

if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::Matmul)
if (vectorContractLowering != vector::VectorContractLowering::Matmul)
return failure();
if (failed(filter(op)))
return failure();
Expand Down Expand Up @@ -1382,13 +1379,14 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
} // namespace

void mlir::vector::populateVectorContractLoweringPatterns(
RewritePatternSet &patterns, VectorTransformsOptions options,
PatternBenefit benefit, bool disableOuterProductLowering) {
RewritePatternSet &patterns,
VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit,
bool disableOuterProductLowering) {
if (!disableOuterProductLowering)
patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
ContractionOpToOuterProductOpLowering>(
options, patterns.getContext(), benefit);
vectorContractLoweringOption, patterns.getContext(), benefit);
}

void mlir::vector::populateVectorOuterProductLoweringPatterns(
Expand Down
Loading