14
14
#include " mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
15
15
#include " mlir/Dialect/Transform/IR/TransformDialect.h"
16
16
#include " mlir/Dialect/Transform/IR/TransformInterfaces.h"
17
+ #include " mlir/Dialect/Transform/IR/TransformOps.h"
17
18
#include " mlir/IR/BuiltinOps.h"
18
19
#include " mlir/IR/Verifier.h"
19
20
#include " mlir/IR/Visitors.h"
@@ -194,7 +195,7 @@ saveReproToTempFile(llvm::raw_ostream &os, Operation *target,
194
195
Operation *transform, StringRef passName,
195
196
const Pass::Option<std::string> &debugPayloadRootTag,
196
197
const Pass::Option<std::string> &debugTransformRootTag,
197
- const Pass::Option <std::string> &transformLibraryFileName ,
198
+ const Pass::ListOption <std::string> &transformLibraryPaths ,
198
199
StringRef binaryName) {
199
200
using llvm::sys::fs::TempFile;
200
201
Operation *root = getRootOperation (target);
@@ -231,7 +232,7 @@ static void performOptionalDebugActions(
231
232
Operation *target, Operation *transform, StringRef passName,
232
233
const Pass::Option<std::string> &debugPayloadRootTag,
233
234
const Pass::Option<std::string> &debugTransformRootTag,
234
- const Pass::Option <std::string> &transformLibraryFileName ,
235
+ const Pass::ListOption <std::string> &transformLibraryPaths ,
235
236
StringRef binaryName) {
236
237
MLIRContext *context = target->getContext ();
237
238
@@ -284,7 +285,7 @@ static void performOptionalDebugActions(
284
285
DEBUG_WITH_TYPE (DEBUG_TYPE_DUMP_FILE, {
285
286
saveReproToTempFile (llvm::dbgs (), target, transform, passName,
286
287
debugPayloadRootTag, debugTransformRootTag,
287
- transformLibraryFileName , binaryName);
288
+ transformLibraryPaths , binaryName);
288
289
});
289
290
290
291
// Remove temporary attributes if they were set.
@@ -534,7 +535,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
534
535
const RaggedArray<MappedValue> &extraMappings,
535
536
const TransformOptions &options,
536
537
const Pass::Option<std::string> &transformFileName,
537
- const Pass::Option <std::string> &transformLibraryFileName ,
538
+ const Pass::ListOption <std::string> &transformLibraryPaths ,
538
539
const Pass::Option<std::string> &debugPayloadRootTag,
539
540
const Pass::Option<std::string> &debugTransformRootTag,
540
541
StringRef binaryName) {
@@ -597,7 +598,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
597
598
if (failed (
598
599
mergeSymbolsInto (SymbolTable::getNearestSymbolTable (transformRoot),
599
600
transformLibraryModule->get ()->clone ())))
600
- return failure ();
601
+ return emitError (transformRoot->getLoc (),
602
+ " failed to merge library symbols into transform root" );
601
603
}
602
604
603
605
// Step 4
@@ -606,7 +608,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
606
608
// repro to stderr and/or a file.
607
609
performOptionalDebugActions (target, transformRoot, passName,
608
610
debugPayloadRootTag, debugTransformRootTag,
609
- transformLibraryFileName , binaryName);
611
+ transformLibraryPaths , binaryName);
610
612
611
613
// Step 5
612
614
// ------
@@ -615,55 +617,148 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
615
617
extraMappings, options);
616
618
}
617
619
620
+ // / Expands the given list of `paths` to a list of `.mlir` files.
621
+ // /
622
+ // / Each entry in `paths` may either be a regular file, in which case it ends up
623
+ // / in the result list, or a directory, in which case all (regular) `.mlir`
624
+ // / files in that directory are added. Any other file types lead to a failure.
625
+ static LogicalResult
626
+ expandPathsToMLIRFiles (ArrayRef<std::string> &paths, MLIRContext *const context,
627
+ SmallVectorImpl<std::string> &fileNames) {
628
+ for (const std::string &path : paths) {
629
+ auto loc = FileLineColLoc::get (context, path, 0 , 0 );
630
+
631
+ if (llvm::sys::fs::is_regular_file (path)) {
632
+ LLVM_DEBUG (DBGS () << " Adding '" << path << " ' to list of files\n " );
633
+ fileNames.push_back (path);
634
+ continue ;
635
+ }
636
+
637
+ if (!llvm::sys::fs::is_directory (path)) {
638
+ return emitError (loc)
639
+ << " '" << path << " ' is neither a file nor a directory" ;
640
+ }
641
+
642
+ LLVM_DEBUG (DBGS () << " Looking for files in '" << path << " ':\n " );
643
+
644
+ std::error_code ec;
645
+ for (llvm::sys::fs::directory_iterator it (path, ec), itEnd;
646
+ it != itEnd && !ec; it.increment (ec)) {
647
+ const std::string &fileName = it->path ();
648
+
649
+ if (it->type () != llvm::sys::fs::file_type::regular_file) {
650
+ LLVM_DEBUG (DBGS () << " Skipping non-regular file '" << fileName
651
+ << " '\n " );
652
+ continue ;
653
+ }
654
+
655
+ if (!StringRef (fileName).endswith (" .mlir" )) {
656
+ LLVM_DEBUG (DBGS () << " Skipping '" << fileName
657
+ << " ' because it does not end with '.mlir'\n " );
658
+ continue ;
659
+ }
660
+
661
+ LLVM_DEBUG (DBGS () << " Adding '" << fileName << " ' to list of files\n " );
662
+ fileNames.push_back (fileName);
663
+ }
664
+
665
+ if (ec)
666
+ return emitError (loc) << " error while opening files in '" << path
667
+ << " ': " << ec.message ();
668
+ }
669
+
670
+ return success ();
671
+ }
672
+
618
673
LogicalResult transform::detail::interpreterBaseInitializeImpl (
619
674
MLIRContext *context, StringRef transformFileName,
620
- StringRef transformLibraryFileName ,
675
+ ArrayRef<std::string> transformLibraryPaths ,
621
676
std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
622
677
std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
623
678
function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
624
679
moduleBuilder) {
625
- OwningOpRef<ModuleOp> parsedTransformModule;
626
- if (failed (parseTransformModuleFromFile (context, transformFileName,
627
- parsedTransformModule)))
628
- return failure ();
629
- if (parsedTransformModule && failed (mlir::verify (*parsedTransformModule)))
630
- return failure ();
680
+ auto unknownLoc = UnknownLoc::get (context);
631
681
632
- OwningOpRef<ModuleOp> parsedLibraryModule;
633
- if (failed (parseTransformModuleFromFile (context, transformLibraryFileName,
634
- parsedLibraryModule)))
635
- return failure ();
636
- if (parsedLibraryModule && failed (mlir::verify (*parsedLibraryModule)))
682
+ // Parse module from file.
683
+ OwningOpRef<ModuleOp> moduleFromFile;
684
+ {
685
+ auto loc = FileLineColLoc::get (context, transformFileName, 0 , 0 );
686
+ if (failed (parseTransformModuleFromFile (context, transformFileName,
687
+ moduleFromFile)))
688
+ return emitError (loc) << " failed to parse transform module" ;
689
+ if (moduleFromFile && failed (mlir::verify (*moduleFromFile)))
690
+ return emitError (loc) << " failed to verify transform module" ;
691
+ }
692
+
693
+ // Assemble list of library files.
694
+ SmallVector<std::string> libraryFileNames;
695
+ if (failed (expandPathsToMLIRFiles (transformLibraryPaths, context,
696
+ libraryFileNames)))
637
697
return failure ();
638
698
639
- if (parsedTransformModule) {
640
- sharedTransformModule = std::make_shared<OwningOpRef<ModuleOp>>(
641
- std::move (parsedTransformModule));
699
+ // Parse modules from library files.
700
+ SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
701
+ for (const std::string &libraryFileName : libraryFileNames) {
702
+ OwningOpRef<ModuleOp> parsedLibrary;
703
+ auto loc = FileLineColLoc::get (context, libraryFileName, 0 , 0 );
704
+ if (failed (parseTransformModuleFromFile (context, libraryFileName,
705
+ parsedLibrary)))
706
+ return emitError (loc) << " failed to parse transform library module" ;
707
+ if (parsedLibrary && failed (mlir::verify (*parsedLibrary)))
708
+ return emitError (loc) << " failed to verify transform library module" ;
709
+ parsedLibraries.push_back (std::move (parsedLibrary));
710
+ }
711
+
712
+ // Build shared transform module.
713
+ if (moduleFromFile) {
714
+ sharedTransformModule =
715
+ std::make_shared<OwningOpRef<ModuleOp>>(std::move (moduleFromFile));
642
716
} else if (moduleBuilder) {
643
- // TODO: better location story.
644
- auto location = UnknownLoc::get (context);
717
+ auto loc = FileLineColLoc::get (context, " <shared-transform-module>" , 0 , 0 );
645
718
auto localModule = std::make_shared<OwningOpRef<ModuleOp>>(
646
- ModuleOp::create (location , " __transform" ));
719
+ ModuleOp::create (unknownLoc , " __transform" ));
647
720
648
721
OpBuilder b (context);
649
722
b.setInsertionPointToEnd (localModule->get ().getBody ());
650
- if (std::optional<LogicalResult> result = moduleBuilder (b, location )) {
723
+ if (std::optional<LogicalResult> result = moduleBuilder (b, loc )) {
651
724
if (failed (*result))
652
- return failure ();
725
+ return (*localModule)->emitError ()
726
+ << " failed to create shared transform module" ;
653
727
sharedTransformModule = std::move (localModule);
654
728
}
655
729
}
656
730
657
- if (!parsedLibraryModule || !*parsedLibraryModule )
731
+ if (parsedLibraries. empty () )
658
732
return success ();
659
733
734
+ // Merge parsed libraries into one module.
735
+ auto loc = FileLineColLoc::get (context, " <shared-library-module>" , 0 , 0 );
736
+ OwningOpRef<ModuleOp> mergedParsedLibraries =
737
+ ModuleOp::create (loc, " __transform" );
738
+ {
739
+ mergedParsedLibraries.get ()->setAttr (" transform.with_named_sequence" ,
740
+ UnitAttr::get (context));
741
+ IRRewriter rewriter (context);
742
+ // TODO: extend `mergeSymbolsInto` to support multiple `other` modules.
743
+ for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
744
+ if (failed (mergeSymbolsInto (mergedParsedLibraries.get (),
745
+ std::move (parsedLibrary))))
746
+ return mergedParsedLibraries->emitError ()
747
+ << " failed to verify merged transform module" ;
748
+ }
749
+ }
750
+
751
+ // Use parsed libaries to resolve symbols in shared transform module or return
752
+ // as separate library module.
660
753
if (sharedTransformModule && *sharedTransformModule) {
661
754
if (failed (mergeSymbolsInto (sharedTransformModule->get (),
662
- std::move (parsedLibraryModule))))
663
- return failure ();
755
+ std::move (mergedParsedLibraries))))
756
+ return (*sharedTransformModule)->emitError ()
757
+ << " failed to merge symbols from library files "
758
+ " into shared transform module" ;
664
759
} else {
665
- transformLibraryModule =
666
- std::make_shared<OwningOpRef<ModuleOp>>( std:: move (parsedLibraryModule ));
760
+ transformLibraryModule = std::make_shared<OwningOpRef<ModuleOp>>(
761
+ std::move (mergedParsedLibraries ));
667
762
}
668
763
return success ();
669
764
}
0 commit comments