Skip to content

Commit f3dcc0f

Browse files
authored
[mlir] Refactor ConvertVectorToLLVMPass options (#128219)
The `VectorTransformsOptions` on the `ConvertVectorToLLVMPass` is currently represented as a struct, which makes it not serialisable. This means a pass pipeline that contains this pass cannot be represented as textual form, which breaks reproducer generation and options such as `--dump-pass-pipeline`. This PR expands the `VectorTransformsOptions` struct into the two options that are actually used by the Pass' patterns: `vector-contract-lowering` and `vector-transpose-lowering` . The other options present in VectorTransformOptions are not used by any patterns in this pass. Additionally, I have changed some interfaces to only take these specific options over the full options struct as, again, the vector contract and transpose lowering patterns only need one of their respective options. Finally, I have added a simple lit test that just prints the pass pipeline using `--dump-pass-pipeline` to ensure the options on this pass remain serialisable. Fixes #129046
1 parent 8c8eff2 commit f3dcc0f

File tree

8 files changed

+123
-76
lines changed

8 files changed

+123
-76
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#define MLIR_CONVERSION_PASSES
1111

1212
include "mlir/Pass/PassBase.td"
13-
13+
include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
1414

1515
//===----------------------------------------------------------------------===//
1616
// ToLLVM
@@ -1410,10 +1410,32 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
14101410
"bool", /*default=*/"false",
14111411
"Enables the use of X86Vector dialect while lowering the vector "
14121412
"dialect.">,
1413-
Option<"vectorTransformsOptions", "vector-transform-options",
1414-
"vector::VectorTransformsOptions",
1415-
/*default=*/"vector::VectorTransformsOptions()",
1416-
"Options to lower some operations like contractions and transposes.">,
1413+
Option<"vectorContractLowering", "vector-contract-lowering",
1414+
"vector::VectorContractLowering",
1415+
/*default=*/"vector::VectorContractLowering::Dot",
1416+
VectorContractLoweringAttr.summary, [{::llvm::cl::values(
1417+
clEnumValN(::mlir::vector::VectorContractLowering::Dot, "dot",
1418+
"Progressively lower to finer grained `vector.contract` and dot-products. (default)"),
1419+
clEnumValN(::mlir::vector::VectorContractLowering::Matmul, "matmul",
1420+
"Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics."),
1421+
clEnumValN(::mlir::vector::VectorContractLowering::OuterProduct, "outerproduct",
1422+
"Lower to `vector.outerproduct`."),
1423+
clEnumValN(::mlir::vector::VectorContractLowering::ParallelArith, "parallelarith",
1424+
"Lower contract with all reduction dimensions unrolled to 1 to a vector elementwise operations.")
1425+
)}]>,
1426+
Option<"vectorTransposeLowering", "vector-transpose-lowering",
1427+
"vector::VectorTransposeLowering",
1428+
/*default=*/"vector::VectorTransposeLowering::EltWise",
1429+
VectorTransposeLoweringAttr.summary, [{::llvm::cl::values(
1430+
clEnumValN(::mlir::vector::VectorTransposeLowering::EltWise, "eltwise",
1431+
"Lower transpose into element-wise extract and inserts (default)"),
1432+
clEnumValN(::mlir::vector::VectorTransposeLowering::Flat, "flat",
1433+
"Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix intrinsics"),
1434+
clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle1D, "shuffle1d",
1435+
"Lower 2-D transpose to `vector.shuffle` on 1-D vector."),
1436+
clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle16x16, "shuffle16x16",
1437+
"Lower 2-D transpose to `vector.shuffle` on 16x16 vector.")
1438+
)}]>,
14171439
];
14181440
}
14191441

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
1010
#define MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
1111

12+
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
1213
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1314

