@@ -47,6 +47,18 @@ def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
47
47
return min - torch .log1p (z ), buffer
48
48
49
49
50
+ def recompute_mean_var (input : Tensor , rstd : Tensor , inner_dim_indices : List [int ], keepdim : bool ):
51
+ # for most norm decompositions, it will be the same as the core version except for here.
52
+ # We recompute the mean and variance so that they track gradients through input
53
+
54
+ mean = torch .mean (input , dim = inner_dim_indices , keepdim = keepdim )
55
+ var = torch .var (input , dim = inner_dim_indices , unbiased = False , keepdim = keepdim )
56
+ eps = torch .pow (1 / rstd , 2 ) - var # this makes me so sad inside
57
+ eps = eps .detach ()
58
+ rstd = 1 / torch .sqrt (var + eps )
59
+ return mean , rstd
60
+
61
+
50
62
@register_decomposition_for_jvp (aten .native_layer_norm_backward )
51
63
def native_layer_norm_backward (
52
64
grad_out : Tensor ,
@@ -80,13 +92,7 @@ def native_layer_norm_backward(
80
92
input .new_zeros (input_shape [axis :]),
81
93
)
82
94
83
- # this is exactly the same as the other decomposition except for here. We recompute the mean and variance
84
- # so that they track gradients through input
85
- mean_ = torch .mean (input , dim = inner_dim_indices , keepdim = True )
86
- var = torch .var (input , dim = inner_dim_indices , unbiased = False , keepdim = True )
87
- eps = torch .pow (1 / rstd , 2 ) - var # this makes me so sad inside
88
- eps = eps .detach ()
89
- rstd_ = 1 / torch .sqrt (var + eps )
95
+ mean_ , rstd_ = recompute_mean_var (input , rstd , inner_dim_indices , keepdim = True )
90
96
91
97
x_hat = (input - mean_ ) * rstd_
92
98
if weight is not None :
@@ -128,3 +134,84 @@ def native_layer_norm_backward(
128
134
d_bias = torch .zeros (()) # should be None but doesn't work with vjp
129
135
130
136
return (d_input , d_weight , d_bias )
137
+
138
+
139
+ def prod (x : List [int ]):
140
+ r = 1
141
+ for i in x :
142
+ r *= i
143
+ return r
144
+
145
+
146
+ @register_decomposition (aten .native_batch_norm_backward ) # @register_decomposition_for_jvp after in core
147
+ def native_batch_norm_backward (
148
+ grad_out : Tensor ,
149
+ input : Tensor ,
150
+ weight : Optional [Tensor ],
151
+ running_mean : Optional [Tensor ],
152
+ running_var : Optional [Tensor ],
153
+ save_mean : Optional [Tensor ],
154
+ save_invstd : Optional [Tensor ],
155
+ train : bool ,
156
+ eps : float ,
157
+ output_mask : List [bool ],
158
+ ) -> Tuple [Tensor , Optional [Tensor ], Optional [Tensor ]]:
159
+ input_shape = input .shape
160
+ input_rank = input .dim ()
161
+ assert input_rank >= 2 , "rank of the input must be at least 2"
162
+
163
+ axis = 1
164
+ num_features = prod (input_shape ) / input_shape [axis ]
165
+ mean = save_mean
166
+ invstd = save_invstd
167
+ if train :
168
+ assert save_mean is not None and save_invstd is not None , "when train=True, save_mean and save_invstd are required"
169
+
170
+ reduciton_dims = [0 ] + list (range (2 , input .dim ()))
171
+ assert invstd is not None # for typing
172
+ mean , invstd = recompute_mean_var (input , invstd , reduciton_dims , keepdim = False )
173
+ else :
174
+ assert running_mean is not None and running_var is not None
175
+ mean = running_mean
176
+ invstd = torch .rsqrt (running_var + eps )
177
+
178
+ broadcast_mask = [1 ] * input_rank
179
+ broadcast_mask [axis ] = input_shape [axis ]
180
+
181
+ reduction_axes : List [int ] = []
182
+ for i in range (input_rank ):
183
+ if i != axis :
184
+ reduction_axes .append (i )
185
+
186
+ mean = torch .reshape (mean , broadcast_mask )
187
+ norm = 1.0 / num_features
188
+ grad_output_sum = torch .sum (grad_out , reduction_axes )
189
+ dot_p = torch .sum (grad_out * (input - mean ), reduction_axes )
190
+
191
+ grad_mean = torch .reshape (grad_output_sum * norm , broadcast_mask )
192
+ proj_scale = torch .reshape (torch .mul (dot_p * norm , invstd * invstd ), broadcast_mask )
193
+
194
+ if weight is None :
195
+ grad_scale = torch .reshape (invstd , broadcast_mask ) * 1.0
196
+ else :
197
+ grad_scale = torch .reshape (invstd * weight , broadcast_mask )
198
+
199
+ if train :
200
+ proj = (input - mean ) * proj_scale
201
+ grad_input = ((grad_out - proj ) - grad_mean ) * grad_scale
202
+ else :
203
+ grad_input = grad_out * grad_scale
204
+
205
+ if output_mask [1 ]:
206
+ grad_weight = dot_p * invstd
207
+ elif weight is not None :
208
+ grad_weight = torch .zeros_like (weight ) # should be None but doesn't work with vjp
209
+ else :
210
+ grad_weight = torch .zeros (()) # should be None but doesn't work with vjp
211
+
212
+ if output_mask [2 ]:
213
+ grad_bias = grad_output_sum
214
+ else :
215
+ grad_bias = torch .zeros_like (grad_output_sum ) # should be None but doesn't work with vjp
216
+
217
+ return (grad_input , grad_weight , grad_bias )
0 commit comments