Skip to content

Commit

Permalink
decompose
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Jan 2, 2024
1 parent dd929e9 commit d30d70b
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions src/common/snippets/src/pass/group_normalization_decomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,20 @@ GroupNormalizationDecomposition::GroupNormalizationDecomposition() {
// reshape [N, C, spatial] to [N, group, C / group, spatial]
const auto orig_shape = group_norm_node->get_input_shape(0);
size_t orig_rank = orig_shape.size();
ov::Shape group_shape(orig_rank + 1);
size_t group_rank = orig_rank + 1;
ov::Shape group_shape(group_rank);
group_shape[0] = orig_shape[0];
group_shape[1] = num_groups;
group_shape[2] = orig_shape[1] / num_groups;
for (size_t i = 3; i < orig_rank + 1; ++i) {
for (size_t i = 3; i < group_rank; ++i) {
group_shape[i] = orig_shape[i - 1];
}
auto group_shape_node = std::make_shared<ov::op::v0::Constant>(element::i64, Shape{group_shape.size()}, group_shape);
auto group_shape_node = std::make_shared<ov::op::v0::Constant>(element::i64, Shape{group_rank}, group_shape);
const auto reshaped_node = std::make_shared<ov::op::v1::Reshape>(data, group_shape_node, true);

// reduceSum on dimension [C / group, spatial]
int64_t axis_start = 2;
std::vector<int64_t> axis(group_shape.size() - axis_start);
std::vector<int64_t> axis(group_rank - axis_start);
std::iota(axis.begin(), axis.end(), axis_start); // axis:[2, 3, 4...]
auto axis_node = std::make_shared<ov::op::v0::Constant>(element::i64, Shape{axis.size()}, axis);
// todo: snippets op ReduceSum to have emitter to generate
Expand Down Expand Up @@ -79,7 +80,7 @@ GroupNormalizationDecomposition::GroupNormalizationDecomposition() {

// ( (x - mean) / variance) * scale + bias
// reshape scale and bias
std::vector<size_t> c_shape(group_shape.size(), 1);
std::vector<size_t> c_shape(group_rank, 1);
c_shape[1] = group_shape[1];
c_shape[2] = group_shape[2];
auto c_reshape = std::make_shared<ov::op::v0::Constant>(element::i64, Shape{c_shape.size()}, c_shape);
Expand All @@ -100,9 +101,15 @@ GroupNormalizationDecomposition::GroupNormalizationDecomposition() {
auto orig_shape_node = std::make_shared<ov::op::v0::Constant>(element::i64, Shape{orig_shape.size()}, orig_shape);
const auto reshape_back_node = std::make_shared<ov::op::v1::Reshape>(biased_node, orig_shape_node, true);

ov::replace_node_update_name(group_norm_node, biased_node);
std::vector<size_t> subtensor(group_rank, 1);
for (size_t i = axis_start; i < group_rank; ++i)
subtensor[i] = PortDescriptor::ServiceDimensions::FULL_DIM;
PortDescriptorUtils::set_port_descriptor_ptr(reduce_sum->input(0), std::make_shared<PortDescriptor>(reduce_sum->input(0), subtensor));
PortDescriptorUtils::set_port_descriptor_ptr(reduce_sum->output(0), std::make_shared<PortDescriptor>(reduce_sum->output(0), subtensor));
PortDescriptorUtils::set_port_descriptor_ptr(mean_sum_variance->input(0), std::make_shared<PortDescriptor>(mean_sum_variance->input(0), subtensor));
PortDescriptorUtils::set_port_descriptor_ptr(mean_sum_variance->output(0), std::make_shared<PortDescriptor>(mean_sum_variance->output(0), subtensor));

return true;
return ov::replace_node_update_name(group_norm_node, reshape_back_node);
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(group_norm_pattern, matcher_name);
Expand Down

0 comments on commit d30d70b

Please sign in to comment.