1415
namespace mlir {
@@ -47,7 +48,8 @@ namespace vector {
4748
/// Progressively lower a `vector.contract` with row-major matmul semantics to
4849
/// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`.
4950
void populateVectorContractLoweringPatterns(
50-
RewritePatternSet &patterns, VectorTransformsOptions options,
51+
RewritePatternSet &patterns,
52+
VectorContractLowering vectorContractLoweringOption,
5153
PatternBenefit benefit = 1, bool disableOuterProductLowering = false);
5254

5355
/// Populate the pattern set with the following patterns:
@@ -142,9 +144,10 @@ void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
142144
///
143145
/// [TransposeOp2DToShuffleLowering]
144146
///
145-
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns,
146-
VectorTransformsOptions options,
147-
PatternBenefit benefit = 1);
147+
void populateVectorTransposeLoweringPatterns(
148+
RewritePatternSet &patterns,
149+
VectorTransposeLowering vectorTransposeLowering,
150+
PatternBenefit benefit = 1);
148151

149152
/// Populate the pattern set with the following patterns:
150153
///

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@ void ConvertVectorToLLVMPass::runOnOperation() {
6969
populateVectorToVectorCanonicalizationPatterns(patterns);
7070
populateVectorBitCastLoweringPatterns(patterns);
7171
populateVectorBroadcastLoweringPatterns(patterns);
72-
populateVectorContractLoweringPatterns(patterns, vectorTransformsOptions);
72+
populateVectorContractLoweringPatterns(patterns, vectorContractLowering);
7373
populateVectorMaskOpLoweringPatterns(patterns);
7474
populateVectorShapeCastLoweringPatterns(patterns);
7575
populateVectorInterleaveLoweringPatterns(patterns);
76-
populateVectorTransposeLoweringPatterns(patterns, vectorTransformsOptions);
76+
populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
7777
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
7878
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
7979
populateVectorMaskMaterializationPatterns(patterns,

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,9 +1374,8 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
13741374
// further transformations to canonicalize/cancel.
13751375
{
13761376
RewritePatternSet patterns(context);
1377-
auto options = vector::VectorTransformsOptions().setVectorTransposeLowering(
1378-
vector::VectorTransposeLowering::EltWise);
1379-
vector::populateVectorTransposeLoweringPatterns(patterns, options);
1377+
vector::populateVectorTransposeLoweringPatterns(
1378+
patterns, vector::VectorTransposeLowering::EltWise);
13801379
vector::populateVectorShapeCastLoweringPatterns(patterns);
13811380
if (failed(applyPatternsGreedily(op, std::move(patterns))))
13821381
return failure();

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,7 @@ void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
102102

103103
void transform::ApplyLowerContractionPatternsOp::populatePatterns(
104104
RewritePatternSet &patterns) {
105-
vector::VectorTransformsOptions vectorTransformOptions;
106-
vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy());
107-
populateVectorContractLoweringPatterns(patterns, vectorTransformOptions,
105+
populateVectorContractLoweringPatterns(patterns, getLoweringStrategy(),
108106
/*benefit=*/1,
109107
/*disableOuterProductLowering=*/true);
110108
}
@@ -161,9 +159,8 @@ void transform::ApplyLowerTransferPatternsOp::populatePatterns(
161159

162160
void transform::ApplyLowerTransposePatternsOp::populatePatterns(
163161
RewritePatternSet &patterns) {
164-
vector::populateVectorTransposeLoweringPatterns(
165-
patterns, vector::VectorTransformsOptions().setVectorTransposeLowering(
166-
getLoweringStrategy()));
162+
vector::populateVectorTransposeLoweringPatterns(patterns,
163+
getLoweringStrategy());
167164
if (getAvx2LoweringStrategy()) {
168165
auto avx2LoweringOptions =
169166
x86vector::avx2::LoweringOptions().setTransposeOptions(

mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp

Lines changed: 39 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -215,13 +215,13 @@ namespace {
215215
/// ```
216216
/// %flattened_a = vector.shape_cast %a
217217
/// %flattened_b = vector.shape_cast %b
218-
/// %flattened_d = vector.matmul %flattened_a, %flattened_b
218+
/// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
219219
/// %d = vector.shape_cast %%flattened_d
220220
/// %e = add %c, %d
221221
/// ```
222-
/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
222+
/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
223223
//
224-
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
224+
/// This only kicks in when vectorContractLowering is set to Matmul and
225225
/// the vector.contract op is a row-major matrix multiply.
226226
class ContractionOpToMatmulOpLowering
227227
: public vector::MaskableOpRewritePattern<vector::ContractionOp> {
@@ -236,11 +236,11 @@ class ContractionOpToMatmulOpLowering
236236
}
237237

238238
ContractionOpToMatmulOpLowering(
239-
vector::VectorTransformsOptions vectorTransformOptions,
239+
vector::VectorContractLowering vectorContractLowering,
240240
MLIRContext *context, PatternBenefit benefit = 1,
241241
FilterConstraintType constraint = defaultFilter)
242242
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
243-
vectorTransformOptions(vectorTransformOptions),
243+
vectorContractLowering(vectorContractLowering),
244244
filter(std::move(constraint)) {}
245245

246246
FailureOr<Value>
@@ -249,7 +249,7 @@ class ContractionOpToMatmulOpLowering
249249

250250
private:
251251
/// Options to control the vector patterns.
252-
vector::VectorTransformsOptions vectorTransformOptions;
252+
vector::VectorContractLowering vectorContractLowering;
253253
FilterConstraintType filter;
254254
};
255255

@@ -266,7 +266,7 @@ class ContractionOpToMatmulOpLowering
266266
/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
267267
/// ```
268268
///
269-
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
269+
/// This only kicks in when vectorContractLowering is set to OuterProduct and
270270
/// the vector.contract op is a row-major matrix multiply.
271271
class ContractionOpToOuterProductOpLowering
272272
: public MaskableOpRewritePattern<vector::ContractionOp> {
@@ -281,11 +281,11 @@ class ContractionOpToOuterProductOpLowering
281281
}
282282

283283
ContractionOpToOuterProductOpLowering(
284-
vector::VectorTransformsOptions vectorTransformOptions,
284+
vector::VectorContractLowering vectorContractLowering,
285285
MLIRContext *context, PatternBenefit benefit = 1,
286286
FilterConstraintType constraint = defaultFilter)
287287
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
288-
vectorTransformOptions(vectorTransformOptions),
288+
vectorContractLowering(vectorContractLowering),
289289
filter(std::move(constraint)) {}
290290

291291
FailureOr<Value>
@@ -294,7 +294,7 @@ class ContractionOpToOuterProductOpLowering
294294

295295
private:
296296
/// Options to control the vector patterns.
297-
vector::VectorTransformsOptions vectorTransformOptions;
297+
vector::VectorContractLowering vectorContractLowering;
298298
FilterConstraintType filter;
299299
};
300300

@@ -329,19 +329,19 @@ class ContractionOpToDotLowering
329329
}
330330

331331
ContractionOpToDotLowering(
332-
vector::VectorTransformsOptions vectorTransformOptions,
332+
vector::VectorContractLowering vectorContractLowering,
333333
MLIRContext *context, PatternBenefit benefit = 1,
334334
const FilterConstraintType &constraint = defaultFilter)
335335
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
336-
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
336+
vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
337337

338338
FailureOr<Value>
339339
matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
340340
PatternRewriter &rewriter) const override;
341341

342342
private:
343343
/// Options to control the vector patterns.
344-
vector::VectorTransformsOptions vectorTransformOptions;
344+
vector::VectorContractLowering vectorContractLowering;
345345
FilterConstraintType filter;
346346
};
347347

@@ -370,11 +370,12 @@ class ContractionOpLowering
370370
return success();
371371
}
372372

373-
ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
374-
MLIRContext *context, PatternBenefit benefit = 1,
375-
FilterConstraintType constraint = defaultFilter)
373+
ContractionOpLowering(
374+
vector::VectorContractLowering vectorContractLoweringOption,
375+
MLIRContext *context, PatternBenefit benefit = 1,
376+
FilterConstraintType constraint = defaultFilter)
376377
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
377-
vectorTransformOptions(vectorTransformOptions),
378+
vectorContractLoweringOption(vectorContractLoweringOption),
378379
filter(std::move(constraint)) {}
379380

380381
FailureOr<Value>
@@ -383,7 +384,7 @@ class ContractionOpLowering
383384

384385
private:
385386
/// Options to control the vector patterns.
386-
vector::VectorTransformsOptions vectorTransformOptions;
387+
vector::VectorContractLowering vectorContractLoweringOption;
387388
FilterConstraintType filter;
388389
// Lower one parallel dimension.
389390
FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
@@ -635,14 +636,13 @@ struct UnrolledOuterProductGenerator
635636
/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
636637
/// ```
637638
///
638-
/// This only kicks in when VectorTransformsOptions is set to OuterProduct but
639+
/// This only kicks in when vectorContractLowering is set to OuterProduct but
639640
/// otherwise supports any layout permutation of the matrix-multiply.
640641
FailureOr<Value>
641642
ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
642643
vector::ContractionOp op, MaskingOpInterface maskOp,
643644
PatternRewriter &rewriter) const {
644-
if (vectorTransformOptions.vectorContractLowering !=
645-
vector::VectorContractLowering::OuterProduct)
645+
if (vectorContractLowering != vector::VectorContractLowering::OuterProduct)
646646
return failure();
647647

648648
if (failed(filter(op)))
@@ -672,8 +672,7 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
672672
if (failed(filter(op)))
673673
return failure();
674674

675-
if (vectorTransformOptions.vectorContractLowering !=
676-
vector::VectorContractLowering::Dot)
675+
if (vectorContractLowering != vector::VectorContractLowering::Dot)
677676
return failure();
678677

679678
auto iteratorTypes = op.getIteratorTypes().getValue();
@@ -789,11 +788,11 @@ struct ContractOpToElementwise
789788
return success();
790789
}
791790
ContractOpToElementwise(
792-
vector::VectorTransformsOptions vectorTransformOptions,
791+
vector::VectorContractLowering vectorContractLowering,
793792
MLIRContext *context, PatternBenefit benefit = 1,
794793
const FilterConstraintType &constraint = defaultFilter)
795794
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
796-
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
795+
vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
797796

798797
FailureOr<Value>
799798
matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
@@ -806,8 +805,7 @@ struct ContractOpToElementwise
806805
if (failed(filter(contractOp)))
807806
return failure();
808807

809-
if (vectorTransformOptions.vectorContractLowering !=
810-
vector::VectorContractLowering::ParallelArith)
808+
if (vectorContractLowering != vector::VectorContractLowering::ParallelArith)
811809
return failure();
812810

813811
ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
@@ -898,7 +896,7 @@ struct ContractOpToElementwise
898896

899897
private:
900898
/// Options to control the vector patterns.
901-
vector::VectorTransformsOptions vectorTransformOptions;
899+
vector::VectorContractLowering vectorContractLowering;
902900
FilterConstraintType filter;
903901
};
904902

@@ -913,7 +911,7 @@ struct ContractOpToElementwise
913911
/// until a pure contraction is reached (no free/batch dimensions),
914912
/// which is replaced by a dot-product.
915913
///
916-
/// This only kicks in when either VectorTransformsOptions is set
914+
/// This only kicks in when either vectorContractLoweringOption is set
917915
/// to DOT or when other contraction patterns fail.
918916
//
919917
// TODO: break down into transpose/reshape/cast ops
@@ -941,25 +939,25 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
941939
// TODO: implement benefits, cost models.
942940
MLIRContext *ctx = op.getContext();
943941

944-
ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
942+
ContractionOpToMatmulOpLowering pat1(vectorContractLoweringOption, ctx);
945943
FailureOr<Value> newVal1 =
946944
pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
947945
if (!failed(newVal1))
948946
return newVal1;
949947

950-
ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
948+
ContractionOpToOuterProductOpLowering pat2(vectorContractLoweringOption, ctx);
951949
FailureOr<Value> newVal2 =
952950
pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
953951
if (!failed(newVal2))
954952
return newVal2;
955953

956-
ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
954+
ContractionOpToDotLowering pat3(vectorContractLoweringOption, ctx);
957955
FailureOr<Value> newVal3 =
958956
pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
959957
if (!failed(newVal3))
960958
return newVal3;
961959

962-
ContractOpToElementwise pat4(vectorTransformOptions, ctx);
960+
ContractOpToElementwise pat4(vectorContractLoweringOption, ctx);
963961
FailureOr<Value> newVal4 =
964962
pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
965963
if (!failed(newVal4))
@@ -1273,14 +1271,14 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
12731271
/// %mtb = maybe_transpose
12741272
/// %flattened_a = vector.shape_cast %mta
12751273
/// %flattened_b = vector.shape_cast %mtb
1276-
/// %flattened_d = vector.matmul %flattened_a, %flattened_b
1274+
/// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
12771275
/// %mtd = vector.shape_cast %flattened_d
12781276
/// %d = maybe_untranspose %mtd
12791277
/// %e = add %c, %d
12801278
/// ```
1281-
/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1279+
/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
12821280
//
1283-
/// This only kicks in when VectorTransformsOptions is set to `Matmul`.
1281+
/// This only kicks in when vectorContractLowering is set to `Matmul`.
12841282
/// vector.transpose operations are inserted if the vector.contract op is not a
12851283
/// row-major matrix multiply.
12861284
///
@@ -1292,8 +1290,7 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
12921290
if (maskOp)
12931291
return failure();
12941292

1295-
if (vectorTransformOptions.vectorContractLowering !=
1296-
vector::VectorContractLowering::Matmul)
1293+
if (vectorContractLowering != vector::VectorContractLowering::Matmul)
12971294
return failure();
12981295
if (failed(filter(op)))
12991296
return failure();
@@ -1382,13 +1379,14 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
13821379
} // namespace
13831380

13841381
void mlir::vector::populateVectorContractLoweringPatterns(
1385-
RewritePatternSet &patterns, VectorTransformsOptions options,
1386-
PatternBenefit benefit, bool disableOuterProductLowering) {
1382+
RewritePatternSet &patterns,
1383+
VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit,
1384+
bool disableOuterProductLowering) {
13871385
if (!disableOuterProductLowering)
13881386
patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
13891387
patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
13901388
ContractionOpToOuterProductOpLowering>(
1391-
options, patterns.getContext(), benefit);
1389+
vectorContractLoweringOption, patterns.getContext(), benefit);
13921390
}
13931391

13941392
void mlir::vector::populateVectorOuterProductLoweringPatterns(

0 commit comments

Comments
 (0)