Skip to content

Commit 347334c

Browse files
author
Samantha Andow
authored
batch norm forward over reverse coverage with decomposition (#877)
1 parent e82e64a commit 347334c

File tree

4 files changed

+100
-10
lines changed

4 files changed

+100
-10
lines changed

functorch/_src/decompositions.py

+94-7
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,18 @@ def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
4747
return min - torch.log1p(z), buffer
4848

4949

50+
def recompute_mean_var(input: Tensor, rstd: Tensor, inner_dim_indices: List[int], keepdim: bool):
51+
# for most norm decompositions, it will be the same as the core version except for here.
52+
# We recompute the mean and variance so that they track gradients through input
53+
54+
mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim)
55+
var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim)
56+
eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside
57+
eps = eps.detach()
58+
rstd = 1 / torch.sqrt(var + eps)
59+
return mean, rstd
60+
61+
5062
@register_decomposition_for_jvp(aten.native_layer_norm_backward)
5163
def native_layer_norm_backward(
5264
grad_out: Tensor,
@@ -80,13 +92,7 @@ def native_layer_norm_backward(
8092
input.new_zeros(input_shape[axis:]),
8193
)
8294

83-
# this is exactly the same as the other decomposition except for here. We recompute the mean and variance
84-
# so that they track gradients through input
85-
mean_ = torch.mean(input, dim=inner_dim_indices, keepdim=True)
86-
var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=True)
87-
eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside
88-
eps = eps.detach()
89-
rstd_ = 1 / torch.sqrt(var + eps)
95+
mean_, rstd_ = recompute_mean_var(input, rstd, inner_dim_indices, keepdim=True)
9096

9197
x_hat = (input - mean_) * rstd_
9298
if weight is not None:
@@ -128,3 +134,84 @@ def native_layer_norm_backward(
128134
d_bias = torch.zeros(()) # should be None but doesn't work with vjp
129135

130136
return (d_input, d_weight, d_bias)
137+
138+
139+
def prod(x: List[int]):
140+
r = 1
141+
for i in x:
142+
r *= i
143+
return r
144+
145+
146+
@register_decomposition(aten.native_batch_norm_backward) # @register_decomposition_for_jvp after in core
147+
def native_batch_norm_backward(
148+
grad_out: Tensor,
149+
input: Tensor,
150+
weight: Optional[Tensor],
151+
running_mean: Optional[Tensor],
152+
running_var: Optional[Tensor],
153+
save_mean: Optional[Tensor],
154+
save_invstd: Optional[Tensor],
155+
train: bool,
156+
eps: float,
157+
output_mask: List[bool],
158+
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
159+
input_shape = input.shape
160+
input_rank = input.dim()
161+
assert input_rank >= 2, "rank of the input must be at least 2"
162+
163+
axis = 1
164+
num_features = prod(input_shape) / input_shape[axis]
165+
mean = save_mean
166+
invstd = save_invstd
167+
if train:
168+
assert save_mean is not None and save_invstd is not None, "when train=True, save_mean and save_invstd are required"
169+
170+
reduciton_dims = [0] + list(range(2, input.dim()))
171+
assert invstd is not None # for typing
172+
mean, invstd = recompute_mean_var(input, invstd, reduciton_dims, keepdim=False)
173+
else:
174+
assert running_mean is not None and running_var is not None
175+
mean = running_mean
176+
invstd = torch.rsqrt(running_var + eps)
177+
178+
broadcast_mask = [1] * input_rank
179+
broadcast_mask[axis] = input_shape[axis]
180+
181+
reduction_axes: List[int] = []
182+
for i in range(input_rank):
183+
if i != axis:
184+
reduction_axes.append(i)
185+
186+
mean = torch.reshape(mean, broadcast_mask)
187+
norm = 1.0 / num_features
188+
grad_output_sum = torch.sum(grad_out, reduction_axes)
189+
dot_p = torch.sum(grad_out * (input - mean), reduction_axes)
190+
191+
grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask)
192+
proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask)
193+
194+
if weight is None:
195+
grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0
196+
else:
197+
grad_scale = torch.reshape(invstd * weight, broadcast_mask)
198+
199+
if train:
200+
proj = (input - mean) * proj_scale
201+
grad_input = ((grad_out - proj) - grad_mean) * grad_scale
202+
else:
203+
grad_input = grad_out * grad_scale
204+
205+
if output_mask[1]:
206+
grad_weight = dot_p * invstd
207+
elif weight is not None:
208+
grad_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp
209+
else:
210+
grad_weight = torch.zeros(()) # should be None but doesn't work with vjp
211+
212+
if output_mask[2]:
213+
grad_bias = grad_output_sum
214+
else:
215+
grad_bias = torch.zeros_like(grad_output_sum) # should be None but doesn't work with vjp
216+
217+
return (grad_input, grad_weight, grad_bias)

functorch/_src/eager_transforms.py

+2
Original file line numberDiff line numberDiff line change
@@ -1491,5 +1491,7 @@ def _register_python_decomposition_vmap(decomp):
14911491
_register_jit_decomposition(torch.ops.aten._softmax_backward_data.default)
14921492
_register_jit_decomposition(torch.ops.aten.log_sigmoid_forward.default)
14931493
_register_jit_decomposition(torch.ops.aten.native_layer_norm_backward.default)
1494+
_register_jit_decomposition(torch.ops.aten.native_batch_norm_backward.default)
1495+
_register_jit_decomposition(torch.ops.aten.cudnn_batch_norm_backward.default)
14941496
_register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default)
14951497
_register_python_decomposition_vmap(torch.ops.aten.addr.default)

functorch/csrc/DynamicLayer.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,8 @@ TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) {
502502
OP_DECOMPOSE(log_sigmoid);
503503
JVP_DECOMP(log_sigmoid_forward);
504504
JVP_DECOMP(native_layer_norm_backward);
505+
JVP_DECOMP(native_batch_norm_backward);
506+
JVP_DECOMP(cudnn_batch_norm_backward);
505507
}
506508

507509

test/test_ops.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1149,8 +1149,6 @@ def get_vjp(cotangents, *primals):
11491149
xfail('logdet', ''),
11501150
xfail('nanmean', ''),
11511151
xfail('nansum', ''),
1152-
xfail('nn.functional.batch_norm', ''),
1153-
xfail('nn.functional.batch_norm', 'without_cudnn', device_type='cuda'),
11541152
xfail('nn.functional.embedding'),
11551153
xfail('nn.functional.embedding', 'functorch'),
11561154
xfail('nn.functional.embedding_bag', ''),
@@ -1249,7 +1247,8 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
12491247
'softmax',
12501248
'log_softmax',
12511249
'nn.functional.cross_entropy',
1252-
'nn.functional.layer_norm'
1250+
'nn.functional.layer_norm',
1251+
'nn.functional.batch_norm',
12531252
}
12541253
if op.name in FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH:
12551254
self.assertFalse(op.supports_fwgrad_bwgrad,

0 commit comments

Comments
 (0)