-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Type annotate messengers #3309
Type annotate messengers #3309
Conversation
ordabayevy
commented
Jan 2, 2024
- plate_messenger
- reentrant_messenger
- reparam_messenger
pyro/poutine/runtime.py
Outdated
is_observed: bool | ||
args: Tuple | ||
kwargs: Dict | ||
value: Optional[torch.Tensor] | ||
value: Optional[T] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've learned this neat trick with Generic
and TypeVar
where the type of value
can be inferred from the Callable
signature. I also fixed effectful
so that it gives the correct signature for the decorated function when the return type is diferent from torch.Tensor
(e.g. reparam_messenger._get_init_messengers
).
pyro/poutine/runtime.py
Outdated
@@ -368,6 +365,7 @@ def _fn( | |||
) | |||
# apply the stack and return its return value | |||
apply_stack(msg) | |||
assert msg["value"] is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this always correct? All tests have passed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eb8680 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, just a couple nits.
pyro/poutine/runtime.py
Outdated
@@ -368,6 +365,7 @@ def _fn( | |||
) | |||
# apply the stack and return its return value | |||
apply_stack(msg) | |||
assert msg["value"] is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eb8680 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fritzo can you have another look? I have addressed your comments.
) -> Union[T, torch.Tensor, None]: | ||
obs: Optional[_T] = None, | ||
**kwargs: _P.kwargs, | ||
) -> Optional[_T]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed this back to return Optional
and removed the assert msg["value"] is not None
line. One concern I have is that if _T
itself is None
then it will raise an assertion error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, Thanks for the ping!