Skip to content

[mlir][Transforms] Dialect Conversion: Simplify block conversion API #94866

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Jun 8, 2024

This commit simplifies and improves documentation for the part of the ConversionPatternRewriter API that deals with signature conversions.

There are now two public functions for signature conversion:

  • applySignatureConversion converts a single block signature. This function used to take a Region * (but converted only the entry block). It now takes a Block *.
  • convertRegionTypes converts all block signatures of a region.

convertNonEntryRegionTypes is removed because it is not widely used and can easily be expressed with a call to applySignatureConversion inside a loop. (See Detensorize.cpp for an example.)

Note: For consistency, convertRegionTypes could be renamed to applySignatureConversion (overload) in the future. (Or applySignatureConversion renamed to convertBlockTypes.)

Also clarify when a type converter and/or signature conversion object is needed and for what purpose.

Internal code refactoring (NFC) of ConversionPatternRewriterImpl (the part that deals with signature conversions). This part of the codebase was quite convoluted and unintuitive.

From a functional perspective, this change is NFC. However, the public API changes, thus not marking as NFC.

Note for LLVM integration: When you see applySignatureConversion(region, ...), replace with applySignatureConversion(region->front(), ...). In the unlikely case that you see convertNonEntryRegionTypes, apply the same changes as this commit did to Detensorize.cpp.

@llvmbot
Copy link
Member

llvmbot commented Jun 8, 2024

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

This commit simplifies and improves documentation for the part of the ConversionPatternRewriter API that deals with signature conversions.

There are now two public functions for signature conversion:

  • applySignatureConversion converts a single block signature. This function used to take a Region * (but converted only the entry block). It now takes a Block *.
  • convertRegionTypes converts all block signatures of a region.

convertNonEntryRegionTypes is removed because it is not widely used and can easily be expressed with a call to applySignatureConversion inside a loop. (See Detensorize.cpp for an example.)

Note: For consistency, convertRegionTypes could be renamed to applySignatureConversion (overload) in the future.

Also clarify when a type converter and/or signature conversion object is needed and for what purpose.

Internal code refactoring (NFC) of ConversionPatternRewriterImpl (the part that deals with signature conversions). This part of the codebase was quite convoluted and unintuitive.

From a functional perspective, this change is NFC. However, the public API changes, thus not marking as NFC.

Note for LLVM integration: When you see applySignatureConversion(region, ...), replace with applySignatureConversion(region->front(), ...). In the unlikely case that you see convertNonEntryRegionTypes, apply the same changes as this commit did to Detensorize.cpp.


Full diff: https://github.com/llvm/llvm-project/pull/94866.diff

