Skip to content

[mlir] Refactor ConvertVectorToLLVMPass options #128219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 10, 2025
Merged
32 changes: 27 additions & 5 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#define MLIR_CONVERSION_PASSES

include "mlir/Pass/PassBase.td"

include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the additional include required?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good spot, thank you

Copy link
Contributor

@banach-space banach-space Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that this has been removed and then added back?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it looks unused but it's actually required for this file. You can see some enums from this file are referenced in the description of the options for the ConvertVectorToLLVMPass, to not duplicate the option descriptions. I removed it the first time by mistake.


//===----------------------------------------------------------------------===//
// ToLLVM
Expand Down Expand Up @@ -1410,10 +1410,32 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
"dialect.">,
Option<"vectorTransformsOptions", "vector-transform-options",
"vector::VectorTransformsOptions",
/*default=*/"vector::VectorTransformsOptions()",
"Options to lower some operations like contractions and transposes.">,
Option<"vectorContractLowering", "vector-contract-lowering",
"vector::VectorContractLowering",
/*default=*/"vector::VectorContractLowering::Dot",
VectorContractLoweringAttr.summary, [{::llvm::cl::values(
clEnumValN(::mlir::vector::VectorContractLowering::Dot, "dot",
"Progressively lower to finer grained `vector.contract` and dot-products. (default)"),
clEnumValN(::mlir::vector::VectorContractLowering::Matmul, "matmul",
"Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics."),
clEnumValN(::mlir::vector::VectorContractLowering::OuterProduct, "outerproduct",
"Lower to `vector.outerproduct`."),
clEnumValN(::mlir::vector::VectorContractLowering::ParallelArith, "parallelarith",
"Lower contract with all reduction dimensions unrolled to 1 to a vector elementwise operations.")
)}]>,
Option<"vectorTransposeLowering", "vector-transpose-lowering",
"vector::VectorTransposeLowering",
/*default=*/"vector::VectorTransposeLowering::EltWise",
VectorTransposeLoweringAttr.summary, [{::llvm::cl::values(
clEnumValN(::mlir::vector::VectorTransposeLowering::EltWise, "eltwise",
"Lower transpose into element-wise extract and inserts (default)"),
clEnumValN(::mlir::vector::VectorTransposeLowering::Flat, "flat",
"Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix intrinsics"),
clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle1D, "shuffle1d",
"Lower 2-D transpose to `vector.shuffle` on 1-D vector."),
clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle16x16, "shuffle16x16",
"Lower 2-D transpose to `vector.shuffle` on 16x16 vector.")
)}]>,
];
}

Expand Down
11 changes: 7 additions & 4 deletions mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
#define MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H

#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"

namespace mlir {
Expand Down Expand Up @@ -47,7 +48,8 @@ namespace vector {
/// Progressively lower a `vector.contract` with row-major matmul semantics to
/// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`.
void populateVectorContractLoweringPatterns(
RewritePatternSet &patterns, VectorTransformsOptions options,
RewritePatternSet &patterns,
VectorContractLowering vectorContractLoweringOption,
PatternBenefit benefit = 1, bool disableOuterProductLowering = false);

/// Populate the pattern set with the following patterns:
Expand Down Expand Up @@ -142,9 +144,10 @@ void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
///
/// [TransposeOp2DToShuffleLowering]
///
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns,
VectorTransformsOptions options,
PatternBenefit benefit = 1);
void populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns,
VectorTransposeLowering vectorTransposeLowering,
PatternBenefit benefit = 1);

