Skip to content

Commit 5c86bfa

Browse files
[mlir][Transforms] Dialect Conversion: Simplify block conversion API
This commit simplifies and improves documentation for the part of the `ConversionPatternRewriter` API that deals with signature conversions. There are now two public functions for signature conversion: * `applySignatureConversion` converts a single block signature. * `convertRegionTypes` converts all block signatures of a region. Note: `convertRegionTypes` could be renamed to `applySignatureConversion` (overload) in the future. Also clarify when a type converter and/or signature conversion object is needed and for what purpose. From a functional perspective, this change is NFC. However, the public API changes, thus not marking as NFC.
1 parent 1e92ad4 commit 5c86bfa

File tree

5 files changed

+73
-145
lines changed

5 files changed

+73
-145
lines changed

mlir/docs/DialectConversion.md

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -372,19 +372,23 @@ class TypeConverter {
372372
From the perspective of type conversion, the types of block arguments are a bit
373373
special. Throughout the conversion process, blocks may move between regions of
374374
different operations. Given this, the conversion of the types for blocks must be
375-
done explicitly via a conversion pattern. To convert the types of block
376-
arguments within a Region, a custom hook on the `ConversionPatternRewriter` must
377-
be invoked; `convertRegionTypes`. This hook uses a provided type converter to
378-
apply type conversions to all blocks within a given region, and all blocks that
379-
move into that region. As noted above, the conversions performed by this method
380-
use the argument materialization hook on the `TypeConverter`. This hook also
381-
takes an optional `TypeConverter::SignatureConversion` parameter that applies a
382-
custom conversion to the entry block of the region. The types of the entry block
383-
arguments are often tied semantically to details on the operation, e.g. func::FuncOp,
384-
AffineForOp, etc. To convert the signature of just the region entry block, and
385-
not any other blocks within the region, the `applySignatureConversion` hook may
386-
be used instead. A signature conversion, `TypeConverter::SignatureConversion`,
387-
can be built programmatically:
375+
done explicitly via a conversion pattern.
376+
377+
To convert the types of block arguments within a Region, a custom hook on the
378+
`ConversionPatternRewriter` must be invoked; `convertRegionTypes`. This hook
379+
uses a provided type converter to apply type conversions to all blocks of a
380+
given region. As noted above, the conversions performed by this method use the
381+
argument materialization hook on the `TypeConverter`. This hook also takes an
382+
optional `TypeConverter::SignatureConversion` parameter that applies a custom
383+
conversion to the entry block of the region. The types of the entry block
384+
arguments are often tied semantically to details on the operation, e.g.,
385+
`func::FuncOp`, `AffineForOp`, etc.
386+
387+
To convert the signature of just one given block, the
388+
`applySignatureConversion` hook can be used.
389+
390+
A signature conversion, `TypeConverter::SignatureConversion`, can be built
391+
programmatically:
388392

389393
```c++
390394
class SignatureConversion {

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,8 @@ class TypeConverter {
247247
/// Attempts a 1-1 type conversion, expecting the result type to be
248248
/// `TargetType`. Returns the converted type cast to `TargetType` on success,
249249
/// and a null type on conversion or cast failure.
250-
template <typename TargetType> TargetType convertType(Type t) const {
250+
template <typename TargetType>
251+
TargetType convertType(Type t) const {
251252
return dyn_cast_or_null<TargetType>(convertType(t));
252253
}
253254

@@ -661,42 +662,38 @@ class ConversionPatternRewriter final : public PatternRewriter {
661662
public:
662663
~ConversionPatternRewriter() override;
663664

664-
/// Apply a signature conversion to the entry block of the given region. This
665-
/// replaces the entry block with a new block containing the updated
666-
/// signature. The new entry block to the region is returned for convenience.
665+
/// Apply a signature conversion to given block. This replaces the block with
666+
/// a new block containing the updated signature. The operations of the given
667+
/// block are inlined into the newly-created block, which is returned.
668+
///
667669
/// If no block argument types are changing, the entry original block will be
668670
/// left in place and returned.
669671
///
670-
/// If provided, `converter` will be used for any materializations.
672+
/// A signature converison must be provided. (Type converters can construct
673+
/// signature conversion with `convertBlockSignature`.) Optionally, a type
674+
/// converter can be provided to build materializations.
671675
Block *
672-
applySignatureConversion(Region *region,
676+
applySignatureConversion(Block *block,
673677
TypeConverter::SignatureConversion &conversion,
674678
const TypeConverter *converter = nullptr);
675679

676-
/// Convert the types of block arguments within the given region. This
680+
/// Apply a signature conversion to each block in the given region. This
677681
/// replaces each block with a new block containing the updated signature. If
678682
/// an updated signature would match the current signature, the respective
679-
/// block is left in place as is.
683+
/// block is left in place as is. (See `applySignatureConversion` for
684+
/// details.) The new entry block of the region is returned.
685+
///
686+
/// SignatureConversions are computed with the specified type converter.
687+
/// This function returns "failure" if the type converter failed to compute
688+
/// a SignatureConversion for at least one block.
680689
///
681-
/// The entry block may have a special conversion if `entryConversion` is
682-
/// provided. On success, the new entry block to the region is returned for
683-
/// convenience. Otherwise, failure is returned.
690+
/// Optionally, a special SignatureConversion can be specified for the entry
691+
/// block. This is because the types of the entry block arguments are often
692+
/// tied semantically to details on the operation.
684693
FailureOr<Block *> convertRegionTypes(
685694
Region *region, const TypeConverter &converter,
686695
TypeConverter::SignatureConversion *entryConversion = nullptr);
687696

688-
/// Convert the types of block arguments within the given region except for
689-
/// the entry region. This replaces each non-entry block with a new block
690-
/// containing the updated signature. If an updated signature would match the
691-
/// current signature, the respective block is left in place as is.
692-
///
693-
/// If special conversion behavior is needed for the non-entry blocks (for
694-
/// example, we need to convert only a subset of a BB arguments), such
695-
/// behavior can be specified in blockConversions.
696-
LogicalResult convertNonEntryRegionTypes(
697-
Region *region, const TypeConverter &converter,
698-
ArrayRef<TypeConverter::SignatureConversion> blockConversions);
699-
700697
/// Replace all the uses of the block argument `from` with value `to`.
701698
void replaceUsesOfBlockArgument(BlockArgument from, Value to);
702699

mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
162162
signatureConverter.remapInput(0, newIndVar);
163163
for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
164164
signatureConverter.remapInput(i, header->getArgument(i));
165-
body = rewriter.applySignatureConversion(&forOp.getRegion(),
165+
body = rewriter.applySignatureConversion(&forOp.getRegion().front(),
166166
signatureConverter);
167167

168168
// Move the blocks from the forOp into the loopOp. This is the body of the

mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -106,27 +106,23 @@ struct FunctionNonEntryBlockConversion
106106
ConversionPatternRewriter &rewriter) const override {
107107
rewriter.startOpModification(op);
108108
Region &region = op.getFunctionBody();
109-
SmallVector<TypeConverter::SignatureConversion, 2> conversions;
110109

111-
for (Block &block : llvm::drop_begin(region, 1)) {
112-
conversions.emplace_back(block.getNumArguments());
113-
TypeConverter::SignatureConversion &back = conversions.back();
110+
for (Block &block :
111+
llvm::make_early_inc_range(llvm::drop_begin(region, 1))) {
112+
TypeConverter::SignatureConversion conversion(
113+
/*numOrigInputs=*/block.getNumArguments());
114114

115115
for (BlockArgument blockArgument : block.getArguments()) {
116116
int idx = blockArgument.getArgNumber();
117117

118118
if (blockArgsToDetensor.count(blockArgument))
119-
back.addInputs(idx, {getTypeConverter()->convertType(
120-
block.getArgumentTypes()[idx])});
119+
conversion.addInputs(idx, {getTypeConverter()->convertType(
120+
block.getArgumentTypes()[idx])});
121121
else
122-
back.addInputs(idx, {block.getArgumentTypes()[idx]});
122+
conversion.addInputs(idx, {block.getArgumentTypes()[idx]});
123123
}
124-
}
125124

126-
if (failed(rewriter.convertNonEntryRegionTypes(&region, *typeConverter,
127-
conversions))) {
128-
rewriter.cancelOpModification(op);
129-
return failure();
125+
rewriter.applySignatureConversion(&block, conversion, getTypeConverter());
130126
}
131127

132128
rewriter.finalizeOpModification(op);

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 27 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -839,27 +839,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
839839
// Type Conversion
840840
//===--------------------------------------------------------------------===//
841841

842-
/// Attempt to convert the signature of the given block, if successful a new
843-
/// block is returned containing the new arguments. Returns `block` if it did
844-
/// not require conversion.
845-
FailureOr<Block *> convertBlockSignature(
846-
ConversionPatternRewriter &rewriter, Block *block,
847-
const TypeConverter *converter,
848-
TypeConverter::SignatureConversion *conversion = nullptr);
849-
850-
/// Convert the types of non-entry block arguments within the given region.
851-
LogicalResult convertNonEntryRegionTypes(
852-
ConversionPatternRewriter &rewriter, Region *region,
853-
const TypeConverter &converter,
854-
ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
855-
856-
/// Apply a signature conversion on the given region, using `converter` for
857-
/// materializations if not null.
858-
Block *
859-
applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region,
860-
TypeConverter::SignatureConversion &conversion,
861-
const TypeConverter *converter);
862-
863842
/// Convert the types of block arguments within the given region.
864843
FailureOr<Block *>
865844
convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
@@ -1294,34 +1273,6 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
12941273
//===----------------------------------------------------------------------===//
12951274
// Type Conversion
12961275

1297-
FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
1298-
ConversionPatternRewriter &rewriter, Block *block,
1299-
const TypeConverter *converter,
1300-
TypeConverter::SignatureConversion *conversion) {
1301-
if (conversion)
1302-
return applySignatureConversion(rewriter, block, converter, *conversion);
1303-
1304-
// If a converter wasn't provided, and the block wasn't already converted,
1305-
// there is nothing we can do.
1306-
if (!converter)
1307-
return failure();
1308-
1309-
// Try to convert the signature for the block with the provided converter.
1310-
if (auto conversion = converter->convertBlockSignature(block))
1311-
return applySignatureConversion(rewriter, block, converter, *conversion);
1312-
return failure();
1313-
}
1314-
1315-
Block *ConversionPatternRewriterImpl::applySignatureConversion(
1316-
ConversionPatternRewriter &rewriter, Region *region,
1317-
TypeConverter::SignatureConversion &conversion,
1318-
const TypeConverter *converter) {
1319-
if (!region->empty())
1320-
return *convertBlockSignature(rewriter, &region->front(), converter,
1321-
&conversion);
1322-
return nullptr;
1323-
}
1324-
13251276
FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
13261277
ConversionPatternRewriter &rewriter, Region *region,
13271278
const TypeConverter &converter,
@@ -1330,42 +1281,29 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
13301281
if (region->empty())
13311282
return nullptr;
13321283

1333-
if (failed(convertNonEntryRegionTypes(rewriter, region, converter)))
1334-
return failure();
1335-
1336-
FailureOr<Block *> newEntry = convertBlockSignature(
1337-
rewriter, &region->front(), &converter, entryConversion);
1338-
return newEntry;
1339-
}
1340-
1341-
LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
1342-
ConversionPatternRewriter &rewriter, Region *region,
1343-
const TypeConverter &converter,
1344-
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
1345-
regionToConverter[region] = &converter;
1346-
if (region->empty())
1347-
return success();
1348-
1349-
// Convert the arguments of each block within the region.
1350-
int blockIdx = 0;
1351-
assert((blockConversions.empty() ||
1352-
blockConversions.size() == region->getBlocks().size() - 1) &&
1353-
"expected either to provide no SignatureConversions at all or to "
1354-
"provide a SignatureConversion for each non-entry block");
1355-
1284+
// Convert the arguments of each non-entry block within the region.
13561285
for (Block &block :
13571286
llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
1358-
TypeConverter::SignatureConversion *blockConversion =
1359-
blockConversions.empty()
1360-
? nullptr
1361-
: const_cast<TypeConverter::SignatureConversion *>(
1362-
&blockConversions[blockIdx++]);
1363-
1364-
if (failed(convertBlockSignature(rewriter, &block, &converter,
1365-
blockConversion)))
1287+
// Compute the signature for the block with the provided converter.
1288+
std::optional<TypeConverter::SignatureConversion> conversion =
1289+
converter.convertBlockSignature(&block);
1290+
if (!conversion)
13661291
return failure();
1367-
}
1368-
return success();
1292+
// Convert the block with the computed signature.
1293+
applySignatureConversion(rewriter, &block, &converter, *conversion);
1294+
}
1295+
1296+
// Convert the entry block. If an entry signature conversion was provided,
1297+
// use that one. Otherwise, compute the signature with the type converter.
1298+
if (entryConversion)
1299+
return applySignatureConversion(rewriter, &region->front(), &converter,
1300+
*entryConversion);
1301+
std::optional<TypeConverter::SignatureConversion> conversion =
1302+
converter.convertBlockSignature(&region->front());
1303+
if (!conversion)
1304+
return failure();
1305+
return applySignatureConversion(rewriter, &region->front(), &converter,
1306+
*conversion);
13691307
}
13701308

13711309
Block *ConversionPatternRewriterImpl::applySignatureConversion(
@@ -1676,12 +1614,12 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
16761614
}
16771615

16781616
Block *ConversionPatternRewriter::applySignatureConversion(
1679-
Region *region, TypeConverter::SignatureConversion &conversion,
1617+
Block *block, TypeConverter::SignatureConversion &conversion,
16801618
const TypeConverter *converter) {
1681-
assert(!impl->wasOpReplaced(region->getParentOp()) &&
1619+
assert(!impl->wasOpReplaced(block->getParentOp()) &&
16821620
"attempting to apply a signature conversion to a block within a "
16831621
"replaced/erased op");
1684-
return impl->applySignatureConversion(*this, region, conversion, converter);
1622+
return impl->applySignatureConversion(*this, block, converter, conversion);
16851623
}
16861624

16871625
FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
@@ -1693,16 +1631,6 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
16931631
return impl->convertRegionTypes(*this, region, converter, entryConversion);
16941632
}
16951633

1696-
LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
1697-
Region *region, const TypeConverter &converter,
1698-
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
1699-
assert(!impl->wasOpReplaced(region->getParentOp()) &&
1700-
"attempting to apply a signature conversion to a block within a "
1701-
"replaced/erased op");
1702-
return impl->convertNonEntryRegionTypes(*this, region, converter,
1703-
blockConversions);
1704-
}
1705-
17061634
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
17071635
Value to) {
17081636
LLVM_DEBUG({
@@ -2231,11 +2159,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
22312159
// If the region of the block has a type converter, try to convert the block
22322160
// directly.
22332161
if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
2234-
if (failed(impl.convertBlockSignature(rewriter, block, converter))) {
2162+
std::optional<TypeConverter::SignatureConversion> conversion =
2163+
converter->convertBlockSignature(block);
2164+
if (!conversion) {
22352165
LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
22362166
"block"));
22372167
return failure();
22382168
}
2169+
impl.applySignatureConversion(rewriter, block, converter, *conversion);
22392170
continue;
22402171
}
22412172

0 commit comments

Comments
 (0)