Skip to content

Commit 108d3a3

Browse files
Revert "[mlir][Transforms] Delete 1:N dialect conversion driver (llvm#121389)"
This reverts commit 23e3cbb. Signed-off-by: nithinsubbiah <[email protected]>
1 parent 2f9a3ab commit 108d3a3

File tree

23 files changed

+1801
-4
lines changed

23 files changed

+1801
-4
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//===- OneToNTypeFuncConversions.h - 1:N type conv. for Func ----*- C++ -*-===//
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_FUNC_TRANSFORMS_ONETONTYPEFUNCCONVERSIONS_H
10+
#define MLIR_DIALECT_FUNC_TRANSFORMS_ONETONTYPEFUNCCONVERSIONS_H
11+
12+
namespace mlir {
13+
class TypeConverter;
14+
class RewritePatternSet;
15+
} // namespace mlir
16+
17+
namespace mlir {
18+
19+
// Populates the provided pattern set with patterns that do 1:N type conversions
20+
// on func ops. This is intended to be used with `applyPartialOneToNConversion`.
21+
void populateFuncTypeConversionPatterns(const TypeConverter &typeConverter,
22+
RewritePatternSet &patterns);
23+
24+
} // namespace mlir
25+
26+
#endif // MLIR_DIALECT_FUNC_TRANSFORMS_ONETONTYPEFUNCCONVERSIONS_H

mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,15 @@ void populateSCFStructuralTypeConversions(const TypeConverter &typeConverter,
6363
void populateSCFStructuralTypeConversionTarget(
6464
const TypeConverter &typeConverter, ConversionTarget &target);
6565

66+
/// Populates the provided pattern set with patterns that do 1:N type
67+
/// conversions on (some) SCF ops. This is intended to be used with
68+
/// applyPartialOneToNConversion.
69+
/// FIXME: The 1:N dialect conversion is deprecated and will be removed soon.
70+
/// 1:N support has been added to the regular dialect conversion driver.
71+
/// Use populateSCFStructuralTypeConversions() instead.
72+
void populateSCFStructuralOneToNTypeConversions(
73+
const TypeConverter &typeConverter, RewritePatternSet &patterns);
74+
6675
/// Populate patterns for SCF software pipelining transformation. See the
6776
/// ForLoopPipeliningPattern for the transformation details.
6877
void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,

mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
2121
#include "mlir/Transforms/DialectConversion.h"
2222
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23+
#include "mlir/Transforms/OneToNTypeConversion.h"
2324
#include "llvm/ADT/SmallSet.h"
2425
#include "llvm/Support/LogicalResult.h"
2526

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,13 @@ class TypeConverter {
4545
// Copy the registered conversions, but not the caches
4646
TypeConverter(const TypeConverter &other)
4747
: conversions(other.conversions),
48+
argumentMaterializations(other.argumentMaterializations),
4849
sourceMaterializations(other.sourceMaterializations),
4950
targetMaterializations(other.targetMaterializations),
5051
typeAttributeConversions(other.typeAttributeConversions) {}
5152
TypeConverter &operator=(const TypeConverter &other) {
5253
conversions = other.conversions;
54+
argumentMaterializations = other.argumentMaterializations;
5355
sourceMaterializations = other.sourceMaterializations;
5456
targetMaterializations = other.targetMaterializations;
5557
typeAttributeConversions = other.typeAttributeConversions;
@@ -178,6 +180,21 @@ class TypeConverter {
178180
/// can be a TypeRange; in that case, the function must return a
179181
/// SmallVector<Value>.
180182

183+
/// This method registers a materialization that will be called when
184+
/// converting (potentially multiple) block arguments that were the result of
185+
/// a signature conversion of a single block argument, to a single SSA value
186+
/// with the old block argument type.
187+
///
188+
/// Note: Argument materializations are used only with the 1:N dialect
189+
/// conversion driver. The 1:N dialect conversion driver will be removed soon
190+
/// and so will be argument materializations.
191+
template <typename FnT, typename T = typename llvm::function_traits<
192+
std::decay_t<FnT>>::template arg_t<1>>
193+
void addArgumentMaterialization(FnT &&callback) {
194+
argumentMaterializations.emplace_back(
195+
wrapSourceMaterialization<T>(std::forward<FnT>(callback)));
196+
}
197+
181198
/// This method registers a materialization that will be called when
182199
/// converting a replacement value back to its original source type.
183200
/// This is used when some uses of the original value persist beyond the main
@@ -305,6 +322,8 @@ class TypeConverter {
305322
/// generating a cast sequence of some kind. See the respective
306323
/// `add*Materialization` for more information on the context for these
307324
/// methods.
325+
Value materializeArgumentConversion(OpBuilder &builder, Location loc,
326+
Type resultType, ValueRange inputs) const;
308327
Value materializeSourceConversion(OpBuilder &builder, Location loc,
309328
Type resultType, ValueRange inputs) const;
310329
Value materializeTargetConversion(OpBuilder &builder, Location loc,
@@ -490,6 +509,7 @@ class TypeConverter {
490509
SmallVector<ConversionCallbackFn, 4> conversions;
491510

492511
/// The list of registered materialization functions.
512+
SmallVector<SourceMaterializationCallbackFn, 2> argumentMaterializations;
493513
SmallVector<SourceMaterializationCallbackFn, 2> sourceMaterializations;
494514
SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations;
495515

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
//===-- OneToNTypeConversion.h - Utils for 1:N type conversion --*- C++ -*-===//
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Note: The 1:N dialect conversion is deprecated and will be removed soon.
10+
// 1:N support has been added to the regular dialect conversion driver.
11+
//
12+
// This file provides utils for implementing (poor-man's) dialect conversion
13+
// passes with 1:N type conversions.
14+
//
15+
// The main function, `applyPartialOneToNConversion`, first applies a set of
16+
// `RewritePattern`s, which produce unrealized casts to convert the operands and
17+
// results from and to the source types, and then replaces all newly added
18+
// unrealized casts by user-provided materializations. For this to work, the
19+
// main function requires a special `TypeConverter`, a special
20+
// `PatternRewriter`, and special RewritePattern`s, which extend their
21+
// respective base classes for 1:N type converions.
22+
//
23+
// Note that this is much more simple-minded than the "real" dialect conversion,
24+
// which checks for legality before applying patterns and does probably many
25+
// other additional things. Ideally, some of the extensions here could be
26+
// integrated there.
27+
//
28+
//===----------------------------------------------------------------------===//
29+
30+
#ifndef MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
31+
#define MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
32+
33+
#include "mlir/IR/PatternMatch.h"
34+
#include "mlir/Transforms/DialectConversion.h"
35+
#include "llvm/ADT/SmallVector.h"
36+
37+
namespace mlir {
38+
39+
/// Stores a 1:N mapping of types and provides several useful accessors. This
40+
/// class extends `SignatureConversion`, which already supports 1:N type
41+
/// mappings but lacks some accessors into the mapping as well as access to the
42+
/// original types.
43+
class OneToNTypeMapping : public TypeConverter::SignatureConversion {
44+
public:
45+
OneToNTypeMapping(TypeRange originalTypes)
46+
: TypeConverter::SignatureConversion(originalTypes.size()),
47+
originalTypes(originalTypes) {}
48+
49+
using TypeConverter::SignatureConversion::getConvertedTypes;
50+
51+
/// Returns the list of types that corresponds to the original type at the
52+
/// given index.
53+
TypeRange getConvertedTypes(unsigned originalTypeNo) const;
54+
55+
/// Returns the list of original types.
56+
TypeRange getOriginalTypes() const { return originalTypes; }
57+
58+
/// Returns the slice of converted values that corresponds the original value
59+
/// at the given index.
60+
ValueRange getConvertedValues(ValueRange convertedValues,
61+
unsigned originalValueNo) const;
62+
63+
/// Fills the given result vector with as many copies of the location of the
64+
/// original value as the number of values it is converted to.
65+
void convertLocation(Value originalValue, unsigned originalValueNo,
66+
llvm::SmallVectorImpl<Location> &result) const;
67+
68+
/// Fills the given result vector with as many copies of the lociation of each
69+
/// original value as the number of values they are respectively converted to.
70+
void convertLocations(ValueRange originalValues,
71+
llvm::SmallVectorImpl<Location> &result) const;
72+
73+
/// Returns true iff at least one type conversion maps an input type to a type
74+
/// that is different from itself.
75+
bool hasNonIdentityConversion() const;
76+
77+
private:
78+
llvm::SmallVector<Type> originalTypes;
79+
};
80+
81+
/// Extends the basic `RewritePattern` class with a type converter member and
82+
/// some accessors to it. This is useful for patterns that are not
83+
/// `ConversionPattern`s but still require access to a type converter.
84+
class RewritePatternWithConverter : public mlir::RewritePattern {
85+
public:
86+
/// Construct a conversion pattern with the given converter, and forward the
87+
/// remaining arguments to RewritePattern.
88+
template <typename... Args>
89+
RewritePatternWithConverter(const TypeConverter &typeConverter,
90+
Args &&...args)
91+
: RewritePattern(std::forward<Args>(args)...),
92+
typeConverter(&typeConverter) {}
93+
94+
/// Return the type converter held by this pattern, or nullptr if the pattern
95+
/// does not require type conversion.
96+
const TypeConverter *getTypeConverter() const { return typeConverter; }
97+
98+
template <typename ConverterTy>
99+
std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
100+
const ConverterTy *>
101+
getTypeConverter() const {
102+
return static_cast<const ConverterTy *>(typeConverter);
103+
}
104+
105+
protected:
106+
/// A type converter for use by this pattern.
107+
const TypeConverter *const typeConverter;
108+
};
109+
110+
/// Specialization of `PatternRewriter` that `OneToNConversionPattern`s use. The
111+
/// class provides additional rewrite methods that are specific to 1:N type
112+
/// conversions.
113+
class OneToNPatternRewriter : public PatternRewriter {
114+
public:
115+
OneToNPatternRewriter(MLIRContext *context,
116+
OpBuilder::Listener *listener = nullptr)
117+
: PatternRewriter(context, listener) {}
118+
119+
/// Replaces the results of the operation with the specified list of values
120+
/// mapped back to the original types as specified in the provided type
121+
/// mapping. That type mapping must match the replaced op (i.e., the original
122+
/// types must be the same as the result types of the op) and the new values
123+
/// (i.e., the converted types must be the same as the types of the new
124+
/// values).
125+
/// FIXME: The 1:N dialect conversion is deprecated and will be removed soon.
126+
/// Use replaceOpWithMultiple() instead.
127+
void replaceOp(Operation *op, ValueRange newValues,
128+
const OneToNTypeMapping &resultMapping);
129+
using PatternRewriter::replaceOp;
130+
131+
/// Applies the given argument conversion to the given block. This consists of
132+
/// replacing each original argument with N arguments as specified in the
133+
/// argument conversion and inserting unrealized casts from the converted
134+
/// values to the original types, which are then used in lieu of the original
135+
/// ones. (Eventually, `applyPartialOneToNConversion` replaces these casts
136+
/// with a user-provided argument materialization if necessary.) This is
137+
/// similar to `ArgConverter::applySignatureConversion` but (1) handles 1:N
138+
/// type conversion properly and probably (2) doesn't handle many other edge
139+
/// cases.
140+
Block *applySignatureConversion(Block *block,
141+
OneToNTypeMapping &argumentConversion);
142+
};
143+
144+
/// Base class for patterns with 1:N type conversions. Derived classes have to
145+
/// overwrite the `matchAndRewrite` overlaod that provides additional
146+
/// information for 1:N type conversions.
147+
class OneToNConversionPattern : public RewritePatternWithConverter {
148+
public:
149+
using RewritePatternWithConverter::RewritePatternWithConverter;
150+
151+
/// This function has to be implemented by derived classes and is called from
152+
/// the usual overloads. Like in "normal" `DialectConversion`, the function is
153+
/// provided with the converted operands (which thus have target types). Since
154+
/// 1:N conversions are supported, there is usually no 1:1 relationship
155+
/// between the original and the converted operands. Instead, the provided
156+
/// `operandMapping` can be used to access the converted operands that
157+
/// correspond to a particular original operand. Similarly, `resultMapping`
158+
/// is provided to help with assembling the result values, which may have 1:N
159+
/// correspondences as well. In that case, the original op should be replaced
160+
/// with the overload of `replaceOp` that takes the provided `resultMapping`
161+
/// in order to deal with the mapping of converted result values to their
162+
/// usages in the original types correctly.
163+
virtual LogicalResult matchAndRewrite(Operation *op,
164+
OneToNPatternRewriter &rewriter,
165+
const OneToNTypeMapping &operandMapping,
166+
const OneToNTypeMapping &resultMapping,
167+
ValueRange convertedOperands) const = 0;
168+
169+
LogicalResult matchAndRewrite(Operation *op,
170+
PatternRewriter &rewriter) const final;
171+
};
172+
173+
/// This class is a wrapper around `OneToNConversionPattern` for matching
174+
/// against instances of a particular op class.
175+
template <typename SourceOp>
176+
class OneToNOpConversionPattern : public OneToNConversionPattern {
177+
public:
178+
OneToNOpConversionPattern(const TypeConverter &typeConverter,
179+
MLIRContext *context, PatternBenefit benefit = 1,
180+
ArrayRef<StringRef> generatedNames = {})
181+
: OneToNConversionPattern(typeConverter, SourceOp::getOperationName(),
182+
benefit, context, generatedNames) {}
183+
/// Generic adaptor around the root op of this pattern using the converted
184+
/// operands. Importantly, each operand is represented as a *range* of values,
185+
/// namely the N values each original operand gets converted to. Concretely,
186+
/// this makes the result type of the accessor functions of the adaptor class
187+
/// be a `ValueRange`.
188+
class OpAdaptor
189+
: public SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> {
190+
public:
191+
using RangeT = ArrayRef<ValueRange>;
192+
using BaseT = typename SourceOp::template GenericAdaptor<RangeT>;
193+
using Properties = typename SourceOp::template InferredProperties<SourceOp>;
194+
195+
OpAdaptor(const OneToNTypeMapping *operandMapping,
196+
const OneToNTypeMapping *resultMapping,
197+
const ValueRange *convertedOperands, RangeT values, SourceOp op)
198+
: BaseT(values, op), operandMapping(operandMapping),
199+
resultMapping(resultMapping), convertedOperands(convertedOperands) {}
200+
201+
/// Get the type mapping of the original operands to the converted operands.
202+
const OneToNTypeMapping &getOperandMapping() const {
203+
return *operandMapping;
204+
}
205+
206+
/// Get the type mapping of the original results to the converted results.
207+
const OneToNTypeMapping &getResultMapping() const { return *resultMapping; }
208+
209+
/// Get a flat range of all converted operands. Unlike `getOperands`, which
210+
/// returns an `ArrayRef` with one `ValueRange` for each original operand,
211+
/// this function returns a `ValueRange` that contains all converted
212+
/// operands irrespectively of which operand they originated from.
213+
ValueRange getFlatOperands() const { return *convertedOperands; }
214+
215+
private:
216+
const OneToNTypeMapping *operandMapping;
217+
const OneToNTypeMapping *resultMapping;
218+
const ValueRange *convertedOperands;
219+
};
220+
221+
using OneToNConversionPattern::matchAndRewrite;
222+
223+
/// Overload that derived classes have to override for their op type.
224+
virtual LogicalResult
225+
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
226+
OneToNPatternRewriter &rewriter) const = 0;
227+
228+
LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter,
229+
const OneToNTypeMapping &operandMapping,
230+
const OneToNTypeMapping &resultMapping,
231+
ValueRange convertedOperands) const final {
232+
// Wrap converted operands and type mappings into an adaptor.
233+
SmallVector<ValueRange> valueRanges;
234+
for (int64_t i = 0; i < op->getNumOperands(); i++) {
235+
auto values = operandMapping.getConvertedValues(convertedOperands, i);
236+
valueRanges.push_back(values);
237+
}
238+
OpAdaptor adaptor(&operandMapping, &resultMapping, &convertedOperands,
239+
valueRanges, cast<SourceOp>(op));
240+
241+
// Call overload implemented by the derived class.
242+
return matchAndRewrite(cast<SourceOp>(op), adaptor, rewriter);
243+
}
244+
};
245+
246+
/// Applies the given set of patterns recursively on the given op and adds user
247+
/// materializations where necessary. The patterns are expected to be
248+
/// `OneToNConversionPattern`, which help converting the types of the operands
249+
/// and results of the matched ops. The provided type converter is used to
250+
/// convert the operands of matched ops from their original types to operands
251+
/// with different types. Unlike in `DialectConversion`, this supports 1:N type
252+
/// conversions. Those conversions at the "boundary" of the pattern application,
253+
/// where converted results are not consumed by replaced ops that expect the
254+
/// converted operands or vice versa, the function inserts user materializations
255+
/// from the type converter. Also unlike `DialectConversion`, there are no legal
256+
/// or illegal types; the function simply applies the given patterns and does
257+
/// not fail if some ops or types remain unconverted (i.e., the conversion is
258+
/// only "partial").
259+
/// FIXME: The 1:N dialect conversion is deprecated and will be removed soon.
260+
/// 1:N support has been added to the regular dialect conversion driver.
261+
/// Use applyPartialConversion() instead.
262+
LogicalResult
263+
applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
264+
const FrozenRewritePatternSet &patterns);
265+
266+
/// Add a pattern to the given pattern list to convert the signature of a
267+
/// FunctionOpInterface op with the given type converter. This only supports
268+
/// ops which use FunctionType to represent their type. This is intended to be
269+
/// used with the 1:N dialect conversion.
270+
/// FIXME: The 1:N dialect conversion is deprecated and will be removed soon.
271+
/// 1:N support has been added to the regular dialect conversion driver.
272+
/// Use populateFunctionOpInterfaceTypeConversionPattern() instead.
273+
void populateOneToNFunctionOpInterfaceTypeConversionPattern(
274+
StringRef functionLikeOpName, const TypeConverter &converter,
275+
RewritePatternSet &patterns);
276+
template <typename FuncOpT>
277+
void populateOneToNFunctionOpInterfaceTypeConversionPattern(
278+
const TypeConverter &converter, RewritePatternSet &patterns) {
279+
populateOneToNFunctionOpInterfaceTypeConversionPattern(
280+
FuncOpT::getOperationName(), converter, patterns);
281+
}
282+
283+
} // namespace mlir
284+
285+
#endif // MLIR_TRANSFORMS_ONETONTYPECONVERSION_H

mlir/lib/Dialect/Func/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_mlir_dialect_library(MLIRFuncTransforms
22
DuplicateFunctionElimination.cpp
33
FuncConversions.cpp
4+
OneToNFuncConversions.cpp
45

56
ADDITIONAL_HEADER_DIRS
67
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Transforms

0 commit comments

Comments
 (0)