We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 3bc68bf commit 57d73ccCopy full SHA for 57d73cc
pymc/sampling/jax.py
@@ -336,7 +336,7 @@ def sample_blackjax_nuts(
336
var_names: Optional[Sequence[str]] = None,
337
keep_untransformed: bool = False,
338
chain_method: str = "parallel",
339
- postprocessing_backend: Literal["cpu", "gpu"] | None = None,
+ postprocessing_backend: Optional[Literal["cpu", "gpu"]] = None,
340
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
341
idata_kwargs: Optional[Dict[str, Any]] = None,
342
postprocessing_chunks=None, # deprecated
@@ -546,7 +546,7 @@ def sample_numpyro_nuts(
546
progressbar: bool = True,
547
548
549
550
551
idata_kwargs: Optional[Dict] = None,
552
nuts_kwargs: Optional[Dict] = None,
0 commit comments