@@ -99,6 +99,7 @@ def sgd_train_linear_model(
99
99
This will return the final training loss (averaged with
100
100
`running_loss_window`)
101
101
"""
102
+
102
103
loss_window : List [torch .Tensor ] = []
103
104
min_avg_loss = None
104
105
convergence_counter = 0
@@ -144,77 +145,77 @@ def get_point(datapoint):
144
145
if model .linear .bias is not None :
145
146
model .linear .bias .zero_ ()
146
147
147
- with torch .enable_grad ():
148
- optim = torch .optim .SGD (model .parameters (), lr = initial_lr )
149
- if reduce_lr :
150
- scheduler = torch .optim .lr_scheduler .ReduceLROnPlateau (
151
- optim , factor = 0.5 , patience = patience , threshold = threshold
152
- )
153
-
154
- t1 = time .time ()
155
- epoch = 0
156
- i = 0
157
- while epoch < max_epoch :
158
- while True : # for x, y, w in dataloader
159
- if running_loss_window is None :
160
- running_loss_window = x .shape [0 ] * len (dataloader )
161
-
162
- y = y .view (x .shape [0 ], - 1 )
163
- if w is not None :
164
- w = w .view (x .shape [0 ], - 1 )
165
-
166
- i += 1
167
-
168
- out = model (x )
169
-
170
- loss = loss_fn (y , out , w )
171
- if reg_term is not None :
172
- reg = torch .norm (model .linear .weight , p = reg_term )
173
- loss += reg .sum () * alpha
174
-
175
- if len (loss_window ) >= running_loss_window :
176
- loss_window = loss_window [1 :]
177
- loss_window .append (loss .clone ().detach ())
178
- assert len (loss_window ) <= running_loss_window
179
-
180
- average_loss = torch .mean (torch .stack (loss_window ))
181
- if min_avg_loss is not None :
182
- # if we haven't improved by at least `threshold`
183
- if average_loss > min_avg_loss or torch .isclose (
184
- min_avg_loss , average_loss , atol = threshold
185
- ):
186
- convergence_counter += 1
187
- if convergence_counter >= patience :
188
- converged = True
189
- break
190
- else :
191
- convergence_counter = 0
192
- if min_avg_loss is None or min_avg_loss >= average_loss :
193
- min_avg_loss = average_loss .clone ()
194
-
195
- if debug :
196
- print (
197
- f"lr={ optim .param_groups [0 ]['lr' ]} , Loss={ loss } ,"
198
- + "Aloss={average_loss}, min_avg_loss={min_avg_loss}"
199
- )
200
-
201
- loss .backward ()
202
- optim .step ()
203
- model .zero_grad ()
204
- if scheduler :
205
- scheduler .step (average_loss )
206
-
207
- temp = next (data_iter , None )
208
- if temp is None :
209
- break
210
- x , y , w = get_point (temp )
211
-
212
- if converged :
148
+ optim = torch .optim .SGD (model .parameters (), lr = initial_lr )
149
+ if reduce_lr :
150
+ scheduler = torch .optim .lr_scheduler .ReduceLROnPlateau (
151
+ optim , factor = 0.5 , patience = patience , threshold = threshold
152
+ )
153
+
154
+ t1 = time .time ()
155
+ epoch = 0
156
+ i = 0
157
+ while epoch < max_epoch :
158
+ while True : # for x, y, w in dataloader
159
+ if running_loss_window is None :
160
+ running_loss_window = x .shape [0 ] * len (dataloader )
161
+
162
+ y = y .view (x .shape [0 ], - 1 )
163
+ if w is not None :
164
+ w = w .view (x .shape [0 ], - 1 )
165
+
166
+ i += 1
167
+
168
+ out = model (x )
169
+
170
+ loss = loss_fn (y , out , w )
171
+ if reg_term is not None :
172
+ reg = torch .norm (model .linear .weight , p = reg_term )
173
+ loss += reg .sum () * alpha
174
+
175
+ if len (loss_window ) >= running_loss_window :
176
+ loss_window = loss_window [1 :]
177
+ loss_window .append (loss .clone ().detach ())
178
+ assert len (loss_window ) <= running_loss_window
179
+
180
+ average_loss = torch .mean (torch .stack (loss_window ))
181
+ if min_avg_loss is not None :
182
+ # if we haven't improved by at least `threshold`
183
+ if average_loss > min_avg_loss or torch .isclose (
184
+ min_avg_loss , average_loss , atol = threshold
185
+ ):
186
+ convergence_counter += 1
187
+ if convergence_counter >= patience :
188
+ converged = True
189
+ break
190
+ else :
191
+ convergence_counter = 0
192
+ if min_avg_loss is None or min_avg_loss >= average_loss :
193
+ min_avg_loss = average_loss .clone ()
194
+
195
+ if debug :
196
+ print (
197
+ f"lr={ optim .param_groups [0 ]['lr' ]} , Loss={ loss } ,"
198
+ + "Aloss={average_loss}, min_avg_loss={min_avg_loss}"
199
+ )
200
+
201
+ loss .backward ()
202
+
203
+ optim .step ()
204
+ model .zero_grad ()
205
+ if scheduler :
206
+ scheduler .step (average_loss )
207
+
208
+ temp = next (data_iter , None )
209
+ if temp is None :
213
210
break
211
+ x , y , w = get_point (temp )
212
+
213
+ if converged :
214
+ break
214
215
215
- epoch += 1
216
- data_iter = iter (dataloader )
217
- x , y , w = get_point (next (data_iter ))
216
+ epoch += 1
217
+ data_iter = iter (dataloader )
218
+ x , y , w = get_point (next (data_iter ))
218
219
219
220
t2 = time .time ()
220
221
return {
0 commit comments