Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cumsum op with CoreML backend fails export #6201

Open
msluszniak opened this issue Oct 14, 2024 · 3 comments
Open

cumsum op with CoreML backend fails export #6201

msluszniak opened this issue Oct 14, 2024 · 3 comments
Assignees
Labels
bug Something isn't working module: coreml Issues related to Apple's Core ML delegation partner: apple For backend delegation, kernels, demo, etc. from the 3rd-party partner, Apple triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@msluszniak
Copy link

🐛 Describe the bug

I'm currently working on creating an example for ExecuTorch with EfficientSAM. I've developed a runner that successfully exports the model in pte format. However, I'm encountering a problem. There seems to be a mismatch in the number of arguments the cumsum operation on the CoreML side expects. The function is indicated to expect 3 arguments, but only two are effectively used. Consequently, I received an error ValueError: node aten_cumsum_default (cumsum) got 2 input(s), expected [3]. When I manually adjust this line to expected=2, the export works fine. I created an issue with mentioning this on CoreML GitHub. However, I received a respond (here) that I was testing it on a version of PyTorch that CoreML doesn't support. So I downgraded it, but it led to another issue. Some functionalities used in ExecuTorch require unsupported by CoreML versions of PyTorch (2.5.0+).

For now, I "hacked" it and replace each call of cumsum with the following code:

def vectorized_cumsum(input_tensor, dim):
        output = input_tensor.clone()
        slices = [slice(None)] * input_tensor.dim()
        for i in range(1, input_tensor.size(dim)):
            slices[dim] = i
            minus_one_slices = slices.copy()
            minus_one_slices[dim] = i - 1
            output[tuple(slices)] += output[tuple(minus_one_slices)]
        return output

So it will work, but the problem is still valid as this op is not optimized. I understand that is more problem on CoreML side, but I've already sent an issue and I want to check if there are any better ways to approach it.

Versions

Collecting environment information...
PyTorch version: 2.6.0.dev20241007
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.0.1 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.3)
CMake version: version 3.30.4
Libc version: N/A

Python version: 3.10.0 (default, Mar 3 2022, 03:54:28) [Clang 12.0.0 ] (64-bit runtime)
Python platform: macOS-15.0.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M3 Pro

Versions of relevant libraries:
[pip3] executorch==0.5.0a0+cb3a546
[pip3] executorchcoreml==0.0.1
[pip3] numpy==1.21.3
[pip3] torch==2.6.0.dev20241007
[pip3] torchaudio==2.5.0.dev20241007
[pip3] torchsr==1.0.4
[pip3] torchvision==0.20.0.dev20241007
[conda] executorch 0.5.0a0+cb3a546 pypi_0 pypi
[conda] executorchcoreml 0.0.1 pypi_0 pypi
[conda] numpy 1.21.3 pypi_0 pypi
[conda] torch 2.6.0.dev20241007 pypi_0 pypi
[conda] torchaudio 2.5.0.dev20241007 pypi_0 pypi
[conda] torchsr 1.0.4 pypi_0 pypi
[conda] torchvision 0.20.0.dev20241007 pypi_0 pypi

@dbort dbort added bug Something isn't working partner: apple For backend delegation, kernels, demo, etc. from the 3rd-party partner, Apple triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: coreml Issues related to Apple's Core ML delegation labels Oct 14, 2024
@cccclai
Copy link
Contributor

cccclai commented Oct 14, 2024

Hi, thank you for trying it out! I noticed this reported issue mentioned

I'm not able to reproduce the error with small code sample.

If so, is there a way to repro with EfficientSAM?

@msluszniak
Copy link
Author

Sorry, now I have a small snippet that generates the issue

import torch
from torch.export import export
from executorch.exir import to_edge
from torch import nn
 
import coremltools as ct
 
class TestCumsum(nn.Module):
    """
    Test class
    """

    def __init__(self) -> None:
        super().__init__()

    def forward(self) -> torch.Tensor:
        """Test forward"""
        h, w = 100, 100
        grid = torch.ones([h, w], device="cpu", dtype=torch.float32)
        
        _y_embed = grid.cumsum(dim=0)
        x_embed = grid.cumsum(dim=1)
        
        return x_embed
 
 
if __name__ == "__main__":
 
    example_args = tuple([])
    aten_dialect = export(TestCumsum(), example_args)
    edge_dialect = to_edge(aten_dialect).exported_program()

    mlmodel = ct.convert(edge_dialect)

@YifanShenSZ
Copy link
Collaborator

YifanShenSZ commented Oct 17, 2024

Locally confirmed fix with the (slightly modified) reproduce with torch 2.4

import numpy as np

import torch
from torch.export import export
from executorch.exir import to_edge
from torch import nn

import coremltools as ct


class TestCumsum(nn.Module):
    def forward(self, grid) -> torch.Tensor:
        _y_embed = grid.cumsum(dim=0)
        x_embed = grid.cumsum(dim=1)
        
        return _y_embed, x_embed
 
 
if __name__ == "__main__":
    torch_model = TestCumsum()
    torch_model.eval()

    h, w = 100, 100
    grid = torch.ones([h, w], device="cpu", dtype=torch.float32)

    aten_dialect = export(torch_model, (grid,))
    edge_dialect = to_edge(aten_dialect).exported_program()

    coreml_model = ct.convert(edge_dialect)


    outputs_torch = torch_model(grid)
    y_embed_torch = outputs_torch[0].detach().numpy()
    x_embed_torch = outputs_torch[1].detach().numpy()
    
    outputs_coreml = coreml_model.predict({"grid": grid.detach().numpy()})
    y_embed_coreml = outputs_coreml["aten_cumsum_default"]
    x_embed_coreml = outputs_coreml["aten_cumsum_default_1"]

    np.testing.assert_allclose(y_embed_coreml, y_embed_torch)
    np.testing.assert_allclose(x_embed_coreml, x_embed_torch)

Will include this fix in next release

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working module: coreml Issues related to Apple's Core ML delegation partner: apple For backend delegation, kernels, demo, etc. from the 3rd-party partner, Apple triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants