Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT FE] Support torch==2.6.0 #28879

Open
wants to merge 13 commits into
base: master
Choose a base branch
from

Conversation

mvafin
Copy link
Contributor

@mvafin mvafin commented Feb 7, 2025

Details:

  • Support torch==2.6.0

Tickets:

@mvafin mvafin requested a review from eaidova February 7, 2025 14:55
@mvafin mvafin requested review from a team as code owners February 7, 2025 14:55
@github-actions github-actions bot added category: Python API OpenVINO Python bindings category: PyTorch FE OpenVINO PyTorch Frontend category: OVC OVC tool labels Feb 7, 2025
Signed-off-by: Maxim Vafin <[email protected]>
Copy link
Member

@rkazants rkazants left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here update is also required: https://github.com/openvinotoolkit/openvino/blob/master/src/bindings/python/src/openvino/preprocess/torchvision/requirements.txt

Ideally, that file should removed in favor of the single file

@anzr299 anzr299 requested review from anzr299 and removed request for anzr299 February 9, 2025 18:15
@mvafin
Copy link
Contributor Author

mvafin commented Feb 10, 2025

Here update is also required: https://github.com/openvinotoolkit/openvino/blob/master/src/bindings/python/src/openvino/preprocess/torchvision/requirements.txt

Ideally, that file should removed in favor of the single file

No, the update is not required there. It is a product requirements, not a test requirements, they only limit lower bound and upper bound for python 3.9

@mvafin mvafin requested a review from a team as a code owner February 14, 2025 09:27
@mvafin mvafin requested review from slyalin and PiotrKrzem February 14, 2025 09:27
Signed-off-by: Maxim Vafin <[email protected]>
@mvafin mvafin requested a review from a team as a code owner February 14, 2025 18:00
@github-actions github-actions bot added category: CI OpenVINO public CI category: CPP API OpenVINO CPP API bindings github_actions Pull requests that update GitHub Actions code labels Feb 14, 2025
@@ -129,7 +129,7 @@ def test_arange_end_only(self, dtype, end, use_out, ie_device, precision, ir_ver
@pytest.mark.parametrize("start,end", [(0, 1), (-1, 1), (1, 5), (0.5, 2.5)])
def test_arange_start_end(self, dtype, end, start, ie_device, precision, ir_version):
self._test(*self.create_model(dtype, 2), ie_device, precision, ir_version,
kwargs_to_prepare_input={"end": end, "start": start, "dtype": dtype})
kwargs_to_prepare_input={"end": end, "start": start, "dtype": dtype}, aot_autograd=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like this torch.arange with torch.tensor inputs fail to trace with both make_fx and aot_autograd. The reason aot_autograd passes the test is that it adds graph break while tracing and falls back before it reaches to the openvino compiler function (fx_openvino in backend.py).

I can also reproduce this issue with torch inductor backend on my side. Below is a short reproducer:

import torch

def arange_fn(x):
    return torch.arange(x)

end = 5
end_tensor = torch.tensor(end)

compiled_fn = torch.compile(arange_fn, fullgraph=True)

# Executes without any issue
out = arange_fn(end)
print("Eager scalar: ", out)

# Executes without any issue
out = arange_fn(end_tensor)
print("Eager torch.tensor: ", out)

# Executes without any issue
out = compiled_fn(end)
print("Inductor scalar: ", out)

# Fails
out = compiled_fn(end_tensor)
print("Inductor torch.tensor: ", out)

I am not sure if this is a bug in pytorch side or simply not supported by aot_autograd and make_fx. Either way, we may not be able to fix it in openvino backend side.

And alternative way to test this op to change the input type to a regular scalar when testing for torch.compile maybe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems like a problem. I think tensor input should be better supported then scalar input. Makes sense to create an issue to torch.

@github-actions github-actions bot removed category: CI OpenVINO public CI github_actions Pull requests that update GitHub Actions code labels Feb 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: CPP API OpenVINO CPP API bindings category: OVC OVC tool category: Python API OpenVINO Python bindings category: PyTorch FE OpenVINO PyTorch Frontend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants