Skip to content

Commit 153310a

Browse files
replace with multiple
Apply suggestions from code review Co-authored-by: Markus Böck <[email protected]> address comments [WIP] 1:N conversion pattern update test cases
1 parent e872b86 commit 153310a

File tree

9 files changed

+381
-303
lines changed

9 files changed

+381
-303
lines changed

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ template <typename SourceOp>
143143
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
144144
public:
145145
using OpAdaptor = typename SourceOp::Adaptor;
146+
using OneToNOpAdaptor =
147+
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
146148

147149
explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
148150
PatternBenefit benefit = 1)
@@ -153,17 +155,29 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
153155
/// Wrappers around the RewritePattern methods that pass the derived op type.
154156
void rewrite(Operation *op, ArrayRef<Value> operands,
155157
ConversionPatternRewriter &rewriter) const final {
156-
rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)),
157-
rewriter);
158+
auto sourceOp = cast<SourceOp>(op);
159+
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
160+
}
161+
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
162+
ConversionPatternRewriter &rewriter) const final {
163+
auto sourceOp = cast<SourceOp>(op);
164+
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
158165
}
159166
LogicalResult match(Operation *op) const final {
160167
return match(cast<SourceOp>(op));
161168
}
162169
LogicalResult
163170
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
164171
ConversionPatternRewriter &rewriter) const final {
165-
return matchAndRewrite(cast<SourceOp>(op),
166-
OpAdaptor(operands, cast<SourceOp>(op)), rewriter);
172+
auto sourceOp = cast<SourceOp>(op);
173+
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
174+
}
175+
LogicalResult
176+
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
177+
ConversionPatternRewriter &rewriter) const final {
178+
auto sourceOp = cast<SourceOp>(op);
179+
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
180+
rewriter);
167181
}
168182

169183
/// Rewrite and Match methods that operate on the SourceOp type. These must be
@@ -175,6 +189,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
175189
ConversionPatternRewriter &rewriter) const {
176190
llvm_unreachable("must override rewrite or matchAndRewrite");
177191
}
192+
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
193+
ConversionPatternRewriter &rewriter) const {
194+
SmallVector<Value> oneToOneOperands =
195+
getOneToOneAdaptorOperands(adaptor.getOperands());
196+
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
197+
}
178198
virtual LogicalResult
179199
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
180200
ConversionPatternRewriter &rewriter) const {
@@ -183,6 +203,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
183203
rewrite(op, adaptor, rewriter);
184204
return success();
185205
}
206+
virtual LogicalResult
207+
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
208+
ConversionPatternRewriter &rewriter) const {
209+
SmallVector<Value> oneToOneOperands =
210+
getOneToOneAdaptorOperands(adaptor.getOperands());
211+
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
212+
}
186213

187214
private:
188215
using ConvertToLLVMPattern::match;

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,10 @@ class ConversionPattern : public RewritePattern {
537537
ConversionPatternRewriter &rewriter) const {
538538
llvm_unreachable("unimplemented rewrite");
539539
}
540+
virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands,
541+
ConversionPatternRewriter &rewriter) const {
542+
rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
543+
}
540544

541545
/// Hook for derived classes to implement combined matching and rewriting.
542546
virtual LogicalResult
@@ -547,6 +551,11 @@ class ConversionPattern : public RewritePattern {
547551
rewrite(op, operands, rewriter);
548552
return success();
549553
}
554+
virtual LogicalResult
555+
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
556+
ConversionPatternRewriter &rewriter) const {
557+
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
558+
}
550559

