-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir] Python: Parse ModuleOp from file path #126572
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Nikhil Kalra (nikalra) ChangesFor extremely large models, it may be inefficient to load the model into memory in Python prior to passing it to the MLIR C APIs for deserialization. This change adds an API to parse a ModuleOp directly from a file path. Re-lands 4e14b8a. Full diff: https://github.com/llvm/llvm-project/pull/126572.diff 6 Files Affected:
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 7d2fd89e8560fc9..14ccae650606af8 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -309,6 +309,10 @@ MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateEmpty(MlirLocation location);
MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateParse(MlirContext context,
MlirStringRef module);
+/// Parses a module from file and transfers ownership to the caller.
+MLIR_CAPI_EXPORTED MlirModule
+mlirModuleCreateParseFromFile(MlirContext context, MlirStringRef fileName);
+
/// Gets the context that a module was created with.
MLIR_CAPI_EXPORTED MlirContext mlirModuleGetContext(MlirModule module);
diff --git a/mlir/include/mlir/Bindings/Python/Nanobind.h b/mlir/include/mlir/Bindings/Python/Nanobind.h
index ca942c83d3e2fad..48b23b57df10841 100644
--- a/mlir/include/mlir/Bindings/Python/Nanobind.h
+++ b/mlir/include/mlir/Bindings/Python/Nanobind.h
@@ -23,6 +23,9 @@
#endif
#include <nanobind/nanobind.h>
#include <nanobind/ndarray.h>
+#if __has_include(<filesystem>)
+#include <nanobind/stl/filesystem.h>
+#endif
#include <nanobind/stl/function.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/pair.h>
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 47a85c2a486fd46..81936323631dd7d 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -6,6 +6,9 @@
//
//===----------------------------------------------------------------------===//
+#if __has_include(<filesystem>)
+#include <filesystem>
+#endif
#include <optional>
#include <utility>
@@ -299,7 +302,7 @@ struct PyAttrBuilderMap {
return *builder;
}
static void dunderSetItemNamed(const std::string &attributeKind,
- nb::callable func, bool replace) {
+ nb::callable func, bool replace) {
PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
replace);
}
@@ -3049,6 +3052,21 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
nb::arg("asm"), nb::arg("context").none() = nb::none(),
kModuleParseDocstring)
+#if __has_include(<filesystem>)
+ .def_static(
+ "parse",
+ [](const std::filesystem::path &path,
+ DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
+ MlirModule module = mlirModuleCreateParseFromFile(
+ context->get(), toMlirStringRef(path.string()));
+ if (mlirModuleIsNull(module))
+ throw MLIRError("Unable to parse module assembly", errors.take());
+ return PyModule::forModule(module).releaseObject();
+ },
+ nb::arg("asm"), nb::arg("context").none() = nb::none(),
+ kModuleParseDocstring)
+#endif
.def_static(
"create",
[](DefaultingPyLocation loc) {
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index f27af0ca9a2c78b..999e8cbda1295a1 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/Location.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/OwningOpRef.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Verifier.h"
@@ -328,6 +329,15 @@ MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) {
return MlirModule{owning.release().getOperation()};
}
+MlirModule mlirModuleCreateParseFromFile(MlirContext context,
+ MlirStringRef fileName) {
+ OwningOpRef<ModuleOp> owning =
+ parseSourceFile<ModuleOp>(unwrap(fileName), unwrap(context));
+ if (!owning)
+ return MlirModule{nullptr};
+ return MlirModule{owning.release().getOperation()};
+}
+
MlirContext mlirModuleGetContext(MlirModule module) {
return wrap(unwrap(module).getContext());
}
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index fb7efb8cd28a5eb..096b87b36244368 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -46,6 +46,7 @@ import abc
import collections
from collections.abc import Callable, Sequence
import io
+from pathlib import Path
from typing import Any, ClassVar, TypeVar, overload
__all__ = [
@@ -2123,7 +2124,7 @@ class Module:
Creates an empty module
"""
@staticmethod
- def parse(asm: str | bytes, context: Context | None = None) -> Module:
+ def parse(asm: str | bytes | Path, context: Context | None = None) -> Module:
"""
Parses a module's assembly format from a string.
diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py
index ecafcb46af2175d..441916b38ee73bb 100644
--- a/mlir/test/python/ir/module.py
+++ b/mlir/test/python/ir/module.py
@@ -1,6 +1,8 @@
# RUN: %PYTHON %s | FileCheck %s
import gc
+from pathlib import Path
+from tempfile import NamedTemporaryFile
from mlir.ir import *
@@ -27,6 +29,24 @@ def testParseSuccess():
print(str(module))
+# Verify successful parse from file.
+# CHECK-LABEL: TEST: testParseFromFileSuccess
+# CHECK: module @successfulParse
+@run
+def testParseFromFileSuccess():
+ ctx = Context()
+ with NamedTemporaryFile(mode="w") as tmp_file:
+ tmp_file.write(r"""module @successfulParse {}""")
+ tmp_file.flush()
+ module = Module.parse(Path(tmp_file.name), ctx)
+ assert module.context is ctx
+ print("CLEAR CONTEXT")
+ ctx = None # Ensure that module captures the context.
+ gc.collect()
+ module.operation.verify()
+ print(str(module))
+
+
# Verify parse error.
# CHECK-LABEL: TEST: testParseError
# CHECK: testParseError: <
|
Basically LGTM but just double checking, is |
mlir/lib/Bindings/Python/IRCore.cpp
Outdated
@@ -3049,6 +3052,21 @@ void mlir::python::populateIRCore(nb::module_ &m) { | |||
}, | |||
nb::arg("asm"), nb::arg("context").none() = nb::none(), | |||
kModuleParseDocstring) | |||
#if __has_include(<filesystem>) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that's a OK solution: the test will fail on gcc7 builds.
mlir/lib/Bindings/Python/IRCore.cpp
Outdated
.def_static( | ||
"parse", | ||
[](const std::filesystem::path &path, | ||
DefaultingPyMlirContext context) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other places in the python bindings are using either std::string
for filepath, or a "streaming" idiom with a "file object" (and our own PyFileAccumulator
), I would just align with either of these patterns.
https://github.com/llvm/llvm-project/blob/main/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp#L77
https://github.com/llvm/llvm-project/blob/main/mlir/lib/Bindings/Python/Pass.cpp#L82
https://github.com/llvm/llvm-project/blob/main/mlir/lib/Bindings/Python/IRCore.cpp#L1321-L1336
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was already an overload that accepted string asm; I ended up adding a new parseFile
API instead. It's a little messier in terms of the API surface, but hopefully that should work across all of the supported configs.
@joker-eph where does it say gcc7 is within the support window? The project is -std=c++17 and filesystem is a c++17 library that simply wasn't supported by GCC until 8. Seems we in fact do not support gcc7. |
The minimum versions are documented at https://llvm.org/docs/GettingStarted.html#host-c-toolchain-both-compiler-and-standard-library |
Ok. A little weird I guess since I would assume std=c++17 includes STL but okay. |
For extremely large models, it may be inefficient to load the model into memory in Python prior to passing it to the MLIR C APIs for deserialization. This change adds an API to parse a ModuleOp directly from a file path. Re-lands [4e14b8a](llvm@4e14b8a).
For extremely large models, it may be inefficient to load the model into memory in Python prior to passing it to the MLIR C APIs for deserialization. This change adds an API to parse a ModuleOp directly from a file path. Re-lands [4e14b8a](llvm@4e14b8a).
For extremely large models, it may be inefficient to load the model into memory in Python prior to passing it to the MLIR C APIs for deserialization. This change adds an API to parse a ModuleOp directly from a file path. Re-lands [4e14b8a](llvm@4e14b8a).
For extremely large models, it may be inefficient to load the model into memory in Python prior to passing it to the MLIR C APIs for deserialization. This change adds an API to parse a ModuleOp directly from a file path.
Re-lands 4e14b8a.