Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT FE] Support torch==2.6.0 #28879

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
4 changes: 2 additions & 2 deletions .github/workflows/job_pytorch_layer_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ env:
jobs:
PyTorch_Layer_Tests:
name: PyTorch Layer Tests
timeout-minutes: 40
timeout-minutes: 60
runs-on: ${{ inputs.runner }}
container: ${{ fromJSON(inputs.container) }}
defaults:
Expand Down Expand Up @@ -141,7 +141,7 @@ jobs:
- name: PyTorch torch.compile TORCHFX Layer Tests
if: ${{ fromJSON(inputs.affected-components).PyTorch_FE.test && runner.os != 'macOS' && runner.arch != 'ARM64' && runner.os != 'Windows' }} # Ticket: 126287
run: |
python3 -m pytest ${{ env.LAYER_TESTS_INSTALL_DIR }}/pytorch_tests -m precommit_fx_backend -v --junitxml=${{ env.INSTALL_TEST_DIR }}/TEST-pytorch_compile.xml
python3 -m pytest ${{ env.LAYER_TESTS_INSTALL_DIR }}/pytorch_tests ${PARALLEL} -m precommit_fx_backend -v --junitxml=${{ env.INSTALL_TEST_DIR }}/TEST-pytorch_compile.xml
env:
TEST_DEVICE: CPU
TEST_PRECISION: FP32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import logging
import inspect
from typing import Any, Optional
import torch

from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder
Expand All @@ -15,7 +16,6 @@
make_constant, fetch_attr, pt_to_ov_type_map, torch_tensor_to_ov_const)

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)


class BaseFXDecoder (Decoder):
Expand Down Expand Up @@ -165,6 +165,7 @@ class TorchFXPythonDecoder (BaseFXDecoder):
"""
Decoder for PyTorch FX GraphModule and Node objects to OpenVINO IR.
"""
_decomp_table = None

def __init__(self, pt_module, fx_gm=None, nodes=None,
mark_node_callback=None, input_shapes=[], input_types=[], dynamic_shapes=False):
Expand Down Expand Up @@ -230,6 +231,32 @@ def __init__(self, pt_module, fx_gm=None, nodes=None,
self.input_types.append(
BaseFXDecoder.get_type_for_value(arg))

@classmethod
def from_exported_program(cls, exported_program: torch.export.ExportedProgram) -> 'TorchFXPythonDecoder':
"""
Create a TorchFXPythonDecoder instance from an exported PyTorch program.
"""
from packaging import version
if version.parse(torch.__version__) >= version.parse("2.6"):
if cls._decomp_table is None:
from torch.export.decomp_utils import CustomDecompTable
from openvino.frontend.pytorch.torchdynamo.decompositions import ops_to_not_decompose
cls._decomp_table = CustomDecompTable()
for op in ops_to_not_decompose():
try:
cls._decomp_table.pop(op)
except KeyError as e:
logging.warning("Operation %s not found in decomp table", op, exc_info=e)
exported_program = exported_program.run_decompositions(cls._decomp_table)
elif version.parse(torch.__version__) >= version.parse("2.2"):
from torch._decomp import get_decompositions
from openvino.frontend.pytorch.torchdynamo.decompositions import get_export_decomposition_list
decomp = get_decompositions(get_export_decomposition_list())
exported_program = exported_program.run_decompositions(decomp_table=decomp)
gm = exported_program.module()
logger.debug(gm.code)
return cls(gm, dynamic_shapes=True)

@staticmethod
def get_found_shape(value) -> str:
# If input is a tensor, read the shape from meta data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,22 @@ def get_export_decomposition_list():
except ImportError:
pass
return decomp


def ops_to_not_decompose():
# List of operations that shouldn't be decomposed
return [
torch.ops.aten.col2im.default,
torch.ops.aten.linear.default,
torch.ops.aten.upsample_nearest1d.default,
torch.ops.aten.upsample_nearest1d.vec,
torch.ops.aten.upsample_nearest2d.default,
torch.ops.aten.upsample_nearest2d.vec,
torch.ops.aten.upsample_nearest3d.default,
torch.ops.aten.upsample_nearest3d.vec,
torch.ops.aten.upsample_linear1d.vec,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.upsample_trilinear3d.vec,
torch.ops.aten.upsample_bicubic2d.vec,
torch.ops.aten.scaled_dot_product_attention.default,
]
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ class NodeContext : public frontend::NodeContext {
std::shared_ptr<ov::Model> convert_subgraph(size_t index) const;

private:
ov::Any apply_additional_conversion_rules(const ov::Any& data, const std::type_info& type_info) const override;

std::shared_ptr<TorchDecoder> m_decoder;
const TensorMap& m_ext_tensor_map;
std::shared_ptr<TensorMap> m_tensor_map;
Expand Down
20 changes: 20 additions & 0 deletions src/frontends/pytorch/src/node_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,26 @@ Any NodeContext::get_values_from_const_input(int index) const {
return 0;
}

ov::Any NodeContext::apply_additional_conversion_rules(const ov::Any& data, const std::type_info& type_info) const {
if (data.is<Output<Node>>() && type_info != typeid(Output<Node>)) {
auto const_node = as_type_ptr<v0::Constant>(data.as<Output<Node>>().get_node_shared_ptr());
FRONT_END_GENERAL_CHECK(const_node, "Attribute must be const if requested as not a Node.");
if (type_info == typeid(bool)) {
bool res = const_node->cast_vector<bool>()[0];
return res;
} else if (type_info == typeid(int32_t)) {
int32_t res = const_node->cast_vector<int32_t>()[0];
return res;
} else {
FRONT_END_GENERAL_CHECK(false,
"Could not decode attribute for ",
get_name(),
" node. Provided type is not known.");
}
}
return data;
}

} // namespace pytorch
} // namespace frontend
} // namespace ov
4 changes: 4 additions & 0 deletions src/frontends/pytorch/src/op/bucketize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@ OutputVector translate_bucketize(const NodeContext& context) {
element::Type output_type = ov::element::i64;
if (!context.input_is_none(2) && context.const_input<bool>(2)) {
output_type = ov::element::i32;
} else if (context.has_attribute("out_int32") && context.get_attribute<bool>("out_int32")) {
output_type = ov::element::i32;
}

bool with_right_bound = true;
if (!context.input_is_none(3)) {
with_right_bound = !context.const_input<bool>(3);
} else if (context.has_attribute("right")) {
with_right_bound = !context.get_attribute<bool>("right");
}

auto bucketize =
Expand Down
7 changes: 5 additions & 2 deletions src/frontends/pytorch/src/op/embedding_bag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ OutputVector translate_embedding_bag_common(const NodeContext& context) {
// per_sample_weights=None, include_last_offset=False, padding_idx=None)
// we have only EmbeddingBag case support, check it before translation

auto mode = context.const_input<int64_t>(4);
int64_t mode = 0;
if (!context.input_is_none(4)) {
mode = context.const_input<int64_t>(4);
}
PYTORCH_OP_CONVERSION_CHECK(mode <= 1, "Only sum and mean mode supported for aten::embedding_bag translation");
auto weight = context.get_input(0);
auto indices = context.get_input(1);
Expand Down Expand Up @@ -78,7 +81,7 @@ OutputVector translate_embedding_bag(const NodeContext& context) {
}

OutputVector translate_embedding_bag_fx(const NodeContext& context) {
num_inputs_check(context, 7, 9);
num_inputs_check(context, 3, 9);
ov::OutputVector output = translate_embedding_bag_common(context);
return {context.mark_node(make_list_construct(output))};
}
Expand Down
22 changes: 11 additions & 11 deletions src/frontends/pytorch/src/op/max_poolnd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ namespace pytorch {
namespace op {

using namespace ov::op;
OutputVector translate_max_pool_base(const NodeContext& context, int dims) {
num_inputs_check(context, 3, 6);
OutputVector translate_max_pool_base(const NodeContext& context, int dims, bool return_indices) {
num_inputs_check(context, 2, 6);
auto input = context.get_input(0);
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input));

Expand Down Expand Up @@ -68,7 +68,7 @@ OutputVector translate_max_pool_base(const NodeContext& context, int dims) {
} else {
pads = context.const_input<Shape>(3); // pytorch supports only symmetric paddings
}
Strides dilations;
auto dilations = Strides(dims, 1);
if (!context.input_is_none(4)) {
dilations = context.const_input<Strides>(4);
}
Expand All @@ -91,7 +91,7 @@ OutputVector translate_max_pool_base(const NodeContext& context, int dims) {
2));
if (is_static) {
if (no_batch_dim) {
if (context.get_output_size() == 2) {
if (return_indices) {
auto out1 = res->output(0);
auto out2 = res->output(1);
out1 = context.mark_node(std::make_shared<v0::Squeeze>(out1, const_0));
Expand All @@ -102,7 +102,7 @@ OutputVector translate_max_pool_base(const NodeContext& context, int dims) {
return {res};
}
} else {
if (context.get_output_size() == 2) {
if (return_indices) {
auto out1 = res->output(0);
auto out2 = res->output(1);
return {std::move(out1), std::move(out2)};
Expand All @@ -125,7 +125,7 @@ OutputVector translate_max_pool_base(const NodeContext& context, int dims) {

auto concat_shape = context.mark_node(
std::make_shared<v0::Concat>(OutputVector{slice_input_shape, slice_pooled_output_shape}, 0));
if (context.get_output_size() == 2) {
if (return_indices) {
auto out1 = res->output(0);
auto out2 = res->output(1);
out1 = context.mark_node(std::make_shared<v1::Reshape>(out1, concat_shape, true));
Expand All @@ -139,24 +139,24 @@ OutputVector translate_max_pool_base(const NodeContext& context, int dims) {
};

OutputVector translate_max_pool1d(const NodeContext& context) {
return translate_max_pool_base(context, 1);
return translate_max_pool_base(context, 1, context.get_output_size() == 2);
};

OutputVector translate_max_pool2d(const NodeContext& context) {
return translate_max_pool_base(context, 2);
return translate_max_pool_base(context, 2, context.get_output_size() == 2);
};

OutputVector translate_max_pool3d(const NodeContext& context) {
return translate_max_pool_base(context, 3);
return translate_max_pool_base(context, 3, context.get_output_size() == 2);
};

OutputVector translate_max_pool2d_fx(const NodeContext& context) {
auto output = translate_max_pool2d(context);
auto output = translate_max_pool_base(context, 2, true);
return {context.mark_node(make_list_construct(output))};
};

OutputVector translate_max_pool3d_fx(const NodeContext& context) {
auto output = translate_max_pool3d(context);
auto output = translate_max_pool_base(context, 3, true);
return {context.mark_node(make_list_construct(output))};
};

Expand Down
7 changes: 5 additions & 2 deletions src/frontends/pytorch/src/op/norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,14 @@ OutputVector translate_linalg_vector_norm(const NodeContext& context) {
// dtype=None) -> Tensor
// aten::linalg_vector_norm.out(Tensor self, Scalar ord=2, int[1]? dim=None, bool
// keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!):
num_inputs_check(context, 4, 6);
num_inputs_check(context, 3, 6);
auto x = context.get_input(0);
// ord defines the vector norm that is computed.
auto ord = context.const_input<float>(1);
bool keep_dim = context.const_input<bool>(3);
bool keep_dim = false;
if (!context.input_is_none(3)) {
keep_dim = context.const_input<bool>(3);
}
Output<Node> dim;
Output<Node> result;
// If dim= None, x will be flattened before the norm is computed.
Expand Down
22 changes: 15 additions & 7 deletions src/frontends/pytorch/src/op/scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,21 @@ std::shared_ptr<ov::Node> translate_scaled_dot_product_attention_common(const No

if (!context.input_is_none(3))
inputs.push_back(context.get_input(3));
else if (!context.input_is_none(6)) {
// need to fill a gap in inputs with scalar 0 to be able to pass one extra input after that
auto zero = op::v0::Constant::create(element::f32, Shape{}, {0});
inputs.push_back(context.mark_node(std::make_shared<v1::ConvertLike>(zero, query)));
}
if (!context.input_is_none(6))
if (!context.input_is_none(6)) {
if (inputs.size() < 4) {
// need to fill a gap in inputs with scalar 0 to be able to pass one extra input after that
auto zero = op::v0::Constant::create(element::f32, Shape{}, {0});
inputs.push_back(context.mark_node(std::make_shared<v1::ConvertLike>(zero, query)));
}
inputs.push_back(context.mark_node(std::make_shared<v1::ConvertLike>(context.get_input(6), query)));
} else if (context.has_attribute("scale")) {
const auto scale = context.get_input("scale");
if (inputs.size() < 4) {
auto zero = op::v0::Constant::create(element::f32, Shape{}, {0});
inputs.push_back(context.mark_node(std::make_shared<v1::ConvertLike>(zero, query)));
}
inputs.push_back(context.mark_node(std::make_shared<v1::ConvertLike>(scale, query)));
}
if (!context.input_is_none(7)) {
auto enable_gqa = context.const_input<bool>(7);
PYTORCH_OP_CONVERSION_CHECK(enable_gqa == false,
Expand All @@ -46,7 +54,7 @@ std::shared_ptr<ov::Node> translate_scaled_dot_product_attention_common(const No
OutputVector translate_scaled_dot_product_attention(const NodeContext& context) {
// aten::scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float
// dropout_p=0., bool is_causal=False, float scale=None, bool enable_gqa=False)
num_inputs_check(context, 6, 8);
num_inputs_check(context, 3, 8);
return {translate_scaled_dot_product_attention_common(context)};
};

Expand Down
30 changes: 6 additions & 24 deletions src/frontends/pytorch/src/op/var_mean.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,11 @@ OutputVector translate_var_mean_fx(const NodeContext& context) {
}
int32_t correction = 0;
if (context.has_attribute("correction")) {
auto correction_node = context.get_attribute<Output<Node>>("correction");
auto const_node = as_type_ptr<v0::Constant>(correction_node.get_node_shared_ptr());
PYTORCH_OP_CONVERSION_CHECK(const_node, "correction must be const.");
correction = const_node->cast_vector<int32_t>()[0];
correction = context.get_attribute<int32_t>("correction");
}
bool keepdim = false;
if (context.has_attribute("keepdim")) {
auto keepdim_node = context.get_attribute<Output<Node>>("keepdim");
auto const_node = as_type_ptr<v0::Constant>(keepdim_node.get_node_shared_ptr());
PYTORCH_OP_CONVERSION_CHECK(const_node, "keepdim must be const.");
keepdim = const_node->cast_vector<bool>()[0];
keepdim = context.get_attribute<bool>("keepdim");
}
auto res = translate_var_mean_common(context, data, axes, correction, keepdim);
return {context.mark_node(make_list_construct(res))};
Expand All @@ -124,17 +118,11 @@ OutputVector translate_var_fx(const NodeContext& context) {
}
int32_t correction = 0;
if (context.has_attribute("correction")) {
auto correction_node = context.get_attribute<Output<Node>>("correction");
auto const_node = as_type_ptr<v0::Constant>(correction_node.get_node_shared_ptr());
PYTORCH_OP_CONVERSION_CHECK(const_node, "correction must be const.");
correction = const_node->cast_vector<int32_t>()[0];
correction = context.get_attribute<int32_t>("correction");
}
bool keepdim = false;
if (context.has_attribute("keepdim")) {
auto keepdim_node = context.get_attribute<Output<Node>>("keepdim");
auto const_node = as_type_ptr<v0::Constant>(keepdim_node.get_node_shared_ptr());
PYTORCH_OP_CONVERSION_CHECK(const_node, "keepdim must be const.");
keepdim = const_node->cast_vector<bool>()[0];
keepdim = context.get_attribute<bool>("keepdim");
}
auto res = translate_var_mean_common(context, data, axes, correction, keepdim);
return {res[0]};
Expand All @@ -155,17 +143,11 @@ OutputVector translate_std_fx(const NodeContext& context) {
}
int32_t correction = 0;
if (context.has_attribute("correction")) {
auto correction_node = context.get_attribute<Output<Node>>("correction");
auto const_node = as_type_ptr<v0::Constant>(correction_node.get_node_shared_ptr());
PYTORCH_OP_CONVERSION_CHECK(const_node, "correction must be const.");
correction = const_node->cast_vector<int32_t>()[0];
correction = context.get_attribute<int32_t>("correction");
}
bool keepdim = false;
if (context.has_attribute("keepdim")) {
auto keepdim_node = context.get_attribute<Output<Node>>("keepdim");
auto const_node = as_type_ptr<v0::Constant>(keepdim_node.get_node_shared_ptr());
PYTORCH_OP_CONVERSION_CHECK(const_node, "keepdim must be const.");
keepdim = const_node->cast_vector<bool>()[0];
keepdim = context.get_attribute<bool>("keepdim");
}
auto res = translate_var_mean_common(context, data, axes, correction, keepdim);

Expand Down
Loading
Loading