Skip to content

fix: support decollate for numpy scalars #8470

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from

Conversation

arthurdjn
Copy link

@arthurdjn arthurdjn commented Jun 3, 2025

Description

This PR supports numpy scalars (e.g. in the form of np.array(1) ) in the decollate_batch function (fix issue #8471).

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@arthurdjn arthurdjn force-pushed the support-decollate-batch-numpy-scalars branch 5 times, most recently from 187c141 to c438fe0 Compare June 3, 2025 13:03
@arthurdjn arthurdjn marked this pull request as ready for review June 3, 2025 14:20
@@ -625,6 +625,8 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None):
type(batch).__module__ == "numpy" and not isinstance(batch, Iterable)
):
return batch
if isinstance(batch, np.ndarray) and batch.ndim == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pr! Do you think it might be beneficial to convert the array into a tensor? This way, the data could be handled more consistently.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, I think it does not matter for my use cases. As long as the function handles numpy scalars in the form of an array it is good for me!

I will add this change and convert it as a tensor there (L629) if you prefer :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick fix!
May I ask the reason for only convert to tensor when batch.ndim == 0 here?

Copy link
Author

@arthurdjn arthurdjn Jun 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed a different behavior when using the decollate_batch function on torch tensors vs numpy arrays (see discussion #8472) so I don't want to convert numpy arrays to torch tensors as it will introduce some breaking changes

This PR only address the issue #8471 as I think it was not expected and should be supported (?).

fix linter

Signed-off-by: Arthur Dujardin <[email protected]>

fix numpy decollate multi arrays

Signed-off-by: Arthur Dujardin <[email protected]>

fix linter

Signed-off-by: Arthur Dujardin <[email protected]>

fix numpy scalar support

Signed-off-by: Arthur Dujardin <[email protected]>

minor refactoring for typing

Signed-off-by: Arthur Dujardin <[email protected]>

convert scalar array to tensor

Signed-off-by: Arthur Dujardin <[email protected]>
@arthurdjn arthurdjn force-pushed the support-decollate-batch-numpy-scalars branch from 451c207 to 49d4954 Compare June 4, 2025 09:00
@arthurdjn arthurdjn requested a review from KumoLiu June 4, 2025 09:01
@ericspod
Copy link
Member

ericspod commented Jun 10, 2025

Could we consider a more complete solution? The issue it seems is that 0-d arrays are iterable but can't be iterated over. We already check for non-iterable things in decollate_batch here. Can we modeify this to correctly pick up when the batch is a 0-d array and just return it in that case? Or return its contents?

@arthurdjn
Copy link
Author

arthurdjn commented Jun 10, 2025

Thanks for the feedback. The initial PR was:

if isinstance(batch, (float, int, str, bytes)) or (
    type(batch).__module__ == "numpy" and not isinstance(batch, Iterable)
):
    return batch
if isinstance(batch, np.ndarray) and batch.ndim == 0:
    return batch.item() if detach else batch
# rest ...

Is this something that you find more complete?

Note

I refactored the PR to convert from numpy array to torch tensor as suggested by @KumoLiu.

@ericspod
Copy link
Member

Is this something that you find more complete?

What I had in mind was more of the following change:

...
    if batch is None or isinstance(batch, (float, int, str, bytes)):
        return batch
    if getattr(batch, "ndim", -1) == 0:  # assumes only Numpy objects and Pytorch tensors have ndim
        return batch.item() if detach else batch
    if isinstance(batch, torch.Tensor):
        if detach:
            batch = batch.detach()
# REMOVE
#      if batch.ndim == 0:
#          return batch.item() if detach else batch
...

@arthurdjn
Copy link
Author

Thanks! I will update the PR to include these changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants