-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: support decollate for numpy scalars #8470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
fix: support decollate for numpy scalars #8470
Conversation
187c141
to
c438fe0
Compare
@@ -625,6 +625,8 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None): | |||
type(batch).__module__ == "numpy" and not isinstance(batch, Iterable) | |||
): | |||
return batch | |||
if isinstance(batch, np.ndarray) and batch.ndim == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pr! Do you think it might be beneficial to convert the array into a tensor? This way, the data could be handled more consistently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could, I think it does not matter for my use cases. As long as the function handles numpy scalars in the form of an array it is good for me!
I will add this change and convert it as a tensor there (L629) if you prefer :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the quick fix!
May I ask the reason for only convert to tensor when batch.ndim == 0
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed a different behavior when using the decollate_batch
function on torch tensors vs numpy arrays (see discussion #8472) so I don't want to convert numpy arrays to torch tensors as it will introduce some breaking changes
This PR only address the issue #8471 as I think it was not expected and should be supported (?).
fix linter Signed-off-by: Arthur Dujardin <[email protected]> fix numpy decollate multi arrays Signed-off-by: Arthur Dujardin <[email protected]> fix linter Signed-off-by: Arthur Dujardin <[email protected]> fix numpy scalar support Signed-off-by: Arthur Dujardin <[email protected]> minor refactoring for typing Signed-off-by: Arthur Dujardin <[email protected]> convert scalar array to tensor Signed-off-by: Arthur Dujardin <[email protected]>
451c207
to
49d4954
Compare
Could we consider a more complete solution? The issue it seems is that 0-d arrays are iterable but can't be iterated over. We already check for non-iterable things in |
Thanks for the feedback. The initial PR was: if isinstance(batch, (float, int, str, bytes)) or (
type(batch).__module__ == "numpy" and not isinstance(batch, Iterable)
):
return batch
if isinstance(batch, np.ndarray) and batch.ndim == 0:
return batch.item() if detach else batch
# rest ... Is this something that you find more complete? Note I refactored the PR to convert from numpy array to torch tensor as suggested by @KumoLiu. |
What I had in mind was more of the following change: ...
if batch is None or isinstance(batch, (float, int, str, bytes)):
return batch
if getattr(batch, "ndim", -1) == 0: # assumes only Numpy objects and Pytorch tensors have ndim
return batch.item() if detach else batch
if isinstance(batch, torch.Tensor):
if detach:
batch = batch.detach()
# REMOVE
# if batch.ndim == 0:
# return batch.item() if detach else batch
... |
Thanks! I will update the PR to include these changes. |
Description
This PR supports numpy scalars (e.g. in the form of
np.array(1)
) in thedecollate_batch
function (fix issue #8471).Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.