Skip to content

Commit 7876899

Browse files
[mlir][transform] Fix handling of transitive include in interpreter. (#67560)
Until now, the interpreter would only load those symbols from the provided library files that were declared in the main transform module. However, sequences in the library may include other sequences on their own. Until now, if such sequences were not *also* declared in the main transform module, the interpreter would fail to resolve them. Forward declaring all of them is undesirable as it defeats the purpose of encapsulation into library modules. This PR implements a kind of linker for transform scripts to solve this problem. The linker merges all symbols of the library module into the main module before interpreting the latter. Symbols whose names collide are handled as follows: (1) if they are both functions (in the sense of `FunctionOpInterface`) with compatible signatures, one is external, and the other one is public, then they are merged; (2) of one of them is private, that one is renamed; and (3) an error is raised otherwise. One consequence of this change is that the loading of the library files in the interpreter pass is not idempotent anymore, i.e., subsequent interpreter passes cannot (and need not) load the same library files again since would lead to doubly defined symbols.
1 parent cd184c8 commit 7876899

11 files changed

+475
-107
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@ def Transform_Dialect : Dialect {
4343
constexpr const static ::llvm::StringLiteral kArgReadOnlyAttrName =
4444
"transform.readonly";
4545

46+
/// Names of the attributes indicating whether an argument of an external
47+
/// transform dialect symbol is consumed or only read.
48+
StringAttr getConsumedAttrName() const {
49+
return StringAttr::get(getContext(), kArgConsumedAttrName);
50+
}
51+
StringAttr getReadOnlyAttrName() const {
52+
return StringAttr::get(getContext(), kArgReadOnlyAttrName);
53+
}
54+
4655
template <typename DataTy>
4756
const DataTy &getExtraData() const {
4857
return *static_cast<const DataTy *>(

mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,11 @@ LogicalResult interpreterBaseRunOnOperationImpl(
6262
/// transform script. If empty, `debugTransformRootTag` is considered or the
6363
/// pass root operation must contain a single top-level transform op that
6464
/// will be interpreted.
65-
/// - transformLibraryFileName: if non-empty, the name of the file containing
66-
/// definitions of external symbols referenced in the transform script.
67-
/// These definitions will be used to replace declarations.
65+
/// - transformLibraryFileName: if non-empty, the module in this file will be
66+
/// merged into the main transform script run by the interpreter before
67+
/// execution. This allows to provide definitions for external functions
68+
/// used in the main script. Other public symbols in the library module may
69+
/// lead to collisions with public symbols in the main script.
6870
/// - debugPayloadRootTag: if non-empty, the value of the attribute named
6971
/// `kTransformDialectTagAttrName` indicating the single op that is
7072
/// considered the payload root of the transform interpreter; otherwise, the
@@ -85,7 +87,7 @@ LogicalResult interpreterBaseRunOnOperationImpl(
8587
/// as template arguments. They are *not* expected to to implement `initialize`
8688
/// or `runOnOperation`. They *are* expected to call the copy constructor of
8789
/// this class in their copy constructors, short of which the file-based
88-
/// transform dialect script injection facility will become nonoperational.
90+
/// transform dialect script injection facility will become non-operational.
8991
///
9092
/// Concrete passes may implement the `runBeforeInterpreter` and
9193
/// `runAfterInterpreter` to customize the behavior of the pass.

mlir/include/mlir/IR/SymbolTable.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,23 @@ class SymbolTable {
5555
/// after insertion as attribute.
5656
StringAttr insert(Operation *symbol, Block::iterator insertPt = {});
5757

58+
/// Renames the given op or the op refered to by the given name to the given
59+
/// new name and updates the symbol table and all usages of the symbol
60+
/// accordingly. Fails if the updating of the usages fails.
61+
LogicalResult rename(StringAttr from, StringAttr to);
62+
LogicalResult rename(Operation *op, StringAttr to);
63+
LogicalResult rename(StringAttr from, StringRef to);
64+
LogicalResult rename(Operation *op, StringRef to);
65+
66+
/// Renames the given op or the op refered to by the given name to the a name
67+
/// that is unique within this and the provided other symbol tables and
68+
/// updates the symbol table and all usages of the symbol accordingly. Returns
69+
/// the new name or failure if the renaming fails.
70+
FailureOr<StringAttr> renameToUnique(StringAttr from,
71+
ArrayRef<SymbolTable *> others);
72+
FailureOr<StringAttr> renameToUnique(Operation *op,
73+
ArrayRef<SymbolTable *> others);
74+
5875
/// Return the name of the attribute used for symbol names.
5976
static StringRef getSymbolAttrName() { return "sym_name"; }
6077

0 commit comments

Comments
 (0)