/// Populate the pattern set with the following patterns:
///
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorToVectorCanonicalizationPatterns(patterns);
populateVectorBitCastLoweringPatterns(patterns);
populateVectorBroadcastLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns, vectorTransformsOptions);
populateVectorContractLoweringPatterns(patterns, vectorContractLowering);
populateVectorMaskOpLoweringPatterns(patterns);
populateVectorShapeCastLoweringPatterns(patterns);
populateVectorInterleaveLoweringPatterns(patterns);
populateVectorTransposeLoweringPatterns(patterns, vectorTransformsOptions);
populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
populateVectorMaskMaterializationPatterns(patterns,
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1374,9 +1374,8 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
// further transformations to canonicalize/cancel.
{
RewritePatternSet patterns(context);
auto options = vector::VectorTransformsOptions().setVectorTransposeLowering(
vector::VectorTransposeLowering::EltWise);
vector::populateVectorTransposeLoweringPatterns(patterns, options);
vector::populateVectorTransposeLoweringPatterns(
patterns, vector::VectorTransposeLowering::EltWise);
vector::populateVectorShapeCastLoweringPatterns(patterns);
if (failed(applyPatternsGreedily(op, std::move(patterns))))
return failure();
Expand Down
9 changes: 3 additions & 6 deletions mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(

void transform::ApplyLowerContractionPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::VectorTransformsOptions vectorTransformOptions;
vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy());
populateVectorContractLoweringPatterns(patterns, vectorTransformOptions,
populateVectorContractLoweringPatterns(patterns, getLoweringStrategy(),
/*benefit=*/1,
/*disableOuterProductLowering=*/true);
}
Expand Down Expand Up @@ -161,9 +159,8 @@ void transform::ApplyLowerTransferPatternsOp::populatePatterns(

void transform::ApplyLowerTransposePatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorTransposeLoweringPatterns(
patterns, vector::VectorTransformsOptions().setVectorTransposeLowering(
getLoweringStrategy()));
vector::populateVectorTransposeLoweringPatterns(patterns,
getLoweringStrategy());
if (getAvx2LoweringStrategy()) {
auto avx2LoweringOptions =
x86vector::avx2::LoweringOptions().setTransposeOptions(
Expand Down
80 changes: 39 additions & 41 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,13 @@ namespace {
/// ```
/// %flattened_a = vector.shape_cast %a
/// %flattened_b = vector.shape_cast %b
/// %flattened_d = vector.matmul %flattened_a, %flattened_b
/// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
/// %d = vector.shape_cast %%flattened_d
/// %e = add %c, %d
/// ```
/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
//
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
/// This only kicks in when vectorContractLowering is set to Matmul and
/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToMatmulOpLowering
: public vector::MaskableOpRewritePattern<vector::ContractionOp> {
Expand All @@ -236,11 +236,11 @@ class ContractionOpToMatmulOpLowering
}

ContractionOpToMatmulOpLowering(
vector::VectorTransformsOptions vectorTransformOptions,
vector::VectorContractLowering vectorContractLowering,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions),
vectorContractLowering(vectorContractLowering),
filter(std::move(constraint)) {}

FailureOr<Value>
Expand All @@ -249,7 +249,7 @@ class ContractionOpToMatmulOpLowering

private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
vector::VectorContractLowering vectorContractLowering;
FilterConstraintType filter;
};

Expand All @@ -266,7 +266,7 @@ class ContractionOpToMatmulOpLowering
/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
/// ```
///
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
/// This only kicks in when vectorContractLowering is set to OuterProduct and
/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToOuterProductOpLowering
: public MaskableOpRewritePattern<vector::ContractionOp> {
Expand All @@ -281,11 +281,11 @@ class ContractionOpToOuterProductOpLowering
}

ContractionOpToOuterProductOpLowering(
vector::VectorTransformsOptions vectorTransformOptions,
vector::VectorContractLowering vectorContractLowering,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions),
vectorContractLowering(vectorContractLowering),
filter(std::move(constraint)) {}

FailureOr<Value>
Expand All @@ -294,7 +294,7 @@ class ContractionOpToOuterProductOpLowering

private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
vector::VectorContractLowering vectorContractLowering;
FilterConstraintType filter;
};

Expand Down Expand Up @@ -329,19 +329,19 @@ class ContractionOpToDotLowering
}

ContractionOpToDotLowering(
vector::VectorTransformsOptions vectorTransformOptions,
vector::VectorContractLowering vectorContractLowering,
MLIRContext *context, PatternBenefit benefit = 1,
const FilterConstraintType &constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}

FailureOr<Value>
matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override;

private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
vector::VectorContractLowering vectorContractLowering;
FilterConstraintType filter;
};

Expand Down Expand Up @@ -370,11 +370,12 @@ class ContractionOpLowering
return success();
}

ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
ContractionOpLowering(
vector::VectorContractLowering vectorContractLoweringOption,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions),
vectorContractLoweringOption(vectorContractLoweringOption),
filter(std::move(constraint)) {}

FailureOr<Value>
Expand All @@ -383,7 +384,7 @@ class ContractionOpLowering

private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
vector::VectorContractLowering vectorContractLoweringOption;
FilterConstraintType filter;
// Lower one parallel dimension.
FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
Expand Down Expand Up @@ -635,14 +636,13 @@ struct UnrolledOuterProductGenerator
/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
/// ```
///
/// This only kicks in when VectorTransformsOptions is set to OuterProduct but
/// This only kicks in when vectorContractLowering is set to OuterProduct but
/// otherwise supports any layout permutation of the matrix-multiply.
FailureOr<Value>
ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
vector::ContractionOp op, MaskingOpInterface maskOp,
PatternRewriter &rewriter) const {
if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::OuterProduct)
if (vectorContractLowering != vector::VectorContractLowering::OuterProduct)
return failure();

