Skip to content

Commit 57d73cc

Browse files
michaelosthegericardoV94
authored andcommitted
Fix type hints Python 3.9 compatibility
1 parent 3bc68bf commit 57d73cc

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pymc/sampling/jax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def sample_blackjax_nuts(
336336
var_names: Optional[Sequence[str]] = None,
337337
keep_untransformed: bool = False,
338338
chain_method: str = "parallel",
339-
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
339+
postprocessing_backend: Optional[Literal["cpu", "gpu"]] = None,
340340
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
341341
idata_kwargs: Optional[Dict[str, Any]] = None,
342342
postprocessing_chunks=None, # deprecated
@@ -546,7 +546,7 @@ def sample_numpyro_nuts(
546546
progressbar: bool = True,
547547
keep_untransformed: bool = False,
548548
chain_method: str = "parallel",
549-
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
549+
postprocessing_backend: Optional[Literal["cpu", "gpu"]] = None,
550550
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
551551
idata_kwargs: Optional[Dict] = None,
552552
nuts_kwargs: Optional[Dict] = None,

0 commit comments

Comments
 (0)