Skip to content

Commit fef3566

Browse files
authored
[mlir] Pass Options ownership modifications (#110582)
This change makes two (related) changes: First, it updates the tablegen option for `ListOption` to emit a `SmallVector` instead of an `ArrayRef`. This brings `ListOption` more inline with the traditional `Option`, where values are typically provided using types that have storage. After this change, all options should be fully owned by a Pass' `Options` object after it has been fully constructed, unless the underlying type of the `Option` explicitly indicates otherwise. Second, it updates the generated constructors for Passes to consume options by value instead of reference, and prefers moving options into the pass itself. This should be more efficient for non-trivial options objects, where the previous interface forced a copy to be materialized. Now, at worst case the API materializes a copy (no worse than before); at best-case, all options objects are moved into place. Ideally, we could update the Pass constructor to take an r-value reference to the Options object instead, but this approach will require numerous changes to existing passes and their factory functions. --------- Authored-by: Nikhil Kalra <[email protected]>
1 parent afc0557 commit fef3566

File tree

4 files changed

+16
-14
lines changed

4 files changed

+16
-14
lines changed

mlir/include/mlir/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#define MLIR_TRANSFORMS_PASSES_H
1616

1717
#include "mlir/Pass/Pass.h"
18+
#include "mlir/Pass/PassManager.h"
1819
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1920
#include "mlir/Transforms/LocationSnapshot.h"
2021
#include "mlir/Transforms/ViewOpGraph.h"

mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ struct NarrowingPattern : OpRewritePattern<SourceOp> {
4444
NarrowingPattern(MLIRContext *ctx, const ArithIntNarrowingOptions &options,
4545
PatternBenefit benefit = 1)
4646
: OpRewritePattern<SourceOp>(ctx, benefit),
47-
supportedBitwidths(options.bitwidthsSupported) {
47+
supportedBitwidths(options.bitwidthsSupported.begin(),
48+
options.bitwidthsSupported.end()) {
4849
assert(!supportedBitwidths.empty() && "Invalid options");
4950
assert(!llvm::is_contained(supportedBitwidths, 0) && "Invalid bitwidth");
5051
llvm::sort(supportedBitwidths);
@@ -757,7 +758,8 @@ struct ArithIntNarrowingPass final
757758
MLIRContext *ctx = op->getContext();
758759
RewritePatternSet patterns(ctx);
759760
populateArithIntNarrowingPatterns(
760-
patterns, ArithIntNarrowingOptions{bitwidthsSupported});
761+
patterns, ArithIntNarrowingOptions{
762+
llvm::to_vector_of<unsigned>(bitwidthsSupported)});
761763
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
762764
signalPassFailure();
763765
}

mlir/tools/mlir-tblgen/PassGen.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) {
9797
std::string type = opt.getType().str();
9898

9999
if (opt.isListOption())
100-
type = "::llvm::ArrayRef<" + type + ">";
100+
type = "::llvm::SmallVector<" + type + ">";
101101

102102
os.indent(2) << llvm::formatv("{0} {1}", type, opt.getCppVariableName());
103103

@@ -128,8 +128,8 @@ static void emitPassDecls(const Pass &pass, raw_ostream &os) {
128128

129129
// Declaration of the constructor with options.
130130
if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty())
131-
os << llvm::formatv("std::unique_ptr<::mlir::Pass> create{0}(const "
132-
"{0}Options &options);\n",
131+
os << llvm::formatv("std::unique_ptr<::mlir::Pass> create{0}("
132+
"{0}Options options);\n",
133133
passName);
134134
}
135135

@@ -236,7 +236,7 @@ namespace impl {{
236236

237237
const char *const friendDefaultConstructorWithOptionsDeclTemplate = R"(
238238
namespace impl {{
239-
std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options);
239+
std::unique_ptr<::mlir::Pass> create{0}({0}Options options);
240240
} // namespace impl
241241
)";
242242

@@ -247,8 +247,8 @@ const char *const friendDefaultConstructorDefTemplate = R"(
247247
)";
248248

249249
const char *const friendDefaultConstructorWithOptionsDefTemplate = R"(
250-
friend std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options) {{
251-
return std::make_unique<DerivedT>(options);
250+
friend std::unique_ptr<::mlir::Pass> create{0}({0}Options options) {{
251+
return std::make_unique<DerivedT>(std::move(options));
252252
}
253253
)";
254254

@@ -259,8 +259,8 @@ std::unique_ptr<::mlir::Pass> create{0}() {{
259259
)";
260260

261261
const char *const defaultConstructorWithOptionsDefTemplate = R"(
262-
std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options) {{
263-
return impl::create{0}(options);
262+
std::unique_ptr<::mlir::Pass> create{0}({0}Options options) {{
263+
return impl::create{0}(std::move(options));
264264
}
265265
)";
266266

@@ -326,10 +326,10 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
326326

327327
if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) {
328328
os.indent(2) << llvm::formatv(
329-
"{0}Base(const {0}Options &options) : {0}Base() {{\n", passName);
329+
"{0}Base({0}Options options) : {0}Base() {{\n", passName);
330330

331331
for (const PassOption &opt : pass.getOptions())
332-
os.indent(4) << llvm::formatv("{0} = options.{0};\n",
332+
os.indent(4) << llvm::formatv("{0} = std::move(options.{0});\n",
333333
opt.getCppVariableName());
334334

335335
os.indent(2) << "}\n";

mlir/unittests/TableGen/PassGenTest.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ TEST(PassGenTest, PassOptions) {
7272
TestPassWithOptionsOptions options;
7373
options.testOption = 57;
7474

75-
llvm::SmallVector<int64_t, 2> testListOption = {1, 2};
76-
options.testListOption = testListOption;
75+
options.testListOption = {1, 2};
7776

7877
const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) {
7978
return static_cast<const TestPassWithOptions *>(pass.get());

0 commit comments

Comments
 (0)