From f1ef55df0016a24f1ca8861e9d993879463c0f46 Mon Sep 17 00:00:00 2001 From: Andrii Staikov Date: Thu, 10 Oct 2024 12:58:56 +0200 Subject: [PATCH] [FRONTEND] Remove unnecessary Reshapes in PyTorch Conv1D conversion (#26960) [FRONTEND] Remove unnecessary Reshapes in PyTorch Conv1D conversion PyTorch's Conv1D operation utilizes additional reshaping in its implementation because torch.addmm() doesn't support tensors with rank higher than 2. Remove unnecessary Reshapes in PyTorch's Conv1D conversion as OpenVINO's MatMul supports tensors with rank higher than 2. This should reduce the number of nodes in a graph and potentially improve performance. - Tickets: * CVS-150872 --- src/frontends/pytorch/src/op/addmm.cpp | 13 ++----------- tests/model_hub_tests/pytorch/test_llm.py | 8 ++++++-- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/src/frontends/pytorch/src/op/addmm.cpp b/src/frontends/pytorch/src/op/addmm.cpp index 4ecfd403afc6dd..a61898aadbdbdb 100644 --- a/src/frontends/pytorch/src/op/addmm.cpp +++ b/src/frontends/pytorch/src/op/addmm.cpp @@ -73,17 +73,8 @@ OutputVector translate_conv1d_ext(const NodeContext& context) { auto bias = context.get_input(2); bias = context.mark_node(std::make_shared(bias, x)); - auto neg_one = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); - auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0})); - auto shape_x = context.mark_node(std::make_shared(x, element::i32)); - auto x_last_dim = context.mark_node(std::make_shared(shape_x, neg_one, zero)); - auto x_new_shape = context.mark_node(std::make_shared(OutputVector{neg_one, x_last_dim}, 0)); - - auto x_new = context.mark_node(std::make_shared(x, x_new_shape, false)); - auto mm = context.mark_node(std::make_shared(x_new, weight)); - auto addmm = context.mark_node(std::make_shared(bias, mm)); - auto size_out = context.mark_node(std::make_shared(shape_x, neg_one, neg_one, zero)); - return {context.mark_node(std::make_shared(addmm, size_out, false))}; + auto mm = context.mark_node(std::make_shared(x, weight)); + return {context.mark_node(std::make_shared(mm, bias))}; }; } // namespace op diff --git a/tests/model_hub_tests/pytorch/test_llm.py b/tests/model_hub_tests/pytorch/test_llm.py index d48ac60e24db71..9acf8e2100c520 100644 --- a/tests/model_hub_tests/pytorch/test_llm.py +++ b/tests/model_hub_tests/pytorch/test_llm.py @@ -100,13 +100,16 @@ def load_model(self, name, type): config = {} model_kwargs = {"torchscript": True, "trust_remote_code": True} is_gptq = is_gptq_model(config) + is_gpt2 = name == "openai-community/gpt2" + if is_gptq: self.cuda_available, self.gptq_postinit = patch_gptq() model_kwargs["torch_dtype"] = torch.float32 self.ov_config = {"DYNAMIC_QUANTIZATION_GROUP_SIZE": "0"} + elif is_gpt2: + model_kwargs["torch_dtype"] = torch.float16 else: model_kwargs["torch_dtype"] = "auto" - pass t = AutoTokenizer.from_pretrained(name, trust_remote_code=True) self.model = AutoModelForCausalLM.from_pretrained(name, **model_kwargs) @@ -114,7 +117,7 @@ def load_model(self, name, type): model = self.model else: assert self.model.config.torch_dtype in [ - torch.float16, torch.bfloat16] + torch.float16, torch.bfloat16] or is_gpt2 model = copy.deepcopy(self.model).float() example = t("Some input text to verify that model works.", @@ -188,6 +191,7 @@ def get_pkv(model, tokenizer): @pytest.mark.parametrize("type,name", [ ("opt_gptq", "katuni4ka/opt-125m-gptq"), ("llama", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"), + ("gpt2", "openai-community/gpt2") ]) @pytest.mark.precommit @pytest.mark.nightly