Skip to content

Commit

Permalink
[PT FE] Support None in example (#28398)
Browse files Browse the repository at this point in the history
### Details:
 - *Support `None` in example*

### Tickets:
 - *CVS-156684*

---------

Signed-off-by: Maxim Vafin <[email protected]>
  • Loading branch information
mvafin authored Jan 13, 2025
1 parent 2c80544 commit 93a103b
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 15 deletions.
12 changes: 7 additions & 5 deletions src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
# flake8: noqa
# mypy: ignore-errors

import inspect
import logging
import typing
import torch

from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder
from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
from openvino import op, PartialShape, Type as OVType, OVAny
Expand All @@ -14,16 +19,12 @@
prepare_example_inputs_and_model,
convert_quantized_tensor,
graph_has_ops,
patch_none_example,
)
from openvino import opset11 as ops
from openvino.frontend.pytorch import quantized, patch_model
from openvino.frontend.pytorch.module_extension import ModuleExtension

import inspect
import logging
import typing
import torch

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -133,6 +134,7 @@ def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False,
scripted = torch.jit.script(pt_module)
freeze_by_default = True
else:
pt_module, example_inputs = patch_none_example(pt_module, example_inputs)
input_parameters, input_signature, pt_module, self._input_is_list = prepare_example_inputs_and_model(
example_inputs, input_params, pt_module)

Expand Down
115 changes: 107 additions & 8 deletions src/bindings/python/src/openvino/frontend/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
# flake8: noqa
# mypy: ignore-errors

import inspect
import logging
import torch
import numpy as np

from openvino import op, Type as OVType, Shape, Tensor
from openvino import opset11 as ops

log = logging.getLogger(__name__)


