-
-
Notifications
You must be signed in to change notification settings - Fork 985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Annotate handlers & add py.typed #3321
Conversation
df2faa3
to
a0921aa
Compare
a0921aa
to
9f4166e
Compare
@overload | ||
def condition( | ||
data: Union[Dict[str, "torch.Tensor"], "Trace"], | ||
) -> ConditionMessenger: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some handlers have arguments that are not optional so fn
has to be required in the signature. This is the trick with overloading to have fn
as optional. Based on the signature mypy can figure out which type annotations to use.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, I have just a couple questions.
|
||
|
||
@_make_handler(BlockMessenger) | ||
def block( # type: ignore[empty-body] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be more idiomatic to use pass
rather than ... # type: ignore[empty-body]
, here and in other targets of @_make_handler
? Or would that cause mypy to complain about the return type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just checked it, it gives the same empty-body mypy error:
(pyro) yordabay@yo-dl-dev:/mnt/disks/dev/repos/pyro$ git diff
diff --git a/pyro/poutine/handlers.py b/pyro/poutine/handlers.py
index c54da6cf..404ae667 100644
--- a/pyro/poutine/handlers.py
+++ b/pyro/poutine/handlers.py
@@ -165,7 +165,7 @@ def block(
@_make_handler(BlockMessenger)
-def block( # type: ignore[empty-body]
+def block(
fn: Optional[Callable[_P, _T]] = None,
hide_fn: Optional[Callable[["Message"], Optional[bool]]] = None,
expose_fn: Optional[Callable[["Message"], Optional[bool]]] = None,
@@ -175,7 +175,8 @@ def block( # type: ignore[empty-body]
expose: Optional[List[str]] = None,
hide_types: Optional[List[str]] = None,
expose_types: Optional[List[str]] = None,
-) -> Union[BlockMessenger, Callable[_P, _T]]: ...
+) -> Union[BlockMessenger, Callable[_P, _T]]:
+ pass
@overload
(pyro) yordabay@yo-dl-dev:/mnt/disks/dev/repos/pyro$ make lint
ruff check .
black --check *.py pyro examples tests scripts profiler
Skipping .ipynb files as Jupyter dependencies are not installed.
You can fix this by running ``pip install "black[jupyter]"``
All done! ✨ 🍰 ✨
621 files would be left unchanged.
python scripts/update_headers.py --check
mypy --install-types --non-interactive pyro scripts
pyro/poutine/handlers.py:168: error: Missing return statement [empty-body]
Found 1 error in 1 file (checked 321 source files)
make: *** [Makefile:24: lint] Error 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for checking!
@@ -94,14 +94,14 @@ def _masked_observe( | |||
name: str, | |||
fn: TorchDistributionMixin, | |||
obs: Optional[torch.Tensor], | |||
obs_mask: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Is torch.cuda.BoolTensor
a subclass of torch.BoolTensor
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like not:
>>> issubclass(torch.cuda.BoolTensor, torch.Tensor)
False
>>> issubclass(torch.cuda.BoolTensor, torch.BoolTensor)
False
But it should be okay since it is only used for type checking, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it should be ok. Let's just keep in mind that assert isinstance(obs_mask, torch.BoolTensor)
would fail in runtime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. That's why I try to use assert x is not None
when x: Optional[torch.Tensor]
*args, | ||
**kwargs, | ||
) -> torch.Tensor: | ||
# Split into two auxiliary sample sites. | ||
with poutine.mask(mask=obs_mask): | ||
observed = sample(f"{name}_observed", fn, *args, **kwargs, obs=obs) | ||
with poutine.mask(mask=~obs_mask): | ||
with poutine.mask(mask=~obs_mask): # type: ignore[call-overload] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, unfortunately, ~torch.BoolTensor
returns a torch.Tensor
No description provided.