Skip to content

Commit ab18cc2

Browse files
authored
[MLIR][py] Add PyThreadPool as wrapper around MlirLlvmThreadPool in MLIR python bindings (#130109)
In some projects like JAX ir.Context are used with disabled multi-threading to avoid caching multiple threading pools: https://github.com/jax-ml/jax/blob/623865fe9538100d877ba9d36f788d0f95a11ed2/jax/_src/interpreters/mlir.py#L606-L611 However, when context has enabled multithreading it also uses locks on the StorageUniquers and this can be helpful to avoid data races in the multi-threaded execution (for example with free-threaded cpython, jax-ml/jax#26272). With this PR user can enable the multi-threading: 1) enables additional locking and 2) set a shared threading pool such that cached contexts can have one global pool.
1 parent 78060a7 commit ab18cc2

File tree

6 files changed

+106
-3
lines changed

6 files changed

+106
-3
lines changed

mlir/include/mlir-c/IR.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,15 @@ MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context,
162162
MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context,
163163
MlirLlvmThreadPool threadPool);
164164

165+
/// Gets the number of threads of the thread pool of the context when
166+
/// multithreading is enabled. Returns 1 if no multithreading.
167+
MLIR_CAPI_EXPORTED unsigned mlirContextGetNumThreads(MlirContext context);
168+
169+
/// Gets the thread pool of the context when enabled multithreading, otherwise
170+
/// an assertion is raised.
171+
MLIR_CAPI_EXPORTED MlirLlvmThreadPool
172+
mlirContextGetThreadPool(MlirContext context);
173+
165174
//===----------------------------------------------------------------------===//
166175
// Dialect API.
167176
//===----------------------------------------------------------------------===//

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2743,6 +2743,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
27432743
// __init__.py will subclass it with site-specific functionality and set a
27442744
// "Context" attribute on this module.
27452745
//----------------------------------------------------------------------------
2746+
2747+
// Expose DefaultThreadPool to python
2748+
nb::class_<PyThreadPool>(m, "ThreadPool")
2749+
.def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); })
2750+
.def("get_max_concurrency", &PyThreadPool::getMaxConcurrency)
2751+
.def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr);
2752+
27462753
nb::class_<PyMlirContext>(m, "_BaseContext")
27472754
.def("__init__",
27482755
[](PyMlirContext &self) {
@@ -2814,6 +2821,25 @@ void mlir::python::populateIRCore(nb::module_ &m) {
28142821
mlirContextEnableMultithreading(self.get(), enable);
28152822
},
28162823
nb::arg("enable"))
2824+
.def("set_thread_pool",
2825+
[](PyMlirContext &self, PyThreadPool &pool) {
2826+
// we should disable multi-threading first before setting
2827+
// new thread pool otherwise the assert in
2828+
// MLIRContext::setThreadPool will be raised.
2829+
mlirContextEnableMultithreading(self.get(), false);
2830+
mlirContextSetThreadPool(self.get(), pool.get());
2831+
})
2832+
.def("get_num_threads",
2833+
[](PyMlirContext &self) {
2834+
return mlirContextGetNumThreads(self.get());
2835+
})
2836+
.def("_mlir_thread_pool_ptr",
2837+
[](PyMlirContext &self) {
2838+
MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
2839+
std::stringstream ss;
2840+
ss << pool.ptr;
2841+
return ss.str();
2842+
})
28172843
.def(
28182844
"is_registered_operation",
28192845
[](PyMlirContext &self, std::string &name) {

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#define MLIR_BINDINGS_PYTHON_IRMODULES_H
1212

1313
#include <optional>
14+
#include <sstream>
1415
#include <utility>
1516
#include <vector>
1617

@@ -22,9 +23,10 @@
2223
#include "mlir-c/IR.h"
2324
#include "mlir-c/IntegerSet.h"
2425
#include "mlir-c/Transforms.h"
25-
#include "mlir/Bindings/Python/NanobindAdaptors.h"
2626
#include "mlir/Bindings/Python/Nanobind.h"
27+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
2728
#include "llvm/ADT/DenseMap.h"
29+
#include "llvm/Support/ThreadPool.h"
2830

2931
namespace mlir {
3032
namespace python {
@@ -158,6 +160,29 @@ class PyThreadContextEntry {
158160
FrameKind frameKind;
159161
};
160162

163+
/// Wrapper around MlirLlvmThreadPool
164+
/// Python object owns the C++ thread pool
165+
class PyThreadPool {
166+
public:
167+
PyThreadPool() {
168+
ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
169+
}
170+
PyThreadPool(const PyThreadPool &) = delete;
171+
PyThreadPool(PyThreadPool &&) = delete;
172+
173+
int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); }
174+
MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); }
175+
176+
std::string _mlir_thread_pool_ptr() const {
177+
std::stringstream ss;
178+
ss << ownedThreadPool.get();
179+
return ss.str();
180+
}
181+
182+
private:
183+
std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
184+
};
185+
161186
/// Wrapper around MlirContext.
162187
using PyMlirContextRef = PyObjectRef<PyMlirContext>;
163188
class PyMlirContext {

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,14 @@ void mlirContextSetThreadPool(MlirContext context,
114114
unwrap(context)->setThreadPool(*unwrap(threadPool));
115115
}
116116

117+
unsigned mlirContextGetNumThreads(MlirContext context) {
118+
return unwrap(context)->getNumThreads();
119+
}
120+
121+
MlirLlvmThreadPool mlirContextGetThreadPool(MlirContext context) {
122+
return wrap(&unwrap(context)->getThreadPool());
123+
}
124+
117125
//===----------------------------------------------------------------------===//
118126
// Dialect API.
119127
//===----------------------------------------------------------------------===//

mlir/python/mlir/_mlir_libs/__init__.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,25 @@ def process_initializer_module(module_name):
148148
break
149149

150150
class Context(ir._BaseContext):
151-
def __init__(self, load_on_create_dialects=None, *args, **kwargs):
151+
def __init__(
152+
self, load_on_create_dialects=None, thread_pool=None, *args, **kwargs
153+
):
152154
super().__init__(*args, **kwargs)
153155
self.append_dialect_registry(get_dialect_registry())
154156
for hook in post_init_hooks:
155157
hook(self)
158+
if disable_multithreading and thread_pool is not None:
159+
raise ValueError(
160+
"Context constructor has given thread_pool argument, "
161+
"but disable_multithreading flag is True. "
162+
"Please, set thread_pool argument to None or "
163+
"set disable_multithreading flag to False."
164+
)
156165
if not disable_multithreading:
157-
self.enable_multithreading(True)
166+
if thread_pool is None:
167+
self.enable_multithreading(True)
168+
else:
169+
self.set_thread_pool(thread_pool)
158170
if load_on_create_dialects is not None:
159171
logger.debug(
160172
"Loading all dialects from load_on_create_dialects arg %r",

mlir/test/python/ir/context_lifecycle.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,26 @@
4747
assert '"mlir.ir.Context._CAPIPtr"' in repr(c4_capsule)
4848
c5 = mlir.ir.Context._CAPICreate(c4_capsule)
4949
assert c4 is c5
50+
c4 = None
51+
c5 = None
52+
gc.collect()
53+
54+
# Create a global threadpool and use it in two contexts
55+
tp = mlir.ir.ThreadPool()
56+
assert tp.get_max_concurrency() > 0
57+
c5 = mlir.ir.Context()
58+
c5.set_thread_pool(tp)
59+
assert c5.get_num_threads() == tp.get_max_concurrency()
60+
assert c5._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr()
61+
c6 = mlir.ir.Context()
62+
c6.set_thread_pool(tp)
63+
assert c6.get_num_threads() == tp.get_max_concurrency()
64+
assert c6._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr()
65+
c7 = mlir.ir.Context(thread_pool=tp)
66+
assert c7.get_num_threads() == tp.get_max_concurrency()
67+
assert c7._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr()
68+
assert mlir.ir.Context._get_live_count() == 3
69+
c5 = None
70+
c6 = None
71+
c7 = None
72+
gc.collect()

0 commit comments

Comments
 (0)