Skip to content

[mlir] Python: Parse ModuleOp from file path #126572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 12, 2025
Merged

Conversation

nikalra
Copy link
Contributor

@nikalra nikalra commented Feb 10, 2025

For extremely large models, it may be inefficient to load the model into memory in Python prior to passing it to the MLIR C APIs for deserialization. This change adds an API to parse a ModuleOp directly from a file path.

Re-lands 4e14b8a.

@llvmbot
Copy link
Member

llvmbot commented Feb 10, 2025

@llvm/pr-subscribers-mlir

Author: Nikhil Kalra (nikalra)

Changes

For extremely large models, it may be inefficient to load the model into memory in Python prior to passing it to the MLIR C APIs for deserialization. This change adds an API to parse a ModuleOp directly from a file path.

Re-lands 4e14b8a.


Full diff: https://github.com/llvm/llvm-project/pull/126572.diff

6 Files Affected:

  • (modified) mlir/include/mlir-c/IR.h (+4)
  • (modified) mlir/include/mlir/Bindings/Python/Nanobind.h (+3)
  • (modified) mlir/lib/Bindings/Python/IRCore.cpp (+19-1)
  • (modified) mlir/lib/CAPI/IR/IR.cpp (+10)
  • (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+2-1)
  • (modified) mlir/test/python/ir/module.py (+20)
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 7d2fd89e8560fc9..14ccae650606af8 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -309,6 +309,10 @@ MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateEmpty(MlirLocation location);
 MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateParse(MlirContext context,
                                                     MlirStringRef module);
 
+/// Parses a module from file and transfers ownership to the caller.
+MLIR_CAPI_EXPORTED MlirModule
+mlirModuleCreateParseFromFile(MlirContext context, MlirStringRef fileName);
+
 /// Gets the context that a module was created with.
 MLIR_CAPI_EXPORTED MlirContext mlirModuleGetContext(MlirModule module);
 
diff --git a/mlir/include/mlir/Bindings/Python/Nanobind.h b/mlir/include/mlir/Bindings/Python/Nanobind.h
index ca942c83d3e2fad..48b23b57df10841 100644
--- a/mlir/include/mlir/Bindings/Python/Nanobind.h
+++ b/mlir/include/mlir/Bindings/Python/Nanobind.h
@@ -23,6 +23,9 @@
 #endif
 #include <nanobind/nanobind.h>
 #include <nanobind/ndarray.h>
+#if __has_include(<filesystem>)
+#include <nanobind/stl/filesystem.h>
+#endif
 #include <nanobind/stl/function.h>
 #include <nanobind/stl/optional.h>
 #include <nanobind/stl/pair.h>
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 47a85c2a486fd46..81936323631dd7d 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -6,6 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 
+#if __has_include(<filesystem>)
+#include <filesystem>
+#endif
 #include <optional>
 #include <utility>
 
@@ -299,7 +302,7 @@ struct PyAttrBuilderMap {
     return *builder;
   }
   static void dunderSetItemNamed(const std::string &attributeKind,
-                                nb::callable func, bool replace) {
+                                 nb::callable func, bool replace) {
     PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
                                               replace);
   }
@@ -3049,6 +3052,21 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           },
           nb::arg("asm"), nb::arg("context").none() = nb::none(),
           kModuleParseDocstring)
