Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potentially wrong number of argument in cumsum op #2303

Open
msluszniak opened this issue Aug 9, 2024 · 6 comments
Open

Potentially wrong number of argument in cumsum op #2303

msluszniak opened this issue Aug 9, 2024 · 6 comments
Assignees
Labels
bug Unexpected behaviour that should be corrected (type) torch.export triaged Reviewed and examined, release as been assigned if applicable (status)

Comments

@msluszniak
Copy link

msluszniak commented Aug 9, 2024

I discovered that the expected number of arguments to cumsum op in coremltools might be wrong while exporting PyTorch module to *.pte file (executorch). See

inputs = _get_inputs(context, node, expected=3)
Number of arguments is set to 3 while only two are used. Indeed, in my export, I got an error that cumsum op expects 3 arguments while only 2 were provided. Changing this value to 2 fixed suppressed the error.

To Reproduce

import torch
from torch.export import export
from executorch.exir import to_edge
from torch import nn
 
import coremltools as ct
 
class TestCumsum(nn.Module):
    """
    Test class
    """

    def __init__(self) -> None:
        super().__init__()

    def forward(self) -> torch.Tensor:
        """Test forward"""
        h, w = 100, 100
        grid = torch.ones([h, w], device="cpu", dtype=torch.float32)
        
        _y_embed = grid.cumsum(dim=0)
        x_embed = grid.cumsum(dim=1)
        
        return x_embed
 
 
if __name__ == "__main__":
 
    example_args = tuple([])
    aten_dialect = export(TestCumsum(), example_args)
    edge_dialect = to_edge(aten_dialect).exported_program()

    mlmodel = ct.convert(edge_dialect)

System environment (please complete the following information):

  • coremltools version: 8.0b1
  • OS (e.g. MacOS version or Linux type): MacOS Version 14.5
  • Any other relevant version information (e.g. PyTorch or TensorFlow version): PyTorch 2.5.0.dev20240716

Additional context

Discovered this while exporting Pytorch model to executorch with CoreML backend.

@msluszniak msluszniak added the bug Unexpected behaviour that should be corrected (type) label Aug 9, 2024
@jakesabathia2
Copy link
Collaborator

Right now the torch version coremltools officially supports is 2.3.0 (https://github.com/apple/coremltools/blob/main/coremltools/_deps/__init__.py#L156),
while your enviroment got 2.5.0.

@mkopcins
Copy link

@jakesabathia2 From what I can tell this issue persists over all versions of pytorch from at least 2.0, as the signature of the torch.cumsum has not changed since. Can You suggest why this would work with older version of PyTorch?

@wojtke
Copy link

wojtke commented Oct 11, 2024

The third argument is optional.
https://pytorch.org/docs/stable/generated/torch.cumsum.html

It is treated as optional in torch.onnx:
https://github.com/pytorch/pytorch/blob/e30c55ee527b40d67555464b9e402b4b7ce03737/torch/onnx/symbolic_opset11.py#L412-L424

and you guys actually not use the third argument here - you ignore it

@register_torch_op
def cumsum(context, node):
inputs = _get_inputs(context, node, expected=3)
x = inputs[0]
if is_bool(x.dtype):
x = mb.cast(x=x, dtype='int32')
res = mb.cumsum(x=x, axis=inputs[1], name=node.name)
context.add(res)

@YifanShenSZ YifanShenSZ self-assigned this Oct 15, 2024
@YifanShenSZ YifanShenSZ added triaged Reviewed and examined, release as been assigned if applicable (status) torch.export labels Oct 15, 2024
@msluszniak
Copy link
Author

As an update I upload the code that generates the issue. Below are all version of libraries:

Collecting environment information...
PyTorch version: 2.6.0.dev20241007
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.0.1 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.3)
CMake version: version 3.30.4
Libc version: N/A

Python version: 3.10.0 (default, Mar 3 2022, 03:54:28) [Clang 12.0.0 ] (64-bit runtime)
Python platform: macOS-15.0.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M3 Pro

Versions of relevant libraries:
[pip3] executorch==0.5.0a0+cb3a546
[pip3] executorchcoreml==0.0.1
[pip3] numpy==1.21.3
[pip3] torch==2.6.0.dev20241007
[pip3] torchaudio==2.5.0.dev20241007
[pip3] torchsr==1.0.4
[pip3] torchvision==0.20.0.dev20241007
[conda] executorch 0.5.0a0+cb3a546 pypi_0 pypi
[conda] executorchcoreml 0.0.1 pypi_0 pypi
[conda] numpy 1.21.3 pypi_0 pypi
[conda] torch 2.6.0.dev20241007 pypi_0 pypi
[conda] torchaudio 2.5.0.dev20241007 pypi_0 pypi
[conda] torchsr 1.0.4 pypi_0 pypi
[conda] torchvision 0.20.0.dev20241007 pypi_0 pypi

and the actual code

import torch
from torch.export import export
from executorch.exir import to_edge
from torch import nn
 
import coremltools as ct
 
class TestCumsum(nn.Module):
    """
    Test class
    """

    def __init__(self) -> None:
        super().__init__()

    def forward(self) -> torch.Tensor:
        """Test forward"""
        h, w = 100, 100
        grid = torch.ones([h, w], device="cpu", dtype=torch.float32)
        
        _y_embed = grid.cumsum(dim=0)
        x_embed = grid.cumsum(dim=1)
        
        return x_embed
 
 
if __name__ == "__main__":
 
    example_args = tuple([])
    aten_dialect = export(TestCumsum(), example_args)
    edge_dialect = to_edge(aten_dialect).exported_program()

    mlmodel = ct.convert(edge_dialect)

@YifanShenSZ
Copy link
Collaborator

YifanShenSZ commented Oct 17, 2024

Locally confirmed fix with the (slightly modified) reproduce with torch 2.4

import numpy as np

import torch
from torch.export import export
from executorch.exir import to_edge
from torch import nn

import coremltools as ct


class TestCumsum(nn.Module):
    def forward(self, grid) -> torch.Tensor:
        _y_embed = grid.cumsum(dim=0)
        x_embed = grid.cumsum(dim=1)
        
        return _y_embed, x_embed
 
 
if __name__ == "__main__":
    torch_model = TestCumsum()
    torch_model.eval()

    h, w = 100, 100
    grid = torch.ones([h, w], device="cpu", dtype=torch.float32)

    aten_dialect = export(torch_model, (grid,))
    edge_dialect = to_edge(aten_dialect).exported_program()

    coreml_model = ct.convert(edge_dialect)


    outputs_torch = torch_model(grid)
    y_embed_torch = outputs_torch[0].detach().numpy()
    x_embed_torch = outputs_torch[1].detach().numpy()
    
    outputs_coreml = coreml_model.predict({"grid": grid.detach().numpy()})
    y_embed_coreml = outputs_coreml["aten_cumsum_default"]
    x_embed_coreml = outputs_coreml["aten_cumsum_default_1"]

    np.testing.assert_allclose(y_embed_coreml, y_embed_torch)
    np.testing.assert_allclose(x_embed_coreml, x_embed_torch)

Will include this fix in next release

@msluszniak
Copy link
Author

Great! Thank you much for your help 😁

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Unexpected behaviour that should be corrected (type) torch.export triaged Reviewed and examined, release as been assigned if applicable (status)
Projects
None yet
Development

No branches or pull requests

5 participants