Skip to content

Commit 9df63b2

Browse files
[mlir][Transforms] Add 1:N matchAndRewrite overload (llvm#116470)
This commit adds a new `matchAndRewrite` overload to `ConversionPattern` to support 1:N replacements. This is the first of two main PRs that merge the 1:1 and 1:N dialect conversion drivers. The existing `matchAndRewrite` function supports only 1:1 replacements, as can be seen from the `ArrayRef<Value>` parameter. ```c++ LogicalResult ConversionPattern::matchAndRewrite( Operation *op, ArrayRef<Value> operands /*adaptor values*/, ConversionPatternRewriter &rewriter) const; ``` This commit adds a `matchAndRewrite` overload that is called by the dialect conversion driver. By default, this new overload dispatches to the original 1:1 `matchAndRewrite` implementation. Existing `ConversionPattern`s do not need to be changed as long as there are no 1:N type conversions or value replacements. ```c++ LogicalResult ConversionPattern::matchAndRewrite( Operation *op, ArrayRef<ValueRange> operands /*adaptor values*/, ConversionPatternRewriter &rewriter) const { // Note: getOneToOneAdaptorOperands produces a fatal error if at least one // ValueRange has 0 or more than 1 value. return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter); } ``` The `ConversionValueMapping`, which keeps track of value replacements and materializations, still does not support 1:N replacements. We still rely on argument materializations to convert N replacement values back into a single value. The `ConversionValueMapping` will be generalized to 1:N mappings in the second main PR. Before handing the adaptor values to a `ConversionPattern`, all argument materializations are "unpacked". The `ConversionPattern` receives N replacement values and does not see any argument materializations. This implementation strategy allows us to use the 1:N infrastructure/API in `ConversionPattern`s even though some functionality is still missing in the driver. This strategy was chosen to keep the sizes of the PRs smaller and to make it easier for downstream users to adapt to API changes. This commit also updates the the "decompose call graphs" transformation and the "sparse tensor codegen" transformation to use the new 1:N `ConversionPattern` API. Note for LLVM conversion: If you are using a type converter with 1:N type conversion rules or if your patterns are performing 1:N replacements (via `replaceOpWithMultiple` or `applySignatureConversion`), conversion pattern applications will start failing (fatal LLVM error) with this error message: `pattern 'name' does not support 1:N conversion`. The name of the failing pattern is shown in the error message. These patterns must be updated to the new 1:N `matchAndRewrite` API.
1 parent b22cc5a commit 9df63b2

File tree

12 files changed

+497
-325
lines changed

12 files changed

+497
-325
lines changed

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ template <typename SourceOp>
143143
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
144144
public:
145145
using OpAdaptor = typename SourceOp::Adaptor;
146+
using OneToNOpAdaptor =
147+
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
146148

147149
explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
148150
PatternBenefit benefit = 1)
@@ -153,17 +155,29 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
153155
/// Wrappers around the RewritePattern methods that pass the derived op type.
154156
void rewrite(Operation *op, ArrayRef<Value> operands,
155157
ConversionPatternRewriter &rewriter) const final {
156-
rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)),
157-
rewriter);
158+
auto sourceOp = cast<SourceOp>(op);
159+
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
160+
}
161+
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
162+
ConversionPatternRewriter &rewriter) const final {
163+
auto sourceOp = cast<SourceOp>(op);
164+
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
158165
}
159166
LogicalResult match(Operation *op) const final {
160167
return match(cast<SourceOp>(op));
161168
}
162169
LogicalResult
163170
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
164171
ConversionPatternRewriter &rewriter) const final {
165-
return matchAndRewrite(cast<SourceOp>(op),
166-
OpAdaptor(operands, cast<SourceOp>(op)), rewriter);
172+
auto sourceOp = cast<SourceOp>(op);
173+
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
174+
}
175+
LogicalResult
176+
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
177+
ConversionPatternRewriter &rewriter) const final {
178+
auto sourceOp = cast<SourceOp>(op);
179+
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
180+
rewriter);
167181
}
168182

