Skip to content

Commit 0c1aad5

Browse files
authored
fix: threadpool configuration (#2671)
1 parent eb25424 commit 0c1aad5

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

src/zarr/core/sync.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@ def _get_executor() -> ThreadPoolExecutor:
5454
global _executor
5555
if not _executor:
5656
max_workers = config.get("threading.max_workers", None)
57-
print(max_workers)
58-
# if max_workers is not None and max_workers > 0:
59-
# raise ValueError(max_workers)
57+
logger.debug("Creating Zarr ThreadPoolExecutor with max_workers=%s", max_workers)
6058
_executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="zarr_pool")
6159
_get_loop().set_default_executor(_executor)
6260
return _executor
@@ -118,6 +116,9 @@ def sync(
118116
# NB: if the loop is not running *yet*, it is OK to submit work
119117
# and we will wait for it
120118
loop = _get_loop()
119+
if _executor is None and config.get("threading.max_workers", None) is not None:
120+
# trigger executor creation and attach to loop
121+
_ = _get_executor()
121122
if not isinstance(loop, asyncio.AbstractEventLoop):
122123
raise TypeError(f"loop cannot be of type {type(loop)}")
123124
if loop.is_closed():
@@ -153,6 +154,7 @@ def _get_loop() -> asyncio.AbstractEventLoop:
153154
# repeat the check just in case the loop got filled between the
154155
# previous two calls from another thread
155156
if loop[0] is None:
157+
logger.debug("Creating Zarr event loop")
156158
new_loop = asyncio.new_event_loop()
157159
loop[0] = new_loop
158160
iothread[0] = threading.Thread(target=new_loop.run_forever, name="zarr_io")

tests/test_sync.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
_get_lock,
1313
_get_loop,
1414
cleanup_resources,
15+
loop,
1516
sync,
1617
)
1718
from zarr.storage import MemoryStore
@@ -148,11 +149,20 @@ def test_open_positional_args_deprecate():
148149

149150

150151
@pytest.mark.parametrize("workers", [None, 1, 2])
151-
def test_get_executor(clean_state, workers) -> None:
152+
def test_threadpool_executor(clean_state, workers: int | None) -> None:
152153
with zarr.config.set({"threading.max_workers": workers}):
153-
e = _get_executor()
154-
if workers is not None and workers != 0:
155-
assert e._max_workers == workers
154+
_ = zarr.zeros(shape=(1,)) # trigger executor creation
155+
assert loop != [None] # confirm loop was created
156+
if workers is None:
157+
# confirm no executor was created if no workers were specified
158+
# (this is the default behavior)
159+
assert loop[0]._default_executor is None
160+
else:
161+
# confirm executor was created and attached to loop as the default executor
162+
# note: python doesn't have a direct way to get the default executor so we
163+
# use the private attribute
164+
assert _get_executor() is loop[0]._default_executor
165+
assert _get_executor()._max_workers == workers
156166

157167

158168
def test_cleanup_resources_idempotent() -> None:

0 commit comments

Comments
 (0)