+#if __has_include(<filesystem>)
+      .def_static(
+          "parse",
+          [](const std::filesystem::path &path,
+             DefaultingPyMlirContext context) {
+            PyMlirContext::ErrorCapture errors(context->getRef());
+            MlirModule module = mlirModuleCreateParseFromFile(
+                context->get(), toMlirStringRef(path.string()));
+            if (mlirModuleIsNull(module))
+              throw MLIRError("Unable to parse module assembly", errors.take());
+            return PyModule::forModule(module).releaseObject();
+          },
+          nb::arg("asm"), nb::arg("context").none() = nb::none(),
+          kModuleParseDocstring)
+#endif
       .def_static(
           "create",
           [](DefaultingPyLocation loc) {
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index f27af0ca9a2c78b..999e8cbda1295a1 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -22,6 +22,7 @@
 #include "mlir/IR/Location.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/OwningOpRef.h"
 #include "mlir/IR/Types.h"
 #include "mlir/IR/Value.h"
 #include "mlir/IR/Verifier.h"
@@ -328,6 +329,15 @@ MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) {
   return MlirModule{owning.release().getOperation()};
 }
 
+MlirModule mlirModuleCreateParseFromFile(MlirContext context,
+                                         MlirStringRef fileName) {
+  OwningOpRef<ModuleOp> owning =
+      parseSourceFile<ModuleOp>(unwrap(fileName), unwrap(context));
+  if (!owning)
+    return MlirModule{nullptr};
+  return MlirModule{owning.release().getOperation()};
+}
+
 MlirContext mlirModuleGetContext(MlirModule module) {
   return wrap(unwrap(module).getContext());
 }
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index fb7efb8cd28a5eb..096b87b36244368 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -46,6 +46,7 @@ import abc
 import collections
 from collections.abc import Callable, Sequence
 import io
+from pathlib import Path
 from typing import Any, ClassVar, TypeVar, overload
 
 __all__ = [
@@ -2123,7 +2124,7 @@ class Module:
         Creates an empty module
         """
     @staticmethod
-    def parse(asm: str | bytes, context: Context | None = None) -> Module:
+    def parse(asm: str | bytes | Path, context: Context | None = None) -> Module:
         """
         Parses a module's assembly format from a string.
 
diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py
index ecafcb46af2175d..441916b38ee73bb 100644
--- a/mlir/test/python/ir/module.py
+++ b/mlir/test/python/ir/module.py
@@ -1,6 +1,8 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 import gc
+from pathlib import Path
+from tempfile import NamedTemporaryFile
 from mlir.ir import *
 
 
@@ -27,6 +29,24 @@ def testParseSuccess():
     print(str(module))
 
 
+# Verify successful parse from file.
+# CHECK-LABEL: TEST: testParseFromFileSuccess
+# CHECK: module @successfulParse
+@run
+def testParseFromFileSuccess():
+    ctx = Context()
+    with NamedTemporaryFile(mode="w") as tmp_file:
+        tmp_file.write(r"""module @successfulParse {}""")
+        tmp_file.flush()
+        module = Module.parse(Path(tmp_file.name), ctx)
+        assert module.context is ctx
+        print("CLEAR CONTEXT")
+        ctx = None  # Ensure that module captures the context.
+        gc.collect()
+        module.operation.verify()
+        print(str(module))
+
+
 # Verify parse error.
 # CHECK-LABEL: TEST: testParseError
 # CHECK: testParseError: <

@makslevental
Copy link
Contributor

Basically LGTM but just double checking, is __has_include standard preprocessor or some kind of extension?

@@ -3049,6 +3052,21 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
nb::arg("asm"), nb::arg("context").none() = nb::none(),
kModuleParseDocstring)
#if __has_include(<filesystem>)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's a OK solution: the test will fail on gcc7 builds.

.def_static(
"parse",
[](const std::filesystem::path &path,
DefaultingPyMlirContext context) {
Copy link
Collaborator

@joker-eph joker-eph Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other places in the python bindings are using either std::string for filepath, or a "streaming" idiom with a "file object" (and our own PyFileAccumulator), I would just align with either of these patterns.

https://github.com/llvm/llvm-project/blob/main/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp#L77
https://github.com/llvm/llvm-project/blob/main/mlir/lib/Bindings/Python/Pass.cpp#L82
https://github.com/llvm/llvm-project/blob/main/mlir/lib/Bindings/Python/IRCore.cpp#L1321-L1336

Copy link
Contributor Author

@nikalra nikalra Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was already an overload that accepted string asm; I ended up adding a new parseFile API instead. It's a little messier in terms of the API surface, but hopefully that should work across all of the supported configs.

@makslevental
Copy link
Contributor

@joker-eph where does it say gcc7 is within the support window? The project is -std=c++17 and filesystem is a c++17 library that simply wasn't supported by GCC until 8. Seems we in fact do not support gcc7.

@jpienaar
Copy link
Member

The minimum versions are documented at https://llvm.org/docs/GettingStarted.html#host-c-toolchain-both-compiler-and-standard-library

@makslevental
Copy link
Contributor

The minimum versions are documented at https://llvm.org/docs/GettingStarted.html#host-c-toolchain-both-compiler-and-standard-library

Ok. A little weird I guess since I would assume std=c++17 includes STL but okay.

@nikalra nikalra merged commit 65ed4fa into llvm:main Feb 12, 2025
8 checks passed
@nikalra nikalra deleted the reland-path-py branch February 12, 2025 22:02
flovent pushed a commit to flovent/llvm-project that referenced this pull request Feb 13, 2025
For extremely large models, it may be inefficient to load the model into
memory in Python prior to passing it to the MLIR C APIs for
deserialization. This change adds an API to parse a ModuleOp directly
from a file path.

Re-lands
[4e14b8a](llvm@4e14b8a).
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
For extremely large models, it may be inefficient to load the model into
memory in Python prior to passing it to the MLIR C APIs for
deserialization. This change adds an API to parse a ModuleOp directly
from a file path.

Re-lands
[4e14b8a](llvm@4e14b8a).
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
For extremely large models, it may be inefficient to load the model into
memory in Python prior to passing it to the MLIR C APIs for
deserialization. This change adds an API to parse a ModuleOp directly
from a file path.

Re-lands
[4e14b8a](llvm@4e14b8a).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants