Skip to content

Commit

Permalink
[FRONTEND] Remove unnecessary Reshapes in PyTorch Conv1D conversion (o…
Browse files Browse the repository at this point in the history
…penvinotoolkit#26960)

[FRONTEND] Remove unnecessary Reshapes in PyTorch Conv1D conversion

PyTorch's Conv1D operation utilizes additional reshaping in its
implementation because torch.addmm() doesn't support tensors with rank
higher than 2.

Remove unnecessary Reshapes in PyTorch's Conv1D conversion as OpenVINO's
MatMul supports tensors with rank higher than 2. This should reduce the
number of nodes in a graph and potentially improve performance.

- Tickets:
	* CVS-150872
  • Loading branch information
CuriousPanCake authored Oct 10, 2024
1 parent 7d5b7cf commit f1ef55d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 13 deletions.
13 changes: 2 additions & 11 deletions src/frontends/pytorch/src/op/addmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,8 @@ OutputVector translate_conv1d_ext(const NodeContext& context) {
auto bias = context.get_input(2);
bias = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(bias, x));

auto neg_one = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
auto shape_x = context.mark_node(std::make_shared<v3::ShapeOf>(x, element::i32));
auto x_last_dim = context.mark_node(std::make_shared<v8::Gather>(shape_x, neg_one, zero));
auto x_new_shape = context.mark_node(std::make_shared<v0::Concat>(OutputVector{neg_one, x_last_dim}, 0));

auto x_new = context.mark_node(std::make_shared<v1::Reshape>(x, x_new_shape, false));
auto mm = context.mark_node(std::make_shared<v0::MatMul>(x_new, weight));
auto addmm = context.mark_node(std::make_shared<v1::Add>(bias, mm));
auto size_out = context.mark_node(std::make_shared<v12::ScatterElementsUpdate>(shape_x, neg_one, neg_one, zero));
return {context.mark_node(std::make_shared<v1::Reshape>(addmm, size_out, false))};
auto mm = context.mark_node(std::make_shared<v0::MatMul>(x, weight));
return {context.mark_node(std::make_shared<v1::Add>(mm, bias))};
};

} // namespace op
Expand Down
8 changes: 6 additions & 2 deletions tests/model_hub_tests/pytorch/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,24 @@ def load_model(self, name, type):
config = {}
model_kwargs = {"torchscript": True, "trust_remote_code": True}
is_gptq = is_gptq_model(config)
is_gpt2 = name == "openai-community/gpt2"

if is_gptq:
self.cuda_available, self.gptq_postinit = patch_gptq()
model_kwargs["torch_dtype"] = torch.float32
self.ov_config = {"DYNAMIC_QUANTIZATION_GROUP_SIZE": "0"}
elif is_gpt2:
model_kwargs["torch_dtype"] = torch.float16
else:
model_kwargs["torch_dtype"] = "auto"
pass

t = AutoTokenizer.from_pretrained(name, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(name, **model_kwargs)
if is_gptq:
model = self.model
else:
assert self.model.config.torch_dtype in [
torch.float16, torch.bfloat16]
torch.float16, torch.bfloat16] or is_gpt2
model = copy.deepcopy(self.model).float()

example = t("Some input text to verify that model works.",
Expand Down Expand Up @@ -188,6 +191,7 @@ def get_pkv(model, tokenizer):
@pytest.mark.parametrize("type,name", [
("opt_gptq", "katuni4ka/opt-125m-gptq"),
("llama", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"),
("gpt2", "openai-community/gpt2")
])
@pytest.mark.precommit
@pytest.mark.nightly
Expand Down

0 comments on commit f1ef55d

Please sign in to comment.