5 Files Affected:

  • (modified) mlir/docs/DialectConversion.md (+17-13)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+20-23)
  • (modified) mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp (+8-12)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+27-96)
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index a355d5a90e4d1..8338109eb97c3 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -372,19 +372,23 @@ class TypeConverter {
 From the perspective of type conversion, the types of block arguments are a bit
 special. Throughout the conversion process, blocks may move between regions of
 different operations. Given this, the conversion of the types for blocks must be
-done explicitly via a conversion pattern. To convert the types of block
-arguments within a Region, a custom hook on the `ConversionPatternRewriter` must
-be invoked; `convertRegionTypes`. This hook uses a provided type converter to
-apply type conversions to all blocks within a given region, and all blocks that
-move into that region. As noted above, the conversions performed by this method
-use the argument materialization hook on the `TypeConverter`. This hook also
-takes an optional `TypeConverter::SignatureConversion` parameter that applies a
-custom conversion to the entry block of the region. The types of the entry block
-arguments are often tied semantically to details on the operation, e.g. func::FuncOp,
-AffineForOp, etc. To convert the signature of just the region entry block, and
-not any other blocks within the region, the `applySignatureConversion` hook may
-be used instead. A signature conversion, `TypeConverter::SignatureConversion`,
-can be built programmatically:
+done explicitly via a conversion pattern. 
+
+To convert the types of block arguments within a Region, a custom hook on the
+`ConversionPatternRewriter` must be invoked; `convertRegionTypes`. This hook
+uses a provided type converter to apply type conversions to all blocks of a
+given region. As noted above, the conversions performed by this method use the
+argument materialization hook on the `TypeConverter`. This hook also takes an
+optional `TypeConverter::SignatureConversion` parameter that applies a custom
+conversion to the entry block of the region. The types of the entry block
+arguments are often tied semantically to details on the operation, e.g.,
+`func::FuncOp`, `AffineForOp`, etc.
+
+To convert the signature of just one given block, the
+`applySignatureConversion` hook can be used.
+
+A signature conversion, `TypeConverter::SignatureConversion`, can be built
+programmatically:
 
 ```c++
 class SignatureConversion {
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 83198c9b0db54..5f4a972748ffc 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -247,7 +247,8 @@ class TypeConverter {
   /// Attempts a 1-1 type conversion, expecting the result type to be
   /// `TargetType`. Returns the converted type cast to `TargetType` on success,
   /// and a null type on conversion or cast failure.
-  template <typename TargetType> TargetType convertType(Type t) const {
+  template <typename TargetType>
+  TargetType convertType(Type t) const {
     return dyn_cast_or_null<TargetType>(convertType(t));
   }
 
@@ -661,42 +662,38 @@ class ConversionPatternRewriter final : public PatternRewriter {
 public:
   ~ConversionPatternRewriter() override;
 
-  /// Apply a signature conversion to the entry block of the given region. This
-  /// replaces the entry block with a new block containing the updated
-  /// signature. The new entry block to the region is returned for convenience.
+  /// Apply a signature conversion to given block. This replaces the block with
+  /// a new block containing the updated signature. The operations of the given
+  /// block are inlined into the newly-created block, which is returned.
+  ///
   /// If no block argument types are changing, the entry original block will be
   /// left in place and returned.
   ///
-  /// If provided, `converter` will be used for any materializations.
+  /// A signature converison must be provided. (Type converters can construct
+  /// signature conversion with `convertBlockSignature`.) Optionally, a type
+  /// converter can be provided to build materializations.
   Block *
-  applySignatureConversion(Region *region,
+  applySignatureConversion(Block *block,
                            TypeConverter::SignatureConversion &conversion,
                            const TypeConverter *converter = nullptr);
 
-  /// Convert the types of block arguments within the given region. This
+  /// Apply a signature conversion to each block in the given region. This
   /// replaces each block with a new block containing the updated signature. If
   /// an updated signature would match the current signature, the respective
-  /// block is left in place as is.
+  /// block is left in place as is. (See `applySignatureConversion` for
+  /// details.) The new entry block of the region is returned.
+  ///
+  /// SignatureConversions are computed with the specified type converter.
+  /// This function returns "failure" if the type converter failed to compute
+  /// a SignatureConversion for at least one block.
   ///
-  /// The entry block may have a special conversion if `entryConversion` is
-  /// provided. On success, the new entry block to the region is returned for
-  /// convenience. Otherwise, failure is returned.
+  /// Optionally, a special SignatureConversion can be specified for the entry
+  /// block. This is because the types of the entry block arguments are often
+  /// tied semantically to details on the operation.
   FailureOr<Block *> convertRegionTypes(
       Region *region, const TypeConverter &converter,
       TypeConverter::SignatureConversion *entryConversion = nullptr);
 
-  /// Convert the types of block arguments within the given region except for
-  /// the entry region. This replaces each non-entry block with a new block
-  /// containing the updated signature. If an updated signature would match the
-  /// current signature, the respective block is left in place as is.
-  ///
-  /// If special conversion behavior is needed for the non-entry blocks (for
-  /// example, we need to convert only a subset of a BB arguments), such
-  /// behavior can be specified in blockConversions.
-  LogicalResult convertNonEntryRegionTypes(
-      Region *region, const TypeConverter &converter,
-      ArrayRef<TypeConverter::SignatureConversion> blockConversions);
-
   /// Replace all the uses of the block argument `from` with value `to`.
   void replaceUsesOfBlockArgument(BlockArgument from, Value to);
 
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index d90cf931385fc..f62de1f17a666 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -162,7 +162,7 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
     signatureConverter.remapInput(0, newIndVar);
     for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
       signatureConverter.remapInput(i, header->getArgument(i));
-    body = rewriter.applySignatureConversion(&forOp.getRegion(),
+    body = rewriter.applySignatureConversion(&forOp.getRegion().front(),
                                              signatureConverter);
 
     // Move the blocks from the forOp into the loopOp. This is the body of the
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 22968096a6891..af38485291182 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -106,27 +106,23 @@ struct FunctionNonEntryBlockConversion
                   ConversionPatternRewriter &rewriter) const override {
     rewriter.startOpModification(op);
     Region &region = op.getFunctionBody();
-    SmallVector<TypeConverter::SignatureConversion, 2> conversions;
 
-    for (Block &block : llvm::drop_begin(region, 1)) {
-      conversions.emplace_back(block.getNumArguments());
-      TypeConverter::SignatureConversion &back = conversions.back();
+    for (Block &block :
+         llvm::make_early_inc_range(llvm::drop_begin(region, 1))) {
+      TypeConverter::SignatureConversion conversion(
+          /*numOrigInputs=*/block.getNumArguments());
 
       for (BlockArgument blockArgument : block.getArguments()) {
         int idx = blockArgument.getArgNumber();
 
         if (blockArgsToDetensor.count(blockArgument))
-          back.addInputs(idx, {getTypeConverter()->convertType(
-                                  block.getArgumentTypes()[idx])});
+          conversion.addInputs(idx, {getTypeConverter()->convertType(
+                                        block.getArgumentTypes()[idx])});
         else
-          back.addInputs(idx, {block.getArgumentTypes()[idx]});
+          conversion.addInputs(idx, {block.getArgumentTypes()[idx]});
       }
-    }
 
-    if (failed(rewriter.convertNonEntryRegionTypes(&region, *typeConverter,
-                                                   conversions))) {
-      rewriter.cancelOpModification(op);
-      return failure();
+      rewriter.applySignatureConversion(&block, conversion, getTypeConverter());
     }
 
     rewriter.finalizeOpModification(op);
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index d407d60334c70..2f0efe1b1e454 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -839,27 +839,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   // Type Conversion
   //===--------------------------------------------------------------------===//
 
-  /// Attempt to convert the signature of the given block, if successful a new
-  /// block is returned containing the new arguments. Returns `block` if it did
-  /// not require conversion.
-  FailureOr<Block *> convertBlockSignature(
-      ConversionPatternRewriter &rewriter, Block *block,
-      const TypeConverter *converter,
-      TypeConverter::SignatureConversion *conversion = nullptr);
-
-  /// Convert the types of non-entry block arguments within the given region.
-  LogicalResult convertNonEntryRegionTypes(
-      ConversionPatternRewriter &rewriter, Region *region,
-      const TypeConverter &converter,
-      ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
-
-  /// Apply a signature conversion on the given region, using `converter` for
-  /// materializations if not null.
-  Block *
-  applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region,
-                           TypeConverter::SignatureConversion &conversion,
-                           const TypeConverter *converter);
-
   /// Convert the types of block arguments within the given region.
   FailureOr<Block *>
   convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
@@ -1294,34 +1273,6 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
 //===----------------------------------------------------------------------===//
 // Type Conversion
 
-FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
-    ConversionPatternRewriter &rewriter, Block *block,
-    const TypeConverter *converter,
-    TypeConverter::SignatureConversion *conversion) {
-  if (conversion)
-    return applySignatureConversion(rewriter, block, converter, *conversion);
-
-  // If a converter wasn't provided, and the block wasn't already converted,
-  // there is nothing we can do.
-  if (!converter)
-    return failure();
-
-  // Try to convert the signature for the block with the provided converter.
-  if (auto conversion = converter->convertBlockSignature(block))
-    return applySignatureConversion(rewriter, block, converter, *conversion);
-  return failure();
-}
-
-Block *ConversionPatternRewriterImpl::applySignatureConversion(
-    ConversionPatternRewriter &rewriter, Region *region,
-    TypeConverter::SignatureConversion &conversion,
-    const TypeConverter *converter) {
-  if (!region->empty())
-    return *convertBlockSignature(rewriter, &region->front(), converter,
-                                  &conversion);
-  return nullptr;
-}
-
 FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
     ConversionPatternRewriter &rewriter, Region *region,
     const TypeConverter &converter,
@@ -1330,42 +1281,29 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
   if (region->empty())
     return nullptr;
 
-  if (failed(convertNonEntryRegionTypes(rewriter, region, converter)))
-    return failure();
-
-  FailureOr<Block *> newEntry = convertBlockSignature(
-      rewriter, &region->front(), &converter, entryConversion);
-  return newEntry;
-}
-
-LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
-    ConversionPatternRewriter &rewriter, Region *region,
-    const TypeConverter &converter,
-    ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
-  regionToConverter[region] = &converter;
-  if (region->empty())
-    return success();
-
-  // Convert the arguments of each block within the region.
-  int blockIdx = 0;
-  assert((blockConversions.empty() ||
-          blockConversions.size() == region->getBlocks().size() - 1) &&
-         "expected either to provide no SignatureConversions at all or to "
-         "provide a SignatureConversion for each non-entry block");
-
+  // Convert the arguments of each non-entry block within the region.
   for (Block &block :
        llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
-    TypeConverter::SignatureConversion *blockConversion =
-        blockConversions.empty()
-            ? nullptr
-            : const_cast<TypeConverter::SignatureConversion *>(
-                  &blockConversions[blockIdx++]);
-
-    if (failed(convertBlockSignature(rewriter, &block, &converter,
-                                     blockConversion)))
+    // Compute the signature for the block with the provided converter.
+    std::optional<TypeConverter::SignatureConversion> conversion =
+        converter.convertBlockSignature(&block);
+    if (!conversion)
       return failure();
-  }
-  return success();
+    // Convert the block with the computed signature.
+    applySignatureConversion(rewriter, &block, &converter, *conversion);
+  }
+
+  // Convert the entry block. If an entry signature conversion was provided,
+  // use that one. Otherwise, compute the signature with the type converter.
+  if (entryConversion)
+    return applySignatureConversion(rewriter, &region->front(), &converter,
+                                    *entryConversion);
+  std::optional<TypeConverter::SignatureConversion> conversion =
+      converter.convertBlockSignature(&region->front());
+  if (!conversion)
+    return failure();
+  return applySignatureConversion(rewriter, &region->front(), &converter,
+                                  *conversion);
 }
 
 Block *ConversionPatternRewriterImpl::applySignatureConversion(
@@ -1676,12 +1614,12 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
 }
 
 Block *ConversionPatternRewriter::applySignatureConversion(
-    Region *region, TypeConverter::SignatureConversion &conversion,
+    Block *block, TypeConverter::SignatureConversion &conversion,
     const TypeConverter *converter) {
-  assert(!impl->wasOpReplaced(region->getParentOp()) &&
+  assert(!impl->wasOpReplaced(block->getParentOp()) &&
          "attempting to apply a signature conversion to a block within a "
          "replaced/erased op");
-  return impl->applySignatureConversion(*this, region, conversion, converter);
+  return impl->applySignatureConversion(*this, block, converter, conversion);
 }
 
 FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
@@ -1693,16 +1631,6 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
   return impl->convertRegionTypes(*this, region, converter, entryConversion);
 }
 
-LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
-    Region *region, const TypeConverter &converter,
-    ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
-  assert(!impl->wasOpReplaced(region->getParentOp()) &&
-         "attempting to apply a signature conversion to a block within a "
-         "replaced/erased op");
-  return impl->convertNonEntryRegionTypes(*this, region, converter,
-                                          blockConversions);
-}
-
 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
                                                            Value to) {
   LLVM_DEBUG({
@@ -2231,11 +2159,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
     // If the region of the block has a type converter, try to convert the block
     // directly.
     if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
-      if (failed(impl.convertBlockSignature(rewriter, block, converter))) {
+      std::optional<TypeConverter::SignatureConversion> conversion =
+          converter->convertBlockSignature(block);
+      if (!conversion) {
         LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
                                            "block"));
         return failure();
       }
+      impl.applySignatureConversion(rewriter, block, converter, *conversion);
       continue;
     }
 

@matthias-springer matthias-springer requested a review from ftynse June 8, 2024 19:35
This commit simplifies and improves documentation for the part of the `ConversionPatternRewriter` API that deals with signature conversions.

There are now two public functions for signature conversion:
* `applySignatureConversion` converts a single block signature.
* `convertRegionTypes` converts all block signatures of a region.

Note: `convertRegionTypes` could be renamed to `applySignatureConversion` (overload) in the future.

Also clarify when a type converter and/or signature conversion object is needed and for what purpose.

From a functional perspective, this change is NFC. However, the public API changes, thus not marking as NFC.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/dialect_conversion_block_api branch from 5c86bfa to 4ea3920 Compare June 8, 2024 19:43
Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Very nice improvements and simplifications

Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the cleanup 🙂

@matthias-springer matthias-springer merged commit 52050f3 into main Jun 10, 2024
7 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/dialect_conversion_block_api branch June 10, 2024 19:49
Lukacma pushed a commit to Lukacma/llvm-project that referenced this pull request Jun 12, 2024
…lvm#94866)

This commit simplifies and improves documentation for the part of the
`ConversionPatternRewriter` API that deals with signature conversions.

There are now two public functions for signature conversion:
* `applySignatureConversion` converts a single block signature. This
function used to take a `Region *` (but converted only the entry block).
It now takes a `Block *`.
* `convertRegionTypes` converts all block signatures of a region.

`convertNonEntryRegionTypes` is removed because it is not widely used
and can easily be expressed with a call to `applySignatureConversion`
inside a loop. (See `Detensorize.cpp` for an example.)

Note: For consistency, `convertRegionTypes` could be renamed to
`applySignatureConversion` (overload) in the future. (Or
`applySignatureConversion` renamed to `convertBlockTypes`.)

Also clarify when a type converter and/or signature conversion object is
needed and for what purpose.

Internal code refactoring (NFC) of `ConversionPatternRewriterImpl` (the
part that deals with signature conversions). This part of the codebase
was quite convoluted and unintuitive.

From a functional perspective, this change is NFC. However, the public
API changes, thus not marking as NFC.

Note for LLVM integration: When you see
`applySignatureConversion(region, ...)`, replace with
`applySignatureConversion(region->front(), ...)`. In the unlikely case
that you see `convertNonEntryRegionTypes`, apply the same changes as
this commit did to `Detensorize.cpp`.

---------

Co-authored-by: Markus Böck <[email protected]>
@HerrCai0907 HerrCai0907 mentioned this pull request Jun 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants