Skip to content

Commit 6a2071c

Browse files
[mlir][transform] Allow passing various library files to interpreter. (llvm#67120)
The transfrom interpreter accepts an argument to a "library" file with named sequences. This patch exteneds this functionality such that (1) several such individual files are accepted and (2) folders can be passed in, in which all `*.mlir` files are loaded.
1 parent a233a49 commit 6a2071c

12 files changed

+200
-53
lines changed

mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ namespace detail {
3333
/// Template-free implementation of TransformInterpreterPassBase::initialize.
3434
LogicalResult interpreterBaseInitializeImpl(
3535
MLIRContext *context, StringRef transformFileName,
36-
StringRef transformLibraryFileName,
36+
ArrayRef<std::string> transformLibraryPaths,
3737
std::shared_ptr<OwningOpRef<ModuleOp>> &module,
3838
std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
3939
function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
@@ -48,7 +48,7 @@ LogicalResult interpreterBaseRunOnOperationImpl(
4848
const RaggedArray<MappedValue> &extraMappings,
4949
const TransformOptions &options,
5050
const Pass::Option<std::string> &transformFileName,
51-
const Pass::Option<std::string> &transformLibraryFileName,
51+
const Pass::ListOption<std::string> &transformLibraryPaths,
5252
const Pass::Option<std::string> &debugPayloadRootTag,
5353
const Pass::Option<std::string> &debugTransformRootTag,
5454
StringRef binaryName);
@@ -62,11 +62,12 @@ LogicalResult interpreterBaseRunOnOperationImpl(
6262
/// transform script. If empty, `debugTransformRootTag` is considered or the
6363
/// pass root operation must contain a single top-level transform op that
6464
/// will be interpreted.
65-
/// - transformLibraryFileName: if non-empty, the module in this file will be
65+
/// - transformLibraryPaths: if non-empty, the modules in these files will be
6666
/// merged into the main transform script run by the interpreter before
6767
/// execution. This allows to provide definitions for external functions
68-
/// used in the main script. Other public symbols in the library module may
69-
/// lead to collisions with public symbols in the main script.
68+
/// used in the main script. Other public symbols in the library modules may
69+
/// lead to collisions with public symbols in the main script and among each
70+
/// other.
7071
/// - debugPayloadRootTag: if non-empty, the value of the attribute named
7172
/// `kTransformDialectTagAttrName` indicating the single op that is
7273
/// considered the payload root of the transform interpreter; otherwise, the
@@ -118,16 +119,26 @@ class TransformInterpreterPassBase : public GeneratedBase<Concrete> {
118119
REQUIRE_PASS_OPTION(transformFileName);
119120
REQUIRE_PASS_OPTION(debugPayloadRootTag);
120121
REQUIRE_PASS_OPTION(debugTransformRootTag);
121-
REQUIRE_PASS_OPTION(transformLibraryFileName);
122122

123123
#undef REQUIRE_PASS_OPTION
124124

125+
#define REQUIRE_PASS_LIST_OPTION(NAME) \
126+
static_assert( \
127+
std::is_same_v< \
128+
std::remove_reference_t<decltype(std::declval<Concrete &>().NAME)>, \
129+
Pass::ListOption<std::string>>, \
130+
"required " #NAME " string pass option is missing")
131+
132+
REQUIRE_PASS_LIST_OPTION(transformLibraryPaths);
133+
134+
#undef REQUIRE_PASS_LIST_OPTION
135+
125136
StringRef transformFileName =
126137
static_cast<Concrete *>(this)->transformFileName;
127-
StringRef transformLibraryFileName =
128-
static_cast<Concrete *>(this)->transformLibraryFileName;
138+
ArrayRef<std::string> transformLibraryPaths =
139+
static_cast<Concrete *>(this)->transformLibraryPaths;
129140
return detail::interpreterBaseInitializeImpl(
130-
context, transformFileName, transformLibraryFileName,
141+
context, transformFileName, transformLibraryPaths,
131142
sharedTransformModule, transformLibraryModule,
132143
[this](OpBuilder &builder, Location loc) {
133144
return static_cast<Concrete *>(this)->constructTransformModule(
@@ -162,7 +173,7 @@ class TransformInterpreterPassBase : public GeneratedBase<Concrete> {
162173
op, pass->getArgument(), sharedTransformModule,
163174
transformLibraryModule,
164175
/*extraMappings=*/{}, options, pass->transformFileName,
165-
pass->transformLibraryFileName, pass->debugPayloadRootTag,
176+
pass->transformLibraryPaths, pass->debugPayloadRootTag,
166177
pass->debugTransformRootTag, binaryName)) ||
167178
failed(pass->runAfterInterpreter(op))) {
168179
return pass->signalPassFailure();

mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp

Lines changed: 126 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
1515
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
1616
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
17+
#include "mlir/Dialect/Transform/IR/TransformOps.h"
1718
#include "mlir/IR/BuiltinOps.h"
1819
#include "mlir/IR/Verifier.h"
1920
#include "mlir/IR/Visitors.h"
@@ -194,7 +195,7 @@ saveReproToTempFile(llvm::raw_ostream &os, Operation *target,
194195
Operation *transform, StringRef passName,
195196
const Pass::Option<std::string> &debugPayloadRootTag,
196197
const Pass::Option<std::string> &debugTransformRootTag,
197-
const Pass::Option<std::string> &transformLibraryFileName,
198+
const Pass::ListOption<std::string> &transformLibraryPaths,
198199
StringRef binaryName) {
199200
using llvm::sys::fs::TempFile;
200201
Operation *root = getRootOperation(target);
@@ -231,7 +232,7 @@ static void performOptionalDebugActions(
231232
Operation *target, Operation *transform, StringRef passName,
232233
const Pass::Option<std::string> &debugPayloadRootTag,
233234
const Pass::Option<std::string> &debugTransformRootTag,
234-
const Pass::Option<std::string> &transformLibraryFileName,
235+
const Pass::ListOption<std::string> &transformLibraryPaths,
235236
StringRef binaryName) {
236237
MLIRContext *context = target->getContext();
237238

@@ -284,7 +285,7 @@ static void performOptionalDebugActions(
284285
DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, {
285286
saveReproToTempFile(llvm::dbgs(), target, transform, passName,
286287
debugPayloadRootTag, debugTransformRootTag,
287-
transformLibraryFileName, binaryName);
288+
transformLibraryPaths, binaryName);
288289
});
289290

290291
// Remove temporary attributes if they were set.
@@ -534,7 +535,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
534535
const RaggedArray<MappedValue> &extraMappings,
535536
const TransformOptions &options,
536537
const Pass::Option<std::string> &transformFileName,
537-
const Pass::Option<std::string> &transformLibraryFileName,
538+
const Pass::ListOption<std::string> &transformLibraryPaths,
538539
const Pass::Option<std::string> &debugPayloadRootTag,
539540
const Pass::Option<std::string> &debugTransformRootTag,
540541
StringRef binaryName) {
@@ -597,7 +598,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
597598
if (failed(
598599
mergeSymbolsInto(SymbolTable::getNearestSymbolTable(transformRoot),
599600
transformLibraryModule->get()->clone())))
600-
return failure();
601+
return emitError(transformRoot->getLoc(),
602+
"failed to merge library symbols into transform root");
601603
}
602604

603605
// Step 4
@@ -606,7 +608,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
606608
// repro to stderr and/or a file.
607609
performOptionalDebugActions(target, transformRoot, passName,
608610
debugPayloadRootTag, debugTransformRootTag,
609-
transformLibraryFileName, binaryName);
611+
transformLibraryPaths, binaryName);
610612

611613
// Step 5
612614
// ------
@@ -615,55 +617,148 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
615617
extraMappings, options);
616618
}
617619

620+
/// Expands the given list of `paths` to a list of `.mlir` files.
621+
///
622+
/// Each entry in `paths` may either be a regular file, in which case it ends up
623+
/// in the result list, or a directory, in which case all (regular) `.mlir`
624+
/// files in that directory are added. Any other file types lead to a failure.
625+
static LogicalResult
626+
expandPathsToMLIRFiles(ArrayRef<std::string> &paths, MLIRContext *const context,
627+
SmallVectorImpl<std::string> &fileNames) {
628+
for (const std::string &path : paths) {
629+
auto loc = FileLineColLoc::get(context, path, 0, 0);
630+
631+
if (llvm::sys::fs::is_regular_file(path)) {
632+
LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");
633+
fileNames.push_back(path);
634+
continue;
635+
}
636+
637+
if (!llvm::sys::fs::is_directory(path)) {
638+
return emitError(loc)
639+
<< "'" << path << "' is neither a file nor a directory";
640+
}
641+
642+
LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n");
643+
644+
std::error_code ec;
645+
for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;
646+
it != itEnd && !ec; it.increment(ec)) {
647+
const std::string &fileName = it->path();
648+
649+
if (it->type() != llvm::sys::fs::file_type::regular_file) {
650+
LLVM_DEBUG(DBGS() << " Skipping non-regular file '" << fileName
651+
<< "'\n");
652+
continue;
653+
}
654+
655+
if (!StringRef(fileName).endswith(".mlir")) {
656+
LLVM_DEBUG(DBGS() << " Skipping '" << fileName
657+
<< "' because it does not end with '.mlir'\n");
658+
continue;
659+
}
660+
661+
LLVM_DEBUG(DBGS() << " Adding '" << fileName << "' to list of files\n");
662+
fileNames.push_back(fileName);
663+
}
664+
665+
if (ec)
666+
return emitError(loc) << "error while opening files in '" << path
667+
<< "': " << ec.message();
668+
}
669+
670+
return success();
671+
}
672+
618673
LogicalResult transform::detail::interpreterBaseInitializeImpl(
619674
MLIRContext *context, StringRef transformFileName,
620-
StringRef transformLibraryFileName,
675+
ArrayRef<std::string> transformLibraryPaths,
621676
std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
622677
std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
623678
function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
624679
moduleBuilder) {
625-
OwningOpRef<ModuleOp> parsedTransformModule;
626-
if (failed(parseTransformModuleFromFile(context, transformFileName,
627-
parsedTransformModule)))
628-
return failure();
629-
if (parsedTransformModule && failed(mlir::verify(*parsedTransformModule)))
630-
return failure();
680+
auto unknownLoc = UnknownLoc::get(context);
631681

632-
OwningOpRef<ModuleOp> parsedLibraryModule;
633-
if (failed(parseTransformModuleFromFile(context, transformLibraryFileName,
634-
parsedLibraryModule)))
635-
return failure();
636-
if (parsedLibraryModule && failed(mlir::verify(*parsedLibraryModule)))
682+
// Parse module from file.
683+
OwningOpRef<ModuleOp> moduleFromFile;
684+
{
685+
auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);
686+
if (failed(parseTransformModuleFromFile(context, transformFileName,
687+
moduleFromFile)))
688+
return emitError(loc) << "failed to parse transform module";
689+
if (moduleFromFile && failed(mlir::verify(*moduleFromFile)))
690+
return emitError(loc) << "failed to verify transform module";
691+
}
692+
693+
// Assemble list of library files.
694+
SmallVector<std::string> libraryFileNames;
695+
if (failed(expandPathsToMLIRFiles(transformLibraryPaths, context,
696+
libraryFileNames)))
637697
return failure();
638698

639-
if (parsedTransformModule) {
640-
sharedTransformModule = std::make_shared<OwningOpRef<ModuleOp>>(
641-
std::move(parsedTransformModule));
699+
// Parse modules from library files.
700+
SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
701+
for (const std::string &libraryFileName : libraryFileNames) {
702+
OwningOpRef<ModuleOp> parsedLibrary;
703+
auto loc = FileLineColLoc::get(context, libraryFileName, 0, 0);
704+
if (failed(parseTransformModuleFromFile(context, libraryFileName,
705+
parsedLibrary)))
706+
return emitError(loc) << "failed to parse transform library module";
707+
if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
708+
return emitError(loc) << "failed to verify transform library module";
709+
parsedLibraries.push_back(std::move(parsedLibrary));
710+
}
711+
712+
// Build shared transform module.
713+
if (moduleFromFile) {
714+
sharedTransformModule =
715+
std::make_shared<OwningOpRef<ModuleOp>>(std::move(moduleFromFile));
642716
} else if (moduleBuilder) {
643-
// TODO: better location story.
644-
auto location = UnknownLoc::get(context);
717+
auto loc = FileLineColLoc::get(context, "<shared-transform-module>", 0, 0);
645718
auto localModule = std::make_shared<OwningOpRef<ModuleOp>>(
646-
ModuleOp::create(location, "__transform"));
719+
ModuleOp::create(unknownLoc, "__transform"));
647720

648721
OpBuilder b(context);
649722
b.setInsertionPointToEnd(localModule->get().getBody());
650-
if (std::optional<LogicalResult> result = moduleBuilder(b, location)) {
723+
if (std::optional<LogicalResult> result = moduleBuilder(b, loc)) {
651724
if (failed(*result))
652-
return failure();
725+
return (*localModule)->emitError()
726+
<< "failed to create shared transform module";
653727
sharedTransformModule = std::move(localModule);
654728
}
655729
}
656730

657-
if (!parsedLibraryModule || !*parsedLibraryModule)
731+
if (parsedLibraries.empty())
658732
return success();
659733

734+
// Merge parsed libraries into one module.
735+
auto loc = FileLineColLoc::get(context, "<shared-library-module>", 0, 0);
736+
OwningOpRef<ModuleOp> mergedParsedLibraries =
737+
ModuleOp::create(loc, "__transform");
738+
{
739+
mergedParsedLibraries.get()->setAttr("transform.with_named_sequence",
740+
UnitAttr::get(context));
741+
IRRewriter rewriter(context);
742+
// TODO: extend `mergeSymbolsInto` to support multiple `other` modules.
743+
for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
744+
if (failed(mergeSymbolsInto(mergedParsedLibraries.get(),
745+
std::move(parsedLibrary))))
746+
return mergedParsedLibraries->emitError()
747+
<< "failed to verify merged transform module";
748+
}
749+
}
750+
751+
// Use parsed libaries to resolve symbols in shared transform module or return
752+
// as separate library module.
660753
if (sharedTransformModule && *sharedTransformModule) {
661754
if (failed(mergeSymbolsInto(sharedTransformModule->get(),
662-
std::move(parsedLibraryModule))))
663-
return failure();
755+
std::move(mergedParsedLibraries))))
756+
return (*sharedTransformModule)->emitError()
757+
<< "failed to merge symbols from library files "
758+
"into shared transform module";
664759
} else {
665-
transformLibraryModule =
666-
std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsedLibraryModule));
760+
transformLibraryModule = std::make_shared<OwningOpRef<ModuleOp>>(
761+
std::move(mergedParsedLibraries));
667762
}
668763
return success();
669764
}

mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
// RUN: mlir-opt %s -test-lower-to-llvm -cse | FileCheck %s
44

5-
// RUN: mlir-opt %s -test-transform-dialect-interpreter="transform-library-file-name=%p/lower-to-llvm-transform-symbol-def.mlir debug-payload-root-tag=payload" \
5+
// RUN: mlir-opt %s -test-transform-dialect-interpreter="transform-library-paths=%p/lower-to-llvm-transform-symbol-def.mlir debug-payload-root-tag=payload" \
66
// RUN: -test-transform-dialect-erase-schedule -cse \
77
// RUN: | FileCheck %s
88

mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
1+
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-paths=%p/test-interpreter-library/definitions-self-contained.mlir})" \
22
// RUN: --verify-diagnostics
33

4-
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
4+
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-paths=%p/test-interpreter-library/definitions-self-contained.mlir}, test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-paths=%p/test-interpreter-library/definitions-self-contained.mlir})" \
55
// RUN: --verify-diagnostics
66

77
// The external transform script has a declaration to the named sequence @foo,
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-paths=%p%{fs-sep}test-interpreter-library})" \
2+
// RUN: --verify-diagnostics --split-input-file | FileCheck %s
3+
4+
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-paths=%p%{fs-sep}test-interpreter-library/definitions-self-contained.mlir,%p%{fs-sep}test-interpreter-library/definitions-with-unresolved.mlir})" \
5+
// RUN: --verify-diagnostics --split-input-file | FileCheck %s
6+
7+
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-paths=%p%{fs-sep}test-interpreter-library}, test-transform-dialect-interpreter)" \
8+
// RUN: --verify-diagnostics --split-input-file | FileCheck %s
9+
10+
// The definition of the @foo named sequence is provided in another file. It
11+
// will be included because of the pass option. Repeated application of the
12+
// same pass, with or without the library option, should not be a problem.
13+
// Note that the same diagnostic produced twice at the same location only
14+
// needs to be matched once.
15+
16+
// expected-remark @below {{message}}
17+
module attributes {transform.with_named_sequence} {
18+
// CHECK: transform.named_sequence @print_message
19+
transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly})
20+
21+
transform.named_sequence @reference_other_module(!transform.any_op {transform.readonly})
22+
23+
transform.sequence failures(propagate) {
24+
^bb0(%arg0: !transform.any_op):
25+
include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()
26+
include @reference_other_module failures(propagate) (%arg0) : (!transform.any_op) -> ()
27+
}
28+
}

0 commit comments

Comments
 (0)