@@ -215,13 +215,13 @@ namespace {
215
215
// / ```
216
216
// / %flattened_a = vector.shape_cast %a
217
217
// / %flattened_b = vector.shape_cast %b
218
- // / %flattened_d = vector.matmul %flattened_a, %flattened_b
218
+ // / %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
219
219
// / %d = vector.shape_cast %%flattened_d
220
220
// / %e = add %c, %d
221
221
// / ```
222
- // / `vector.matmul ` later lowers to `llvm.matrix.multiply`.
222
+ // / `vector.matrix_multiply ` later lowers to `llvm.matrix.multiply`.
223
223
//
224
- // / This only kicks in when VectorTransformsOptions is set to OuterProduct and
224
+ // / This only kicks in when vectorContractLowering is set to Matmul and
225
225
// / the vector.contract op is a row-major matrix multiply.
226
226
class ContractionOpToMatmulOpLowering
227
227
: public vector::MaskableOpRewritePattern<vector::ContractionOp> {
@@ -236,11 +236,11 @@ class ContractionOpToMatmulOpLowering
236
236
}
237
237
238
238
ContractionOpToMatmulOpLowering (
239
- vector::VectorTransformsOptions vectorTransformOptions ,
239
+ vector::VectorContractLowering vectorContractLowering ,
240
240
MLIRContext *context, PatternBenefit benefit = 1 ,
241
241
FilterConstraintType constraint = defaultFilter)
242
242
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
243
- vectorTransformOptions (vectorTransformOptions ),
243
+ vectorContractLowering (vectorContractLowering ),
244
244
filter (std::move(constraint)) {}
245
245
246
246
FailureOr<Value>
@@ -249,7 +249,7 @@ class ContractionOpToMatmulOpLowering
249
249
250
250
private:
251
251
// / Options to control the vector patterns.
252
- vector::VectorTransformsOptions vectorTransformOptions ;
252
+ vector::VectorContractLowering vectorContractLowering ;
253
253
FilterConstraintType filter;
254
254
};
255
255
@@ -266,7 +266,7 @@ class ContractionOpToMatmulOpLowering
266
266
// / %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
267
267
// / ```
268
268
// /
269
- // / This only kicks in when VectorTransformsOptions is set to OuterProduct and
269
+ // / This only kicks in when vectorContractLowering is set to OuterProduct and
270
270
// / the vector.contract op is a row-major matrix multiply.
271
271
class ContractionOpToOuterProductOpLowering
272
272
: public MaskableOpRewritePattern<vector::ContractionOp> {
@@ -281,11 +281,11 @@ class ContractionOpToOuterProductOpLowering
281
281
}
282
282
283
283
ContractionOpToOuterProductOpLowering (
284
- vector::VectorTransformsOptions vectorTransformOptions ,
284
+ vector::VectorContractLowering vectorContractLowering ,
285
285
MLIRContext *context, PatternBenefit benefit = 1 ,
286
286
FilterConstraintType constraint = defaultFilter)
287
287
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
288
- vectorTransformOptions (vectorTransformOptions ),
288
+ vectorContractLowering (vectorContractLowering ),
289
289
filter (std::move(constraint)) {}
290
290
291
291
FailureOr<Value>
@@ -294,7 +294,7 @@ class ContractionOpToOuterProductOpLowering
294
294
295
295
private:
296
296
// / Options to control the vector patterns.
297
- vector::VectorTransformsOptions vectorTransformOptions ;
297
+ vector::VectorContractLowering vectorContractLowering ;
298
298
FilterConstraintType filter;
299
299
};
300
300
@@ -329,19 +329,19 @@ class ContractionOpToDotLowering
329
329
}
330
330
331
331
ContractionOpToDotLowering (
332
- vector::VectorTransformsOptions vectorTransformOptions ,
332
+ vector::VectorContractLowering vectorContractLowering ,
333
333
MLIRContext *context, PatternBenefit benefit = 1 ,
334
334
const FilterConstraintType &constraint = defaultFilter)
335
335
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
336
- vectorTransformOptions (vectorTransformOptions ), filter(defaultFilter) {}
336
+ vectorContractLowering (vectorContractLowering ), filter(defaultFilter) {}
337
337
338
338
FailureOr<Value>
339
339
matchAndRewriteMaskableOp (vector::ContractionOp op, MaskingOpInterface maskOp,
340
340
PatternRewriter &rewriter) const override ;
341
341
342
342
private:
343
343
// / Options to control the vector patterns.
344
- vector::VectorTransformsOptions vectorTransformOptions ;
344
+ vector::VectorContractLowering vectorContractLowering ;
345
345
FilterConstraintType filter;
346
346
};
347
347
@@ -370,11 +370,12 @@ class ContractionOpLowering
370
370
return success ();
371
371
}
372
372
373
- ContractionOpLowering (vector::VectorTransformsOptions vectorTransformOptions,
374
- MLIRContext *context, PatternBenefit benefit = 1 ,
375
- FilterConstraintType constraint = defaultFilter)
373
+ ContractionOpLowering (
374
+ vector::VectorContractLowering vectorContractLoweringOption,
375
+ MLIRContext *context, PatternBenefit benefit = 1 ,
376
+ FilterConstraintType constraint = defaultFilter)
376
377
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
377
- vectorTransformOptions (vectorTransformOptions ),
378
+ vectorContractLoweringOption (vectorContractLoweringOption ),
378
379
filter (std::move(constraint)) {}
379
380
380
381
FailureOr<Value>
@@ -383,7 +384,7 @@ class ContractionOpLowering
383
384
384
385
private:
385
386
// / Options to control the vector patterns.
386
- vector::VectorTransformsOptions vectorTransformOptions ;
387
+ vector::VectorContractLowering vectorContractLoweringOption ;
387
388
FilterConstraintType filter;
388
389
// Lower one parallel dimension.
389
390
FailureOr<Value> lowerParallel (PatternRewriter &rewriter,
@@ -635,14 +636,13 @@ struct UnrolledOuterProductGenerator
635
636
// / %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
636
637
// / ```
637
638
// /
638
- // / This only kicks in when VectorTransformsOptions is set to OuterProduct but
639
+ // / This only kicks in when vectorContractLowering is set to OuterProduct but
639
640
// / otherwise supports any layout permutation of the matrix-multiply.
640
641
FailureOr<Value>
641
642
ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp (
642
643
vector::ContractionOp op, MaskingOpInterface maskOp,
643
644
PatternRewriter &rewriter) const {
644
- if (vectorTransformOptions.vectorContractLowering !=
645
- vector::VectorContractLowering::OuterProduct)
645
+ if (vectorContractLowering != vector::VectorContractLowering::OuterProduct)
646
646
return failure ();
647
647
648
648
if (failed (filter (op)))
@@ -672,8 +672,7 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
672
672
if (failed (filter (op)))
673
673
return failure ();
674
674
675
- if (vectorTransformOptions.vectorContractLowering !=
676
- vector::VectorContractLowering::Dot)
675
+ if (vectorContractLowering != vector::VectorContractLowering::Dot)
677
676
return failure ();
678
677
679
678
auto iteratorTypes = op.getIteratorTypes ().getValue ();
@@ -789,11 +788,11 @@ struct ContractOpToElementwise
789
788
return success ();
790
789
}
791
790
ContractOpToElementwise (
792
- vector::VectorTransformsOptions vectorTransformOptions ,
791
+ vector::VectorContractLowering vectorContractLowering ,
793
792
MLIRContext *context, PatternBenefit benefit = 1 ,
794
793
const FilterConstraintType &constraint = defaultFilter)
795
794
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
796
- vectorTransformOptions (vectorTransformOptions ), filter(defaultFilter) {}
795
+ vectorContractLowering (vectorContractLowering ), filter(defaultFilter) {}
797
796
798
797
FailureOr<Value>
799
798
matchAndRewriteMaskableOp (vector::ContractionOp contractOp,
@@ -806,8 +805,7 @@ struct ContractOpToElementwise
806
805
if (failed (filter (contractOp)))
807
806
return failure ();
808
807
809
- if (vectorTransformOptions.vectorContractLowering !=
810
- vector::VectorContractLowering::ParallelArith)
808
+ if (vectorContractLowering != vector::VectorContractLowering::ParallelArith)
811
809
return failure ();
812
810
813
811
ArrayRef<int64_t > lhsShape = contractOp.getLhsType ().getShape ();
@@ -898,7 +896,7 @@ struct ContractOpToElementwise
898
896
899
897
private:
900
898
// / Options to control the vector patterns.
901
- vector::VectorTransformsOptions vectorTransformOptions ;
899
+ vector::VectorContractLowering vectorContractLowering ;
902
900
FilterConstraintType filter;
903
901
};
904
902
@@ -913,7 +911,7 @@ struct ContractOpToElementwise
913
911
// / until a pure contraction is reached (no free/batch dimensions),
914
912
// / which is replaced by a dot-product.
915
913
// /
916
- // / This only kicks in when either VectorTransformsOptions is set
914
+ // / This only kicks in when either vectorContractLoweringOption is set
917
915
// / to DOT or when other contraction patterns fail.
918
916
//
919
917
// TODO: break down into transpose/reshape/cast ops
@@ -941,25 +939,25 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
941
939
// TODO: implement benefits, cost models.
942
940
MLIRContext *ctx = op.getContext ();
943
941
944
- ContractionOpToMatmulOpLowering pat1 (vectorTransformOptions , ctx);
942
+ ContractionOpToMatmulOpLowering pat1 (vectorContractLoweringOption , ctx);
945
943
FailureOr<Value> newVal1 =
946
944
pat1.matchAndRewriteMaskableOp (op, maskOp, rewriter);
947
945
if (!failed (newVal1))
948
946
return newVal1;
949
947
950
- ContractionOpToOuterProductOpLowering pat2 (vectorTransformOptions , ctx);
948
+ ContractionOpToOuterProductOpLowering pat2 (vectorContractLoweringOption , ctx);
951
949
FailureOr<Value> newVal2 =
952
950
pat2.matchAndRewriteMaskableOp (op, maskOp, rewriter);
953
951
if (!failed (newVal2))
954
952
return newVal2;
955
953
956
- ContractionOpToDotLowering pat3 (vectorTransformOptions , ctx);
954
+ ContractionOpToDotLowering pat3 (vectorContractLoweringOption , ctx);
957
955
FailureOr<Value> newVal3 =
958
956
pat3.matchAndRewriteMaskableOp (op, maskOp, rewriter);
959
957
if (!failed (newVal3))
960
958
return newVal3;
961
959
962
- ContractOpToElementwise pat4 (vectorTransformOptions , ctx);
960
+ ContractOpToElementwise pat4 (vectorContractLoweringOption , ctx);
963
961
FailureOr<Value> newVal4 =
964
962
pat4.matchAndRewriteMaskableOp (op, maskOp, rewriter);
965
963
if (!failed (newVal4))
@@ -1273,14 +1271,14 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
1273
1271
// / %mtb = maybe_transpose
1274
1272
// / %flattened_a = vector.shape_cast %mta
1275
1273
// / %flattened_b = vector.shape_cast %mtb
1276
- // / %flattened_d = vector.matmul %flattened_a, %flattened_b
1274
+ // / %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
1277
1275
// / %mtd = vector.shape_cast %flattened_d
1278
1276
// / %d = maybe_untranspose %mtd
1279
1277
// / %e = add %c, %d
1280
1278
// / ```
1281
- // / `vector.matmul ` later lowers to `llvm.matrix.multiply`.
1279
+ // / `vector.matrix_multiply ` later lowers to `llvm.matrix.multiply`.
1282
1280
//
1283
- // / This only kicks in when VectorTransformsOptions is set to `Matmul`.
1281
+ // / This only kicks in when vectorContractLowering is set to `Matmul`.
1284
1282
// / vector.transpose operations are inserted if the vector.contract op is not a
1285
1283
// / row-major matrix multiply.
1286
1284
// /
@@ -1292,8 +1290,7 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
1292
1290
if (maskOp)
1293
1291
return failure ();
1294
1292
1295
- if (vectorTransformOptions.vectorContractLowering !=
1296
- vector::VectorContractLowering::Matmul)
1293
+ if (vectorContractLowering != vector::VectorContractLowering::Matmul)
1297
1294
return failure ();
1298
1295
if (failed (filter (op)))
1299
1296
return failure ();
@@ -1382,13 +1379,14 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
1382
1379
} // namespace
1383
1380
1384
1381
void mlir::vector::populateVectorContractLoweringPatterns (
1385
- RewritePatternSet &patterns, VectorTransformsOptions options,
1386
- PatternBenefit benefit, bool disableOuterProductLowering) {
1382
+ RewritePatternSet &patterns,
1383
+ VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit,
1384
+ bool disableOuterProductLowering) {
1387
1385
if (!disableOuterProductLowering)
1388
1386
patterns.add <OuterProductOpLowering>(patterns.getContext (), benefit);
1389
1387
patterns.add <ContractionOpLowering, ContractionOpToMatmulOpLowering,
1390
1388
ContractionOpToOuterProductOpLowering>(
1391
- options , patterns.getContext (), benefit);
1389
+ vectorContractLoweringOption , patterns.getContext (), benefit);
1392
1390
}
1393
1391
1394
1392
void mlir::vector::populateVectorOuterProductLoweringPatterns (
0 commit comments