if (failed(filter(op)))
Expand Down Expand Up @@ -672,8 +672,7 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
if (failed(filter(op)))
return failure();

if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::Dot)
if (vectorContractLowering != vector::VectorContractLowering::Dot)
return failure();

auto iteratorTypes = op.getIteratorTypes().getValue();
Expand Down Expand Up @@ -789,11 +788,11 @@ struct ContractOpToElementwise
return success();
}
ContractOpToElementwise(
vector::VectorTransformsOptions vectorTransformOptions,
vector::VectorContractLowering vectorContractLowering,
MLIRContext *context, PatternBenefit benefit = 1,
const FilterConstraintType &constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}

FailureOr<Value>
matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
Expand All @@ -806,8 +805,7 @@ struct ContractOpToElementwise
if (failed(filter(contractOp)))
return failure();

if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::ParallelArith)
if (vectorContractLowering != vector::VectorContractLowering::ParallelArith)
return failure();

ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
Expand Down Expand Up @@ -898,7 +896,7 @@ struct ContractOpToElementwise

private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
vector::VectorContractLowering vectorContractLowering;
FilterConstraintType filter;
};

Expand All @@ -913,7 +911,7 @@ struct ContractOpToElementwise
/// until a pure contraction is reached (no free/batch dimensions),
/// which is replaced by a dot-product.
///
/// This only kicks in when either VectorTransformsOptions is set
/// This only kicks in when either vectorContractLoweringOption is set
/// to DOT or when other contraction patterns fail.
//
// TODO: break down into transpose/reshape/cast ops
Expand Down Expand Up @@ -941,25 +939,25 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
// TODO: implement benefits, cost models.
MLIRContext *ctx = op.getContext();

ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
ContractionOpToMatmulOpLowering pat1(vectorContractLoweringOption, ctx);
FailureOr<Value> newVal1 =
pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
if (!failed(newVal1))
return newVal1;

ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
ContractionOpToOuterProductOpLowering pat2(vectorContractLoweringOption, ctx);
FailureOr<Value> newVal2 =
pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
if (!failed(newVal2))
return newVal2;

ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
ContractionOpToDotLowering pat3(vectorContractLoweringOption, ctx);
FailureOr<Value> newVal3 =
pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
if (!failed(newVal3))
return newVal3;

ContractOpToElementwise pat4(vectorTransformOptions, ctx);
ContractOpToElementwise pat4(vectorContractLoweringOption, ctx);
FailureOr<Value> newVal4 =
pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
if (!failed(newVal4))
Expand Down Expand Up @@ -1273,14 +1271,14 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
/// %mtb = maybe_transpose
/// %flattened_a = vector.shape_cast %mta
/// %flattened_b = vector.shape_cast %mtb
/// %flattened_d = vector.matmul %flattened_a, %flattened_b
/// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
/// %mtd = vector.shape_cast %flattened_d
/// %d = maybe_untranspose %mtd
/// %e = add %c, %d
/// ```
/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
//
/// This only kicks in when VectorTransformsOptions is set to `Matmul`.
/// This only kicks in when vectorContractLowering is set to `Matmul`.
/// vector.transpose operations are inserted if the vector.contract op is not a
/// row-major matrix multiply.
///
Expand All @@ -1292,8 +1290,7 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
if (maskOp)
return failure();

if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::Matmul)
if (vectorContractLowering != vector::VectorContractLowering::Matmul)
return failure();
if (failed(filter(op)))
return failure();
Expand Down Expand Up @@ -1382,13 +1379,14 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
} // namespace

void mlir::vector::populateVectorContractLoweringPatterns(
RewritePatternSet &patterns, VectorTransformsOptions options,
PatternBenefit benefit, bool disableOuterProductLowering) {
RewritePatternSet &patterns,
VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit,
bool disableOuterProductLowering) {
if (!disableOuterProductLowering)
patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
ContractionOpToOuterProductOpLowering>(
options, patterns.getContext(), benefit);
vectorContractLoweringOption, patterns.getContext(), benefit);
}

void mlir::vector::populateVectorOuterProductLoweringPatterns(
Expand Down
Loading