169183
/// Rewrite and Match methods that operate on the SourceOp type. These must be
@@ -175,6 +189,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
175189
ConversionPatternRewriter &rewriter) const {
176190
llvm_unreachable("must override rewrite or matchAndRewrite");
177191
}
192+
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
193+
ConversionPatternRewriter &rewriter) const {
194+
SmallVector<Value> oneToOneOperands =
195+
getOneToOneAdaptorOperands(adaptor.getOperands());
196+
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
197+
}
178198
virtual LogicalResult
179199
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
180200
ConversionPatternRewriter &rewriter) const {
@@ -183,6 +203,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
183203
rewrite(op, adaptor, rewriter);
184204
return success();
185205
}
206+
virtual LogicalResult
207+
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
208+
ConversionPatternRewriter &rewriter) const {
209+
SmallVector<Value> oneToOneOperands =
210+
getOneToOneAdaptorOperands(adaptor.getOperands());
211+
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
212+
}
186213

187214
private:
188215
using ConvertToLLVMPattern::match;

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,8 +538,15 @@ class ConversionPattern : public RewritePattern {
538538
ConversionPatternRewriter &rewriter) const {
539539
llvm_unreachable("unimplemented rewrite");
540540
}
541+
virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands,
542+
ConversionPatternRewriter &rewriter) const {
543+
rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
544+
}
541545

542546
/// Hook for derived classes to implement combined matching and rewriting.
547+
/// This overload supports only 1:1 replacements. The 1:N overload is called
548+
/// by the driver. By default, it calls this 1:1 overload or reports a fatal
549+
/// error if 1:N replacements were found.
543550
virtual LogicalResult
544551
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
545552
ConversionPatternRewriter &rewriter) const {
@@ -549,6 +556,14 @@ class ConversionPattern : public RewritePattern {
549556
return success();
550557
}
551558

559+
/// Hook for derived classes to implement combined matching and rewriting.
560+
/// This overload supports 1:N replacements.
561+
virtual LogicalResult
562+
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
563+
ConversionPatternRewriter &rewriter) const {
564+
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
565+
}
566+
552567
/// Attempt to match and rewrite the IR root at the specified operation.
553568
LogicalResult matchAndRewrite(Operation *op,
554569
PatternRewriter &rewriter) const final;
@@ -575,6 +590,15 @@ class ConversionPattern : public RewritePattern {
575590
: RewritePattern(std::forward<Args>(args)...),
576591
typeConverter(&typeConverter) {}
577592

593+
/// Given an array of value ranges, which are the inputs to a 1:N adaptor,
594+
/// try to extract the single value of each range to construct a the inputs
595+
/// for a 1:1 adaptor.
596+
///
597+
/// This function produces a fatal error if at least one range has 0 or
598+
/// more than 1 value: "pattern 'name' does not support 1:N conversion"
599+
SmallVector<Value>
600+
getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands) const;
601+
578602
protected:
579603
/// An optional type converter for use by this pattern.
580604
const TypeConverter *typeConverter = nullptr;
@@ -590,6 +614,8 @@ template <typename SourceOp>
590614
class OpConversionPattern : public ConversionPattern {
591615
public:
592616
using OpAdaptor = typename SourceOp::Adaptor;
617+
using OneToNOpAdaptor =
618+
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
593619

594620
OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
595621
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -608,12 +634,24 @@ class OpConversionPattern : public ConversionPattern {
608634
auto sourceOp = cast<SourceOp>(op);
609635
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
610636
}
637+
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
638+
ConversionPatternRewriter &rewriter) const final {
639+
auto sourceOp = cast<SourceOp>(op);
640+
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
641+
}
611642
LogicalResult
612643
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
613644
ConversionPatternRewriter &rewriter) const final {
614645
auto sourceOp = cast<SourceOp>(op);
615646
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
616647
}
648+
LogicalResult
649+
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
650+
ConversionPatternRewriter &rewriter) const final {
651+
auto sourceOp = cast<SourceOp>(op);
652+
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
653+
rewriter);
654+
}
617655

618656
/// Rewrite and Match methods that operate on the SourceOp type. These must be
619657
/// overridden by the derived pattern class.
@@ -624,6 +662,12 @@ class OpConversionPattern : public ConversionPattern {
624662
ConversionPatternRewriter &rewriter) const {
625663
llvm_unreachable("must override matchAndRewrite or a rewrite method");
626664
}
665+
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
666+
ConversionPatternRewriter &rewriter) const {
667+
SmallVector<Value> oneToOneOperands =
668+
getOneToOneAdaptorOperands(adaptor.getOperands());
669+
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
670+
}
627671
virtual LogicalResult
628672
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
629673
ConversionPatternRewriter &rewriter) const {
@@ -632,6 +676,13 @@ class OpConversionPattern : public ConversionPattern {
632676
rewrite(op, adaptor, rewriter);
633677
return success();
634678
}
679+
virtual LogicalResult
680+
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
681+
ConversionPatternRewriter &rewriter) const {
682+
SmallVector<Value> oneToOneOperands =
683+
getOneToOneAdaptorOperands(adaptor.getOperands());
684+
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
685+
}
635686

636687
private:
637688
using ConversionPattern::matchAndRewrite;
@@ -657,18 +708,31 @@ class OpInterfaceConversionPattern : public ConversionPattern {
657708
ConversionPatternRewriter &rewriter) const final {
658709
rewrite(cast<SourceOp>(op), operands, rewriter);
659710
}
711+
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
712+
ConversionPatternRewriter &rewriter) const final {
713+
rewrite(cast<SourceOp>(op), operands, rewriter);
714+
}
660715
LogicalResult
661716
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
662717
ConversionPatternRewriter &rewriter) const final {
663718
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
664719
}
720+
LogicalResult
721+
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
722+
ConversionPatternRewriter &rewriter) const final {
723+
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
724+
}
665725

666726
/// Rewrite and Match methods that operate on the SourceOp type. These must be
667727
/// overridden by the derived pattern class.
668728
virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
669729
ConversionPatternRewriter &rewriter) const {
670730
llvm_unreachable("must override matchAndRewrite or a rewrite method");
671731
}
732+
virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands,
733+
ConversionPatternRewriter &rewriter) const {
734+
rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
735+
}
672736
virtual LogicalResult
673737
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
674738
ConversionPatternRewriter &rewriter) const {
@@ -677,6 +741,11 @@ class OpInterfaceConversionPattern : public ConversionPattern {
677741
rewrite(op, operands, rewriter);
678742
return success();
679743
}
744+
virtual LogicalResult
745+
matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
746+
ConversionPatternRewriter &rewriter) const {
747+
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
748+
}
680749

681750
private:
682751
using ConversionPattern::matchAndRewrite;

mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp

Lines changed: 6 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -13,40 +13,6 @@
1313
using namespace mlir;
1414
using namespace mlir::func;
1515

16-
//===----------------------------------------------------------------------===//
17-
// Helper functions
18-
//===----------------------------------------------------------------------===//
19-
20-
/// If the given value can be decomposed with the type converter, decompose it.
21-
/// Otherwise, return the given value.
22-
// TODO: Value decomposition should happen automatically through a 1:N adaptor.
23-
// This function will disappear when the 1:1 and 1:N drivers are merged.
24-
static SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc,
25-
Value value,
26-
const TypeConverter *converter) {
27-
// Try to convert the given value's type. If that fails, just return the
28-
// given value.
29-
SmallVector<Type> convertedTypes;
30-
if (failed(converter->convertType(value.getType(), convertedTypes)))
31-
return {value};
32-
if (convertedTypes.empty())
33-
return {};
34-
35-
// If the given value's type is already legal, just return the given value.
36-
TypeRange convertedTypeRange(convertedTypes);
37-
if (convertedTypeRange == TypeRange(value.getType()))
38-
return {value};
39-
40-
// Try to materialize a target conversion. If the materialization did not
41-
// produce values of the requested type, the materialization failed. Just
42-
// return the given value in that case.
43-
SmallVector<Value> result = converter->materializeTargetConversion(
44-
builder, loc, convertedTypeRange, value);
45-
if (result.empty())
46-
return {value};
47-
return result;
48-
}
49-
5016
//===----------------------------------------------------------------------===//
5117
// DecomposeCallGraphTypesForFuncArgs
5218
//===----------------------------------------------------------------------===//
@@ -102,16 +68,11 @@ struct DecomposeCallGraphTypesForReturnOp
10268
using OpConversionPattern::OpConversionPattern;
10369

