Skip to content

Commit

Permalink
add layout opt rule for group norm
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712372815
  • Loading branch information
chunnienc authored and copybara-github committed Jan 6, 2025
1 parent 7e68dce commit 58a7cde
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def _qdq_layout_sensitive_inputs_getter(node: Node):
@layout_sensitive_inputs_getters.register(
aten._native_batch_norm_legit_no_training
)
@layout_sensitive_inputs_getters.register(aten.group_norm)
@layout_sensitive_inputs_getters.register(aten.native_group_norm)
def _first_arg_getter(node):
return [node.args[0]]
Expand Down Expand Up @@ -188,6 +189,14 @@ def _aten_norm_checker(node):
return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)


@nhwcable_node_checkers.register(aten.group_norm)
def _aten_group_norm_checker(node):
val = node.meta.get("val")
if not hasattr(val, "shape"):
return NHWCable(can_be=False, must_be=False)
return NHWCable(can_be=len(val.shape) == 4, must_be=False)


@nhwcable_node_checkers.register(aten.native_group_norm)
def _aten_native_group_norm_checker(node):
val = node.meta.get("val")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,18 @@ def batch_norm(input, weight, bias, running_mean, running_var, momentum, eps):
node.target = batch_norm


@rewriters.register(aten.group_norm.default)
def _aten_group_norm(node):
def group_norm(input, num_groups: int, weight=None, bias=None, eps=1e-5):
# Disable NHWC rewriter with native decomposied ops due to precision issue.
# TODO(b/354780253): Re-enable NHWC rewriter with proper lowering.
input = utils.tensor_to_nchw(input)
res = aten.group_norm.default(input, num_groups, weight, bias, eps=eps)
return utils.tensor_to_nhwc(res)

node.target = group_norm


@rewriters.register(aten.native_group_norm.default)
def _aten_native_group_norm(node):

Expand All @@ -354,6 +366,7 @@ def native_group_norm(
flattened_inner_size: int,
num_groups: int,
eps: float,
**kwargs,
):
input_reshaped = torch.reshape(
input,
Expand Down

0 comments on commit 58a7cde

Please sign in to comment.