Skip to content

Commit

Permalink
Fix type hints Python 3.9 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelosthege authored and ricardoV94 committed Oct 11, 2023
1 parent 3bc68bf commit 57d73cc
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def sample_blackjax_nuts(
var_names: Optional[Sequence[str]] = None,
keep_untransformed: bool = False,
chain_method: str = "parallel",
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
postprocessing_backend: Optional[Literal["cpu", "gpu"]] = None,
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
idata_kwargs: Optional[Dict[str, Any]] = None,
postprocessing_chunks=None, # deprecated
Expand Down Expand Up @@ -546,7 +546,7 @@ def sample_numpyro_nuts(
progressbar: bool = True,
keep_untransformed: bool = False,
chain_method: str = "parallel",
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
postprocessing_backend: Optional[Literal["cpu", "gpu"]] = None,
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
idata_kwargs: Optional[Dict] = None,
nuts_kwargs: Optional[Dict] = None,
Expand Down

0 comments on commit 57d73cc

Please sign in to comment.