10470
LogicalResult
105-
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
71+
matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
10672
ConversionPatternRewriter &rewriter) const final {
10773
SmallVector<Value, 2> newOperands;
108-
for (Value operand : adaptor.getOperands()) {
109-
// TODO: We can directly take the values from the adaptor once this is a
110-
// 1:N conversion pattern.
111-
llvm::append_range(newOperands,
112-
decomposeValue(rewriter, operand.getLoc(), operand,
113-
getTypeConverter()));
114-
}
74+
for (ValueRange operand : adaptor.getOperands())
75+
llvm::append_range(newOperands, operand);
11576
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
11677
return success();
11778
}
@@ -128,18 +89,13 @@ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
12889
using OpConversionPattern::OpConversionPattern;
12990

13091
LogicalResult
131-
matchAndRewrite(CallOp op, OpAdaptor adaptor,
92+
matchAndRewrite(CallOp op, OneToNOpAdaptor adaptor,
13293
ConversionPatternRewriter &rewriter) const final {
13394

13495
// Create the operands list of the new `CallOp`.
13596
SmallVector<Value, 2> newOperands;
136-
for (Value operand : adaptor.getOperands()) {
137-
// TODO: We can directly take the values from the adaptor once this is a
138-
// 1:N conversion pattern.
139-
llvm::append_range(newOperands,
140-
decomposeValue(rewriter, operand.getLoc(), operand,
141-
getTypeConverter()));
142-
}
97+
for (ValueRange operand : adaptor.getOperands())
98+
llvm::append_range(newOperands, operand);
14399

144100
// Create the new result types for the new `CallOp` and track the number of
145101
// replacement types for each original op result.

mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@
1313
using namespace mlir;
1414
using namespace mlir::func;
1515

16+
/// Flatten the given value ranges into a single vector of values.
17+
static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
18+
SmallVector<Value> result;
19+
for (const auto &vals : values)
20+
llvm::append_range(result, vals);
21+
return result;
22+
}
23+
1624
namespace {
1725
/// Converts the operand and result types of the CallOp, used together with the
1826
/// FuncOpSignatureConversion.
@@ -21,7 +29,7 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
2129

2230
/// Hook for derived classes to implement combined matching and rewriting.
2331
LogicalResult
24-
matchAndRewrite(CallOp callOp, OpAdaptor adaptor,
32+
matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor,
2533
ConversionPatternRewriter &rewriter) const override {
2634
// Convert the original function results. Keep track of how many result
2735
// types an original result type is converted into.
@@ -38,9 +46,9 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
3846

3947
// Substitute with the new result types from the corresponding FuncType
4048
// conversion.
41-
auto newCallOp =
42-
rewriter.create<CallOp>(callOp.getLoc(), callOp.getCallee(),
43-
convertedResults, adaptor.getOperands());
49+
auto newCallOp = rewriter.create<CallOp>(
50+
callOp.getLoc(), callOp.getCallee(), convertedResults,
51+
flattenValues(adaptor.getOperands()));
4452
SmallVector<ValueRange> replacements;
4553
size_t offset = 0;
4654
for (int i = 0, e = callOp->getNumResults(); i < e; ++i) {

0 commit comments

Comments
 (0)