551560
/// Attempt to match and rewrite the IR root at the specified operation.
552561
LogicalResult matchAndRewrite(Operation *op,
@@ -574,6 +583,15 @@ class ConversionPattern : public RewritePattern {
574583
: RewritePattern(std::forward<Args>(args)...),
575584
typeConverter(&typeConverter) {}
576585

586+
/// Given an array of value ranges, which are the inputs to a 1:N adaptor,
587+
/// try to extract the single value of each range to construct a the inputs
588+
/// for a 1:1 adaptor.
589+
///
590+
/// This function produces a fatal error if at least one range has 0 or
591+
/// more than 1 value: "pattern 'name' does not support 1:N conversion"
592+
SmallVector<Value>
593+
getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands) const;
594+
577595
protected:
578596
/// An optional type converter for use by this pattern.
579597
const TypeConverter *typeConverter = nullptr;
@@ -589,6 +607,8 @@ template <typename SourceOp>
589607
class OpConversionPattern : public ConversionPattern {
590608
public:
591609
using OpAdaptor = typename SourceOp::Adaptor;
610+
using OneToNOpAdaptor =
611+
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
592612

593613
OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
594614
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -607,12 +627,24 @@ class OpConversionPattern : public ConversionPattern {
607627
auto sourceOp = cast<SourceOp>(op);
608628
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
609629
}
630+
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
631+
ConversionPatternRewriter &rewriter) const final {
632+
auto sourceOp = cast<SourceOp>(op);
633+
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
634+
}
610635
LogicalResult
611636
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
612637
ConversionPatternRewriter &rewriter) const final {
613638
auto sourceOp = cast<SourceOp>(op);
614639
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
615640
}
641+
LogicalResult
642+
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
643+
ConversionPatternRewriter &rewriter) const final {
644+
auto sourceOp = cast<SourceOp>(op);
645+
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
646+
rewriter);
647+
}
616648

617649
/// Rewrite and Match methods that operate on the SourceOp type. These must be
618650
/// overridden by the derived pattern class.
@@ -623,6 +655,12 @@ class OpConversionPattern : public ConversionPattern {
623655
ConversionPatternRewriter &rewriter) const {
624656
llvm_unreachable("must override matchAndRewrite or a rewrite method");
625657
}
658+
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
659+
ConversionPatternRewriter &rewriter) const {
660+
SmallVector<Value> oneToOneOperands =
661+
getOneToOneAdaptorOperands(adaptor.getOperands());
662+
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
663+
}
626664
virtual LogicalResult
627665
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
628666
ConversionPatternRewriter &rewriter) const {
@@ -631,6 +669,13 @@ class OpConversionPattern : public ConversionPattern {
631669
rewrite(op, adaptor, rewriter);
632670
return success();
633671
}
672+
virtual LogicalResult
673+
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
674+
ConversionPatternRewriter &rewriter) const {
675+
SmallVector<Value> oneToOneOperands =
676+
getOneToOneAdaptorOperands(adaptor.getOperands());
677+
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
678+
}
634679

635680
private:
636681
using ConversionPattern::matchAndRewrite;
@@ -656,18 +701,31 @@ class OpInterfaceConversionPattern : public ConversionPattern {
656701
ConversionPatternRewriter &rewriter) const final {
657702
rewrite(cast<SourceOp>(op), operands, rewriter);
658703
}
704+
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
705+
ConversionPatternRewriter &rewriter) const final {
706+
rewrite(cast<SourceOp>(op), operands, rewriter);
707+
}
659708
LogicalResult
660709
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
661710
ConversionPatternRewriter &rewriter) const final {
662711
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
663712
}
713+
LogicalResult
714+
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
715+
ConversionPatternRewriter &rewriter) const final {
716+
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
717+
}
664718

665719
/// Rewrite and Match methods that operate on the SourceOp type. These must be
666720
/// overridden by the derived pattern class.
667721
virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
668722
ConversionPatternRewriter &rewriter) const {
669723
llvm_unreachable("must override matchAndRewrite or a rewrite method");
670724
}
725+
virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands,
726+
ConversionPatternRewriter &rewriter) const {
727+
rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
728+
}
671729
virtual LogicalResult
672730
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
673731
ConversionPatternRewriter &rewriter) const {
@@ -676,6 +734,11 @@ class OpInterfaceConversionPattern : public ConversionPattern {
676734
rewrite(op, operands, rewriter);
677735
return success();
678736
}
737+
virtual LogicalResult
738+
matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
739+
ConversionPatternRewriter &rewriter) const {
740+
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
741+
}
679742

680743
private:
681744
using ConversionPattern::matchAndRewrite;

mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp

Lines changed: 6 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -13,40 +13,6 @@
1313
using namespace mlir;
1414
using namespace mlir::func;
1515

16-
//===----------------------------------------------------------------------===//
17-
// Helper functions
18-
//===----------------------------------------------------------------------===//
19-
20-
/// If the given value can be decomposed with the type converter, decompose it.
21-
/// Otherwise, return the given value.
22-
// TODO: Value decomposition should happen automatically through a 1:N adaptor.
23-
// This function will disappear when the 1:1 and 1:N drivers are merged.
24-
static SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc,
25-
Value value,
26-
const TypeConverter *converter) {
27-
// Try to convert the given value's type. If that fails, just return the
28-
// given value.
29-
SmallVector<Type> convertedTypes;
30-
if (failed(converter->convertType(value.getType(), convertedTypes)))
31-
return {value};
32-
if (convertedTypes.empty())
33-
return {};
34-
35-
// If the given value's type is already legal, just return the given value.
36-
TypeRange convertedTypeRange(convertedTypes);
37-
if (convertedTypeRange == TypeRange(value.getType()))
38-
return {value};
39-
40-
// Try to materialize a target conversion. If the materialization did not
41-
// produce values of the requested type, the materialization failed. Just
42-
// return the given value in that case.
43-
SmallVector<Value> result = converter->materializeTargetConversion(
44-
builder, loc, convertedTypeRange, value);
45-
if (result.empty())
46-
return {value};
47-
return result;
48-
}
49-
5016
//===----------------------------------------------------------------------===//
5117
// DecomposeCallGraphTypesForFuncArgs
5218
//===----------------------------------------------------------------------===//
@@ -102,16 +68,11 @@ struct DecomposeCallGraphTypesForReturnOp
10268
using OpConversionPattern::OpConversionPattern;
10369

10470
LogicalResult
105-
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
71+
matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
10672
ConversionPatternRewriter &rewriter) const final {
10773
SmallVector<Value, 2> newOperands;
108-
for (Value operand : adaptor.getOperands()) {
109-
// TODO: We can directly take the values from the adaptor once this is a
110-
// 1:N conversion pattern.
111-
llvm::append_range(newOperands,
112-
decomposeValue(rewriter, operand.getLoc(), operand,
113-
getTypeConverter()));
114-
}
74+
for (ValueRange operand : adaptor.getOperands())
75+
llvm::append_range(newOperands, operand);
11576
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
11677
return success();
11778
}
@@ -128,18 +89,13 @@ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
12889
using OpConversionPattern::OpConversionPattern;
12990

13091
LogicalResult
131-
matchAndRewrite(CallOp op, OpAdaptor adaptor,
92+
matchAndRewrite(CallOp op, OneToNOpAdaptor adaptor,
13293
ConversionPatternRewriter &rewriter) const final {
13394

13495
// Create the operands list of the new `CallOp`.
13596
SmallVector<Value, 2> newOperands;
136-
for (Value operand : adaptor.getOperands()) {
137-
// TODO: We can directly take the values from the adaptor once this is a
138-
// 1:N conversion pattern.
139-
llvm::append_range(newOperands,
140-
decomposeValue(rewriter, operand.getLoc(), operand,
141-
getTypeConverter()));
142-
}
97+
for (ValueRange operand : adaptor.getOperands())
98+
llvm::append_range(newOperands, operand);
14399

144100
// Create the new result types for the new `CallOp` and track the number of
145101
// replacement types for each original op result.

mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
2121

2222
/// Hook for derived classes to implement combined matching and rewriting.
2323
LogicalResult
24-
matchAndRewrite(CallOp callOp, OpAdaptor adaptor,
24+
matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor,
2525
ConversionPatternRewriter &rewriter) const override {
2626
// Convert the original function results.
2727
SmallVector<Type, 1> convertedResults;
@@ -37,7 +37,8 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
3737
// Substitute with the new result types from the corresponding FuncType
3838
// conversion.
3939
rewriter.replaceOpWithNewOp<CallOp>(
40-
callOp, callOp.getCallee(), convertedResults, adaptor.getOperands());
40+
callOp, callOp.getCallee(), convertedResults,
41+
getOneToOneAdaptorOperands(adaptor.getOperands()));
4142
return success();
4243
}
4344
};

0 commit comments

Comments
 (0)