def make_constant(*args, **kwargs):
return op.Constant(*args, **kwargs)
Expand Down Expand Up @@ -162,6 +166,23 @@ def forward(self, {input_sign}):
"""


def build_wrapper(template, model):
"""
Builds a wrapper around the given model using the provided template.
"""
result = {}
try:
exec(template, result)

wrapped_model = result["ModelWrapper"](model)
wrapped_model.eval()
# if wrapping failed, it is better to return original model for avoid user confusion regarding error message
except Exception:
log.error("Failed to build model wrapper.")
wrapped_model = model
return wrapped_model


def process_dict_inputs(inputs, input_params, model):
ordered_inputs = []
for input_name in input_params:
Expand Down Expand Up @@ -203,15 +224,8 @@ def process_dict_inputs(inputs, input_params, model):

wrapper_class = wrapper_template.format(input_sign=", ".join(
input_sign_str), example_input=", ".join(input_params_str))
result = {}
try:
exec(wrapper_class, result)

wrapped_model = result["ModelWrapper"](model)
wrapped_model.eval()
# if wrapping failed, it is better to return original model for avoid user confusion regarding error message
except Exception:
wrapped_model = model
wrapped_model = build_wrapper(wrapper_class, model)

return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, wrapped_model

Expand Down Expand Up @@ -265,3 +279,88 @@ def convert_quantized_tensor(qtensor: torch.Tensor, shared_memory: bool):
sub = ops.subtract(convert, zero_point)
return ops.multiply(sub, scale).outputs()
assert False, "Unsupported qscheme"


def process_individual_input(x, x_name):
"""
Processes an individual input and generates a signature,
parameter string, example entry, and a wrap flag.
"""
sign = None
param = None
example_entry = None
to_wrap = False
if isinstance(x, tuple):
internal_input = []
new_tuple = []
index = 0
for v in x:
if v is None:
to_wrap = True
internal_input.append("None")
else:
internal_input.append(f"{x_name}[{index}]")
new_tuple.append(v)
index += 1
param = f"({', '.join(internal_input)},)"
if len(new_tuple) > 0:
example_entry = tuple(new_tuple)
sign = x_name
elif x is None:
to_wrap = True
param = "None"
else:
sign = x_name
param = x_name
example_entry = x
return sign, param, example_entry, to_wrap


def patch_none_example(model: torch.nn.Module, example):
"""
Patches a PyTorch model to handle None values in the input example.
"""
callable_func = getattr(model, "forward", model.__call__)
input_params = inspect.signature(callable_func).parameters
input_signature = list(input_params)
input_sign_str = []
input_params_str = []
to_wrap = False
if isinstance(example, tuple) and len(input_signature) >= len(example):
new_example = []
for i, x in enumerate(example):
x_name = input_signature[i]
sign, param, example_entry, _to_wrap = process_individual_input(x, x_name)
to_wrap = to_wrap or _to_wrap
if sign is not None:
input_sign_str.append(str(input_params[sign]))
input_params_str.append(param)
if example_entry is not None:
new_example.append(example_entry)
if to_wrap:
wrapper_class = wrapper_template.format(input_sign=", ".join(input_sign_str),
example_input=", ".join(input_params_str))
wrapped_model = build_wrapper(wrapper_class, model)
log.warning("Model has None in the example input. The input "
"with None will be removed from the resulting model.")
return wrapped_model, tuple(new_example)
elif isinstance(example, dict) and len(input_signature) >= len(example):
new_example = {}
input_signature = [s for s in input_signature if s in example]
for x_name in input_signature:
x = example[x_name]
sign, param, example_entry, _to_wrap = process_individual_input(x, x_name)
to_wrap = to_wrap or _to_wrap
if sign is not None:
input_sign_str.append(str(input_params[sign]))
input_params_str.append(f"{x_name}={param}")
if example_entry is not None:
new_example[x_name] = example_entry
if to_wrap:
wrapper_class = wrapper_template.format(input_sign=", ".join(input_sign_str),
example_input=", ".join(input_params_str))
wrapped_model = build_wrapper(wrapper_class, model)
log.warning("Model has None in the example input. The input "
"with None will be removed from the resulting model.")
return wrapped_model, new_example
return model, example
100 changes: 99 additions & 1 deletion tests/layer_tests/ovc_python_api_tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,100 @@ def forward(self, a, b):
), "output": "some_name"}


def create_pytorch_module_with_none_example(tmp_dir):
class PTModel(torch.nn.Module):
def forward(self, a, b):
if b is None:
b = torch.tensor(1., dtype=torch.float32)
return a + b

net = PTModel()
a = ov.opset10.parameter(PartialShape([-1]), dtype=np.float32)
add = ov.opset10.add(a, np.float32([1.]))
ref_model = Model([add], [a], "test")
return net, ref_model, {
"example_input": (
torch.tensor([5, 6], dtype=torch.float32),
None
),
"compress_to_fp16": False}


def create_pytorch_module_with_none_dict_example(tmp_dir):
class PTModel(torch.nn.Module):
def forward(self, a, b):
if b is None:
b = torch.tensor(1., dtype=torch.float32)
return a + b

net = PTModel()
a = ov.opset10.parameter(PartialShape([-1]), dtype=np.float32)
add = ov.opset10.add(a, np.float32([1.]))
ref_model = Model([add], [a], "test")
return net, ref_model, {
"example_input": {
"a": torch.tensor([5, 6], dtype=torch.float32),
"b": None,
},
"compress_to_fp16": False}


def create_pytorch_module_with_none_in_tuple(tmp_dir):
class PTModel(torch.nn.Module):
def forward(self, a, b):
x = a[0]
if a[1] is None:
x = x + torch.tensor(1., dtype=torch.float32)
else:
x = x + a[1]
if a[2] is None:
x = x + torch.tensor(1., dtype=torch.float32)
else:
x = x + a[2]
return x + b

net = PTModel()
a = ov.opset10.parameter(PartialShape([-1]), dtype=np.float32)
b = ov.opset10.parameter(PartialShape([-1]), dtype=np.float32)
add = ov.opset10.add(a, np.float32([2.]))
add2 = ov.opset10.add(add, b)
ref_model = Model([add2], [a, b], "test")
return net, ref_model, {
"example_input": {
"a": (torch.tensor([5, 6], dtype=torch.float32), None, None),
"b": torch.tensor([5, 6], dtype=torch.float32),
},
"compress_to_fp16": False}


def create_pytorch_module_with_none_in_tuple_case2(tmp_dir):
class PTModel(torch.nn.Module):
def forward(self, a, b):
x = a[0]
if a[1] is None:
x = x + torch.tensor(1., dtype=torch.float32)
else:
x = x + a[1]
if a[2] is None:
x = x + torch.tensor(1., dtype=torch.float32)
else:
x = x + a[2]
return x + b

net = PTModel()
a = ov.opset10.parameter(PartialShape([-1]), dtype=np.float32)
add = ov.opset10.add(a, np.float32([2.]))
b = ov.opset10.parameter(PartialShape([-1]), dtype=np.float32)
add2 = ov.opset10.add(add, b)
ref_model = Model([add2], [a, b], "test")
return net, ref_model, {
"example_input": (
(torch.tensor([5, 6], dtype=torch.float32), None, None),
torch.tensor([5, 6], dtype=torch.float32),
),
"compress_to_fp16": False}


class TestMoConvertPyTorch(CommonMOConvertTest):
test_data = [
'create_pytorch_nn_module_case1',
Expand Down Expand Up @@ -1062,7 +1156,11 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
'create_pytorch_module_with_nested_inputs6',
'create_pytorch_module_with_nested_list_and_single_input',
'create_pytorch_module_with_single_input_as_list',
'create_pytorch_module_with_nested_dict_input'
'create_pytorch_module_with_nested_dict_input',
'create_pytorch_module_with_none_example',
'create_pytorch_module_with_none_dict_example',
'create_pytorch_module_with_none_in_tuple',
'create_pytorch_module_with_none_in_tuple_case2',
]

@pytest.mark.parametrize("create_model", test_data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ def get_pytorch_decoder(model, example_inputs, args):
else:
decoder = model
args['input_model'] = decoder
args["example_input"] = inputs
ei = getattr(decoder, "_example_input", None)
if ei is not None:
args["example_input"] = ei
else:
args["example_input"] = inputs

return args

Expand Down Expand Up @@ -250,6 +254,8 @@ def to_torch_tensor(tensor):
return tuple(to_torch_tensor(x) for x in tensor)
if isinstance(tensor, dict) and all(isinstance(k, str) for k in tensor.keys()):
return dict((k, to_torch_tensor(x)) for k, x in tensor.items())
if tensor is None:
return None
else:
raise Error("Unexpected type of example_input. Supported types torch.Tensor, np.array or ov.Tensor. "
"Got {}".format(type(tensor)))
Expand Down

0 comments on commit 93a103b

Please sign in to comment.