Skip to content

Commit 4e14b8a

Browse files
authored
[mlir] Python: Parse ModuleOp from file path (llvm#125736)
For extremely large models, it may be inefficient to load the model into memory in Python prior to passing it to the MLIR C APIs for deserialization. This change adds an API to parse a ModuleOp directly from a file path.
1 parent 0ad1f83 commit 4e14b8a

File tree

6 files changed

+52
-2
lines changed

6 files changed

+52
-2
lines changed

mlir/include/mlir-c/IR.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,10 @@ MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateEmpty(MlirLocation location);
309309
MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateParse(MlirContext context,
310310
MlirStringRef module);
311311

312+
/// Parses a module from file and transfers ownership to the caller.
313+
MLIR_CAPI_EXPORTED MlirModule
314+
mlirModuleCreateParseFromFile(MlirContext context, MlirStringRef fileName);
315+
312316
/// Gets the context that a module was created with.
313317
MLIR_CAPI_EXPORTED MlirContext mlirModuleGetContext(MlirModule module);
314318

mlir/include/mlir/Bindings/Python/Nanobind.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#endif
2424
#include <nanobind/nanobind.h>
2525
#include <nanobind/ndarray.h>
26+
#include <nanobind/stl/filesystem.h>
2627
#include <nanobind/stl/function.h>
2728
#include <nanobind/stl/optional.h>
2829
#include <nanobind/stl/pair.h>

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include <filesystem>
910
#include <optional>
1011
#include <utility>
1112

@@ -299,7 +300,7 @@ struct PyAttrBuilderMap {
299300
return *builder;
300301
}
301302
static void dunderSetItemNamed(const std::string &attributeKind,
302-
nb::callable func, bool replace) {
303+
nb::callable func, bool replace) {
303304
PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
304305
replace);
305306
}
@@ -3049,6 +3050,19 @@ void mlir::python::populateIRCore(nb::module_ &m) {
30493050
},
30503051
nb::arg("asm"), nb::arg("context").none() = nb::none(),
30513052
kModuleParseDocstring)
3053+
.def_static(
3054+
"parse",
3055+
[](const std::filesystem::path &path,
3056+
DefaultingPyMlirContext context) {
3057+
PyMlirContext::ErrorCapture errors(context->getRef());
3058+
MlirModule module = mlirModuleCreateParseFromFile(
3059+
context->get(), toMlirStringRef(path.string()));
3060+
if (mlirModuleIsNull(module))
3061+
throw MLIRError("Unable to parse module assembly", errors.take());
3062+
return PyModule::forModule(module).releaseObject();
3063+
},
3064+
nb::arg("asm"), nb::arg("context").none() = nb::none(),
3065+
kModuleParseDocstring)
30523066
.def_static(
30533067
"create",
30543068
[](DefaultingPyLocation loc) {

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/IR/Location.h"
2323
#include "mlir/IR/Operation.h"
2424
#include "mlir/IR/OperationSupport.h"
25+
#include "mlir/IR/OwningOpRef.h"
2526
#include "mlir/IR/Types.h"
2627
#include "mlir/IR/Value.h"
2728
#include "mlir/IR/Verifier.h"
@@ -328,6 +329,15 @@ MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) {
328329
return MlirModule{owning.release().getOperation()};
329330
}
330331

332+
MlirModule mlirModuleCreateParseFromFile(MlirContext context,
333+
MlirStringRef fileName) {
334+
OwningOpRef<ModuleOp> owning =
335+
parseSourceFile<ModuleOp>(unwrap(fileName), unwrap(context));
336+
if (!owning)
337+
return MlirModule{nullptr};
338+
return MlirModule{owning.release().getOperation()};
339+
}
340+
331341
MlirContext mlirModuleGetContext(MlirModule module) {
332342
return wrap(unwrap(module).getContext());
333343
}

mlir/python/mlir/_mlir_libs/_mlir/ir.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ import abc
4646
import collections
4747
from collections.abc import Callable, Sequence
4848
import io
49+
from pathlib import Path
4950
from typing import Any, ClassVar, TypeVar, overload
5051

5152
__all__ = [
@@ -2123,7 +2124,7 @@ class Module:
21232124
Creates an empty module
21242125
"""
21252126
@staticmethod
2126-
def parse(asm: str | bytes, context: Context | None = None) -> Module:
2127+
def parse(asm: str | bytes | Path, context: Context | None = None) -> Module:
21272128
"""
21282129
Parses a module's assembly format from a string.
21292130

mlir/test/python/ir/module.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# RUN: %PYTHON %s | FileCheck %s
22

33
import gc
4+
from pathlib import Path
5+
from tempfile import NamedTemporaryFile
46
from mlir.ir import *
57

68

@@ -27,6 +29,24 @@ def testParseSuccess():
2729
print(str(module))
2830

2931

32+
# Verify successful parse from file.
33+
# CHECK-LABEL: TEST: testParseFromFileSuccess
34+
# CHECK: module @successfulParse
35+
@run
36+
def testParseFromFileSuccess():
37+
ctx = Context()
38+
with NamedTemporaryFile(mode="w") as tmp_file:
39+
tmp_file.write(r"""module @successfulParse {}""")
40+
tmp_file.flush()
41+
module = Module.parse(Path(tmp_file.name), ctx)
42+
assert module.context is ctx
43+
print("CLEAR CONTEXT")
44+
ctx = None # Ensure that module captures the context.
45+
gc.collect()
46+
module.operation.verify()
47+
print(str(module))
48+
49+
3050
# Verify parse error.
3151
# CHECK-LABEL: TEST: testParseError
3252
# CHECK: testParseError: <

0 commit comments

Comments
 (0)