Skip to content

Commit

Permalink
Add log10 operator and test
Browse files Browse the repository at this point in the history
  • Loading branch information
siddhant-0707 committed Oct 20, 2023
1 parent 5018be8 commit ee0a0ed
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/frontends/pytorch/src/op/log.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ OutputVector translate_log2(const NodeContext& context) {
return {res};
};

OutputVector translate_log10(const NodeContext& context) {
// torch.log10 returns a tensor with the logarithm to the base 10 of the elements of input.
num_inputs_check(context, 1, 1);
auto x = context.get_input(0);
auto ten = context.mark_node(v0::Constant::create(element::f32, Shape{}, {10}));
x = context.mark_node(std::make_shared<v0::Convert>(x, element::f32));
auto log10 = context.mark_node(std::make_shared<v0::Log>(ten));
auto log = context.mark_node(std::make_shared<v0::Log>(x));
auto res = context.mark_node(std::make_shared<v1::Divide>(log, log10));
return {res};
};

OutputVector translate_logsumexp(const NodeContext& context) {
num_inputs_check(context, 1, 2);
auto input = context.get_input(0);
Expand Down
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ OP_CONVERTER(translate_log);
OP_CONVERTER(translate_log1p);
OP_CONVERTER(translate_log_softmax);
OP_CONVERTER(translate_log2);
OP_CONVERTER(translate_log10);
OP_CONVERTER(translate_logsumexp);
OP_CONVERTER(translate_loop);
OP_CONVERTER(translate_masked_fill);
Expand Down Expand Up @@ -380,6 +381,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::log1p", op::translate_log1p},
{"aten::log1p_", op::inplace_op<op::translate_log1p>},
{"aten::log2", op::translate_log2},
{"aten::log10", op::translate_log10},
{"aten::log2_", op::inplace_op<op::translate_log2>},
{"aten::lt", op::translate_1to1_match_2_inputs_align_types<opset10::Less>},
{"aten::masked_fill", op::translate_masked_fill},
Expand Down
2 changes: 2 additions & 0 deletions tests/layer_tests/pytorch_tests/test_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def create_model(self, op):
"log_": torch.log_,
"log2": torch.log2,
"log2_": torch.log2_,
"log10": torch.log10,
"log1p": torch.log1p,
"log1p_": torch.log1p_
}
Expand Down Expand Up @@ -45,6 +46,7 @@ def forward(self, x):
["log2", "float32"],
["log2", "int32"],
["log2_", "float32"],
["log10", "float32"],
["log1p", "float32"],
["log1p", "int32"],
["log1p_", "float32"]])
Expand Down

0 comments on commit ee0a0ed

Please sign in to comment.