Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT FE] Support torch==2.6.0 #28879

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import logging
import inspect
from typing import Any, Optional
import torch

from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder
Expand All @@ -15,7 +16,6 @@
make_constant, fetch_attr, pt_to_ov_type_map, torch_tensor_to_ov_const)

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)


class BaseFXDecoder (Decoder):
Expand Down Expand Up @@ -234,6 +234,28 @@ def __init__(self, pt_module, fx_gm=None, nodes=None,
self.input_types.append(
BaseFXDecoder.get_type_for_value(arg))

@classmethod
def from_exported_program(cls, exported_program: torch.export.ExportedProgram) -> 'TorchFXPythonDecoder':
"""
Create a TorchFXPythonDecoder instance from an exported PyTorch program.
"""
from packaging import version
if version.parse(torch.__version__) >= version.parse("2.6"):
from torch.export.decomp_utils import CustomDecompTable
from openvino.frontend.pytorch.torchdynamo.decompositions import ops_to_not_decompose
decomp = CustomDecompTable()
for op in ops_to_not_decompose():
decomp.pop(op)
exported_program = exported_program.run_decompositions(decomp)
elif version.parse(torch.__version__) >= version.parse("2.2"):
from torch._decomp import get_decompositions
from openvino.frontend.pytorch.torchdynamo.decompositions import get_export_decomposition_list
decomp = get_decompositions(get_export_decomposition_list())
exported_program = exported_program.run_decompositions(decomp_table=decomp)
gm = exported_program.module()
logger.debug(gm.code)
return cls(gm, dynamic_shapes=True)

def get_input_signature_name(self, index: int) -> str:
if self._input_signature is not None and index < len(self._input_signature):
return self._input_signature[index]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,11 @@ def get_export_decomposition_list():
except ImportError:
pass
return decomp


def ops_to_not_decompose():
# List of operations that shouldn't be decomposed
return [
torch.ops.aten.upsample_nearest2d.default,
torch.ops.aten.col2im.default,
]
6 changes: 3 additions & 3 deletions tests/requirements_pytorch
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
# optimum still requires numpy<2.0.0
numpy==1.26.4; python_version < "3.12"
numpy==2.1.1; python_version >= "3.12"
torch==2.5.1; platform_system != "Darwin" or platform_machine != "x86_64"
torch==2.6.0; platform_system != "Darwin" or platform_machine != "x86_64"
torch==2.2.2; platform_system == "Darwin" and platform_machine == "x86_64"
--extra-index-url https://download.pytorch.org/whl/cpu

torchvision==0.20.1; platform_system != "Darwin" or platform_machine != "x86_64"
torchvision==0.21.0; platform_system != "Darwin" or platform_machine != "x86_64"
torchvision==0.17.2; platform_system == "Darwin" and platform_machine == "x86_64"
torchaudio==2.5.1; platform_system != "Darwin" or platform_machine != "x86_64"
torchaudio==2.6.0; platform_system != "Darwin" or platform_machine != "x86_64"
torchaudio==2.2.2; platform_system == "Darwin" and platform_machine == "x86_64"
# before updating transformers version, make sure no tests (esp. sdpa2pa) are failing
transformers==4.47.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,6 @@ def extract_module_extensions(args):
return {extension.module: extension for extension in extensions if isinstance(extension, ModuleExtension)}


def get_decoder_for_exported_program(model):
from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder
import torch

from packaging import version
if version.parse(torch.__version__) >= version.parse("2.2"):
from torch._decomp import get_decompositions
from openvino.frontend.pytorch.torchdynamo.decompositions import get_export_decomposition_list
decomp = get_decompositions(get_export_decomposition_list())
model = model.run_decompositions(decomp_table=decomp)
gm = model.module()
log.debug(gm.code)
decoder = TorchFXPythonDecoder(gm, dynamic_shapes=True)
return decoder


def get_pytorch_decoder(model, example_inputs, args):
try:
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
Expand Down Expand Up @@ -65,7 +49,7 @@ def get_pytorch_decoder(model, example_inputs, args):
inputs = prepare_torch_inputs(example_inputs)
if not isinstance(model, (TorchScriptPythonDecoder, TorchFXPythonDecoder)):
if hasattr(torch, "export") and isinstance(model, (torch.export.ExportedProgram)):
decoder = get_decoder_for_exported_program(model)
decoder = TorchFXPythonDecoder.from_exported_program(model)
else:
decoder = TorchScriptPythonDecoder(
model,
Expand Down Expand Up @@ -123,7 +107,7 @@ def get_pytorch_decoder_for_model_on_disk(argv, args):
try:
exported_program = torch.export.load(input_model)
if hasattr(torch, "export") and isinstance(exported_program, (torch.export.ExportedProgram)):
argv.input_model = get_decoder_for_exported_program(exported_program)
argv.input_model = TorchFXPythonDecoder.from_exported_program(exported_program)
argv.framework = 'pytorch'
return True
except:
Expand Down
Loading