Skip to content

Commit 9c4611f

Browse files
[mlir] Implement pass utils for 1:N type conversions.
The current dialect conversion does not support 1:N type conversions. This commit implements a (poor-man's) dialect conversion pass that does just that. To keep the pass independent of the "real" dialect conversion infrastructure, it provides a specialization of the TypeConverter class that allows for N:1 target materializations, a specialization of the RewritePattern and PatternRewriter classes that automatically add appropriate unrealized casts supporting 1:N type conversions and provide converted operands for implementing subclasses, and a conversion driver that applies the provided patterns and replaces the unrealized casts that haven't folded away with user-provided materializations. The current pass is powerful enough to express many existing manual solutions for 1:N type conversions or extend transforms that previously didn't support them, out of which this patch implements call graph type decomposition (which is currently implemented with a ValueDecomposer that is only used there). The goal of this pass is to illustrate the effect that 1:N type conversions could have, gain experience in how patterns should be written that achieve that effect, and get feedback on how the APIs of the dialect conversion should be extended or changed to support such patterns. The hope is that the "real" dialect conversion eventually supports such patterns, at which point, this pass could be removed again. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D144469
1 parent 1c4fedf commit 9c4611f

File tree

12 files changed

+1142
-0
lines changed

12 files changed

+1142
-0
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//===- OneToNTypeFuncConversions.h - 1:N type conv. for Func ----*- C++ -*-===//
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_FUNC_TRANSFORMS_ONETONTYPEFUNCCONVERSIONS_H
10+
#define MLIR_DIALECT_FUNC_TRANSFORMS_ONETONTYPEFUNCCONVERSIONS_H
11+
12+
namespace mlir {
13+
class TypeConverter;
14+
class RewritePatternSet;
15+
} // namespace mlir
16+
17+
namespace mlir {
18+
19+
// Populates the provided pattern set with patterns that do 1:N type conversions
20+
// on func ops. This is intended to be used with `applyPartialOneToNConversion`.
21+
void populateFuncTypeConversionPatterns(TypeConverter &typeConverter,
22+
RewritePatternSet &patterns);
23+
24+
} // namespace mlir
25+
26+
#endif // MLIR_DIALECT_FUNC_TRANSFORMS_ONETONTYPEFUNCCONVERSIONS_H
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
//===-- OneToNTypeConversion.h - Utils for 1:N type conversion --*- C++ -*-===//
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file provides utils for implementing (poor-man's) dialect conversion
10+
// passes with 1:N type conversions.
11+
//
12+
// The main function, `applyPartialOneToNConversion`, first applies a set of
13+
// `RewritePattern`s, which produce unrealized casts to convert the operands and
14+
// results from and to the source types, and then replaces all newly added
15+
// unrealized casts by user-provided materializations. For this to work, the
16+
// main function requires a special `TypeConverter`, a special
17+
// `PatternRewriter`, and special RewritePattern`s, which extend their
18+
// respective base classes for 1:N type converions.
19+
//
20+
// Note that this is much more simple-minded than the "real" dialect conversion,
21+
// which checks for legality before applying patterns and does probably many
22+
// other additional things. Ideally, some of the extensions here could be
23+
// integrated there.
24+
//
25+
//===----------------------------------------------------------------------===//
26+
27+
#ifndef MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
28+
#define MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
29+
30+
#include "mlir/IR/PatternMatch.h"
31+
#include "mlir/Transforms/DialectConversion.h"
32+
#include "llvm/ADT/SmallVector.h"
33+
34+
namespace mlir {
35+
36+
/// Extends `TypeConverter` with 1:N target materializations. Such
37+
/// materializations have to provide the "reverse" of 1:N type conversions,
38+
/// i.e., they need to materialize N values with target types into one value
39+
/// with a source type (which isn't possible in the base class currently).
40+
class OneToNTypeConverter : public TypeConverter {
41+
public:
42+
/// Callback that expresses user-provided materialization logic from the given
43+
/// value to N values of the given types. This is useful for expressing target
44+
/// materializations for 1:N type conversions, which materialize one value in
45+
/// a source type as N values in target types.
46+
using OneToNMaterializationCallbackFn =
47+
std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
48+
Value, Location)>;
49+
50+
/// Creates the mapping of the given range of original types to target types
51+
/// of the conversion and stores that mapping in the given (signature)
52+
/// conversion. This function simply calls
53+
/// `TypeConverter::convertSignatureArgs` and exists here with a different
54+
/// name to reflect the broader semantic.
55+
LogicalResult computeTypeMapping(TypeRange types,
56+
SignatureConversion &result) {
57+
return convertSignatureArgs(types, result);
58+
}
59+
60+
/// Applies one of the user-provided 1:N target materializations. If several
61+
/// exists, they are tried out in the reverse order in which they have been
62+
/// added until the first one succeeds. If none succeeds, the functions
63+
/// returns `std::nullopt`.
64+
std::optional<SmallVector<Value>>
65+
materializeTargetConversion(OpBuilder &builder, Location loc,
66+
TypeRange resultTypes, Value input) const;
67+
68+
/// Adds a 1:N target materialization to the converter. Such materializations
69+
/// build IR that converts N values with target types into 1 value of the
70+
/// source type.
71+
void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) {
72+
oneToNTargetMaterializations.emplace_back(std::move(callback));
73+
}
74+
75+
private:
76+
SmallVector<OneToNMaterializationCallbackFn> oneToNTargetMaterializations;
77+
};
78+
79+
/// Stores a 1:N mapping of types and provides several useful accessors. This
80+
/// class extends `SignatureConversion`, which already supports 1:N type
81+
/// mappings but lacks some accessors into the mapping as well as access to the
82+
/// original types.
83+
class OneToNTypeMapping : public TypeConverter::SignatureConversion {
84+
public:
85+
OneToNTypeMapping(TypeRange originalTypes)
86+
: TypeConverter::SignatureConversion(originalTypes.size()),
87+
originalTypes(originalTypes) {}
88+
89+
using TypeConverter::SignatureConversion::getConvertedTypes;
90+
91+
/// Returns the list of types that corresponds to the original type at the
92+
/// given index.
93+
TypeRange getConvertedTypes(unsigned originalTypeNo) const;
94+
95+
/// Returns the list of original types.
96+
TypeRange getOriginalTypes() const { return originalTypes; }
97+
98+
/// Returns the slice of converted values that corresponds the original value
99+
/// at the given index.
100+
ValueRange getConvertedValues(ValueRange convertedValues,
101+
unsigned originalValueNo) const;
102+
103+
/// Fills the given result vector with as many copies of the location of the
104+
/// original value as the number of values it is converted to.
105+
void convertLocation(Value originalValue, unsigned originalValueNo,
106+
llvm::SmallVectorImpl<Location> &result) const;
107+
108+
/// Fills the given result vector with as many copies of the lociation of each
109+
/// original value as the number of values they are respectively converted to.
110+
void convertLocations(ValueRange originalValues,
111+
llvm::SmallVectorImpl<Location> &result) const;
112+
113+
/// Returns true iff at least one type conversion maps an input type to a type
114+
/// that is different from itself.
115+
bool hasNonIdentityConversion() const;
116+
117+
private:
118+
llvm::SmallVector<Type> originalTypes;
119+
};
120+
121+
/// Extends the basic `RewritePattern` class with a type converter member and
122+
/// some accessors to it. This is useful for patterns that are not
123+
/// `ConversionPattern`s but still require access to a type converter.
124+
class RewritePatternWithConverter : public mlir::RewritePattern {
125+
public:
126+
/// Construct a conversion pattern with the given converter, and forward the
127+
/// remaining arguments to RewritePattern.
128+
template <typename... Args>
129+
RewritePatternWithConverter(TypeConverter &typeConverter, Args &&...args)
130+
: RewritePattern(std::forward<Args>(args)...),
131+
typeConverter(&typeConverter) {}
132+
133+
/// Return the type converter held by this pattern, or nullptr if the pattern
134+
/// does not require type conversion.
135+
TypeConverter *getTypeConverter() const { return typeConverter; }
136+
137+
template <typename ConverterTy>
138+
std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
139+
ConverterTy *>
140+
getTypeConverter() const {
141+
return static_cast<ConverterTy *>(typeConverter);
142+
}
143+
144+
protected:
145+
/// A type converter for use by this pattern.
146+
TypeConverter *const typeConverter;
147+
};
148+
149+
/// Specialization of `PatternRewriter` that `OneToNConversionPattern`s use. The
150+
/// class provides additional rewrite methods that are specific to 1:N type
151+
/// conversions.
152+
class OneToNPatternRewriter : public PatternRewriter {
153+
public:
154+
OneToNPatternRewriter(MLIRContext *context) : PatternRewriter(context) {}
155+
156+
/// Replaces the results of the operation with the specified list of values
157+
/// mapped back to the original types as specified in the provided type
158+
/// mapping. That type mapping must match the replaced op (i.e., the original
159+
/// types must be the same as the result types of the op) and the new values
160+
/// (i.e., the converted types must be the same as the types of the new
161+
/// values).
162+
void replaceOp(Operation *op, ValueRange newValues,
163+
const OneToNTypeMapping &resultMapping);
164+
using PatternRewriter::replaceOp;
165+
166+
/// Applies the given argument conversion to the given block. This consists of
167+
/// replacing each original argument with N arguments as specified in the
168+
/// argument conversion and inserting unrealized casts from the converted
169+
/// values to the original types, which are then used in lieu of the original
170+
/// ones. (Eventually, `applyPartialOneToNConversion` replaces these casts
171+
/// with a user-provided argument materialization if necessary.) This is
172+
/// similar to `ArgConverter::applySignatureConversion` but (1) handles 1:N
173+
/// type conversion properly and probably (2) doesn't handle many other edge
174+
/// cases.
175+
Block *applySignatureConversion(Block *block,
176+
OneToNTypeMapping &argumentConversion);
177+
};
178+
179+
/// Base class for patterns with 1:N type conversions. Derived classes have to
180+
/// overwrite the `matchAndRewrite` overlaod that provides additional
181+
/// information for 1:N type conversions.
182+
class OneToNConversionPattern : public RewritePatternWithConverter {
183+
public:
184+
using RewritePatternWithConverter::RewritePatternWithConverter;
185+
186+
/// This function has to be implemented by base classes and is called from the
187+
/// usual overloads. Like in "normal" `DialectConversion`, the function is
188+
/// provided with the converted operands (which thus have target types). Since
189+
/// 1:N conversion are supported, there is usually no 1:1 relationship between
190+
/// the original and the converted operands. Instead, the provided
191+
/// `operandMapping` can be used to access the converted operands that
192+
/// correspond to a particular original operand. Similarly, `resultMapping`
193+
/// is provided to help with assembling the result values, which may have 1:N
194+
/// correspondences as well. In that case, the original op should be replaced
195+
/// with the overload of `replaceOp` that takes the provided `resultMapping`
196+
/// in order to deal with the mapping of converted result values to their
197+
/// usages in the original types correctly.
198+
virtual LogicalResult matchAndRewrite(Operation *op,
199+
OneToNPatternRewriter &rewriter,
200+
const OneToNTypeMapping &operandMapping,
201+
const OneToNTypeMapping &resultMapping,
202+
ValueRange convertedOperands) const = 0;
203+
204+
LogicalResult matchAndRewrite(Operation *op,
205+
PatternRewriter &rewriter) const final;
206+
};
207+
208+
/// This class is a wrapper around `OneToNConversionPattern` for matching
209+
/// against instances of a particular op class.
210+
template <typename SourceOp>
211+
class OneToNOpConversionPattern : public OneToNConversionPattern {
212+
public:
213+
OneToNOpConversionPattern(TypeConverter &typeConverter, MLIRContext *context,
214+
PatternBenefit benefit = 1,
215+
ArrayRef<StringRef> generatedNames = {})
216+
: OneToNConversionPattern(typeConverter, SourceOp::getOperationName(),
217+
benefit, context, generatedNames) {}
218+
219+
using OneToNConversionPattern::matchAndRewrite;
220+
221+
/// Overload that derived classes have to override for their op type.
222+
virtual LogicalResult matchAndRewrite(SourceOp op,
223+
OneToNPatternRewriter &rewriter,
224+
const OneToNTypeMapping &operandMapping,
225+
const OneToNTypeMapping &resultMapping,
226+
ValueRange convertedOperands) const = 0;
227+
228+
LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter,
229+
const OneToNTypeMapping &operandMapping,
230+
const OneToNTypeMapping &resultMapping,
231+
ValueRange convertedOperands) const final {
232+
return matchAndRewrite(cast<SourceOp>(op), rewriter, operandMapping,
233+
resultMapping, convertedOperands);
234+
}
235+
};
236+
237+
/// Applies the given set of patterns recursively on the given op and adds user
238+
/// materializations where necessary. The patterns are expected to be
239+
/// `OneToNConversionPattern`, which help converting the types of the operands
240+
/// and results of the matched ops. The provided type converter is used to
241+
/// convert the operands of matched ops from their original types to operands
242+
/// with different types. Unlike in `DialectConversion`, this supports 1:N type
243+
/// conversions. Those conversions at the "boundary" of the pattern application,
244+
/// where converted results are not consumed by replaced ops that expect the
245+
/// converted operands or vice versa, the function inserts user materializations
246+
/// from the type converter. Also unlike `DialectConversion`, there are no legal
247+
/// or illegal types; the function simply applies the given patterns and does
248+
/// not fail if some ops or types remain unconverted (i.e., the conversion is
249+
/// only "partial").
250+
LogicalResult
251+
applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
252+
const FrozenRewritePatternSet &patterns);
253+
254+
} // namespace mlir
255+
256+
#endif // MLIR_TRANSFORMS_ONETONTYPECONVERSION_H

mlir/lib/Dialect/Func/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRFuncTransforms
33
DuplicateFunctionElimination.cpp
44
FuncBufferize.cpp
55
FuncConversions.cpp
6+
OneToNFuncConversions.cpp
67

78
ADDITIONAL_HEADER_DIRS
89
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Transforms

0 commit comments

Comments
 (0)