diff --git a/tf/tfprocess.py b/tf/tfprocess.py index c6ddba2e..b44e92f0 100644 --- a/tf/tfprocess.py +++ b/tf/tfprocess.py @@ -121,6 +121,7 @@ def __init__(self, cfg): self.training = tf.placeholder(tf.bool) self.global_step = tf.Variable(0, name='global_step', trainable=False) self.learning_rate = tf.placeholder(tf.float32) + self.target_lr = None def init(self, dataset, train_iterator, test_iterator): # TF variables @@ -200,7 +201,7 @@ def init_net(self, next_batch): # You need to change the learning rate here if you are training # from a self-play training set, for example start with 0.005 instead. - opt_op = tf.train.MomentumOptimizer( + self.opt_op = tf.train.MomentumOptimizer( learning_rate=self.learning_rate, momentum=0.9, use_nesterov=True) # Do swa after we contruct the net @@ -227,10 +228,19 @@ def init_net(self, next_batch): var.initialized_value()), trainable=False) for var in tf.trainable_variables()] self.zero_op = [var.assign(tf.zeros_like(var)) for var in gradient_accum] + if self.cfg['training'].get('lr_search', False): + self.backup_vars = [tf.Variable(tf.zeros_like( + var.initialized_value()), trainable=False) for var in tf.trainable_variables()] + self.backup_momentums = [tf.Variable(tf.zeros_like( + var.initialized_value()), trainable=False) for var in tf.trainable_variables()] + self.opt_backup_op = None + self.opt_restore_op = None + self.last_lr = tf.Variable(0., name='last_lr', trainable=False) + self.last_lr_cached = None self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(self.update_ops): - gradients = opt_op.compute_gradients(loss) + gradients = self.opt_op.compute_gradients(loss) self.accum_op = [accum.assign_add( gradient[0]) for accum, gradient in zip(gradient_accum, gradients)] # gradients are num_batch_splits times higher due to accumulation by summing, so the norm will be too @@ -238,8 +248,10 @@ def init_net(self, next_batch): 'max_grad_norm', 10000.0) * self.cfg['training'].get('num_batch_splits', 1) gradient_accum, self.grad_norm = tf.clip_by_global_norm( gradient_accum, max_grad_norm) - self.train_op = opt_op.apply_gradients( + self.train_op = self.opt_op.apply_gradients( [(accum, gradient[1]) for accum, gradient in zip(gradient_accum, gradients)], global_step=self.global_step) + self.quiet_train_op = self.opt_op.apply_gradients( + [(accum, gradient[1]) for accum, gradient in zip(gradient_accum, gradients)]) correct_policy_prediction = \ tf.equal(tf.argmax(self.y_conv, 1), tf.argmax(self.y_, 1)) @@ -274,6 +286,7 @@ def init_net(self, next_batch): self.session.run(self.init) def replace_weights(self, new_weights): + all_evals = [] for e, weights in enumerate(self.weights): if weights.shape.ndims == 4: # Rescale rule50 related weights as clients do not normalize the input. @@ -295,7 +308,7 @@ def replace_weights(self, new_weights): s = weights.shape.as_list() shape = [s[i] for i in [3, 2, 0, 1]] new_weight = tf.constant(new_weights[e], shape=shape) - self.session.run(weights.assign( + all_evals.append(weights.assign( tf.transpose(new_weight, [2, 3, 1, 0]))) elif weights.shape.ndims == 2: # Fully connected layers are [in, out] in TF @@ -305,12 +318,13 @@ def replace_weights(self, new_weights): s = weights.shape.as_list() shape = [s[i] for i in [1, 0]] new_weight = tf.constant(new_weights[e], shape=shape) - self.session.run(weights.assign( + all_evals.append(weights.assign( tf.transpose(new_weight, [1, 0]))) else: # Biases, batchnorm etc new_weight = tf.constant(new_weights[e], shape=weights.shape) - self.session.run(tf.assign(weights, new_weight)) + all_evals.append(tf.assign(weights, new_weight)) + self.session.run(all_evals) # This should result in identical file to the starting one # self.save_leelaz_weights('restored.txt') @@ -360,6 +374,20 @@ def process(self, batch_size, test_batches, batch_splits=1): lr_boundaries = self.cfg['training']['lr_boundaries'] steps_total = steps % self.cfg['training']['total_steps'] self.lr = lr_values[bisect.bisect_right(lr_boundaries, steps_total)] + lr_search = self.cfg['training'].get('lr_search', False) + lr_searching = lr_search and steps % self.cfg['training']['lr_search_freq'] == 0 and steps != 0 + if lr_search: + if self.last_lr_cached is None: + self.last_lr_cached = self.session.run(self.last_lr) + if self.last_lr_cached > 0: + self.lr = self.last_lr_cached + if self.target_lr is not None: + target_progress = steps % self.cfg['training']['lr_search_freq'] + if target_progress == 0: + self.lr = self.target_lr + else: + self.lr = self.lr + (self.target_lr - self.lr) * (target_progress / self.cfg['training']['lr_search_freq']) + raw_lr = self.lr if self.warmup_steps > 0 and steps < self.warmup_steps: self.lr = self.lr * (steps + 1) / self.warmup_steps @@ -383,6 +411,32 @@ def process(self, batch_size, test_batches, batch_splits=1): self.avg_value_loss.append(value_loss) self.avg_mse_loss.append(mse_loss) self.avg_reg_term.append(reg_term) + if lr_searching: + if self.opt_backup_op is None: + self.opt_backup_op = [var.assign(val) for var, val in zip(self.backup_vars, tf.trainable_variables())] +\ + [var.assign(val) for var, val in zip(self.backup_momentums, [self.opt_op.get_slot(var, "momentum") for var in tf.trainable_variables()])] + if self.opt_restore_op is None: + self.opt_restore_op = [val.assign(var) for var, val in zip(self.backup_vars, tf.trainable_variables())] +\ + [val.assign(var) for var, val in zip(self.backup_momentums, [self.opt_op.get_slot(var, "momentum") for var in tf.trainable_variables()])] + self.session.run(self.opt_backup_op) + best_reg_term = None + best_x = 0 + for x in np.arange(-1, 1, 0.1): + corrected_lr = raw_lr*(2**x) / batch_splits + _, grad_norm = self.session.run([self.quiet_train_op, self.grad_norm], + feed_dict={self.learning_rate: corrected_lr, self.training: True, self.handle: self.train_handle}) + new_reg = self.session.run(self.reg_term) + if best_reg_term is None or new_reg < best_reg_term: + best_reg_term = new_reg + best_x = x + print("LR Search {} {}".format(raw_lr*(2**x), new_reg)) + self.session.run(self.opt_restore_op) + self.last_lr_cached = raw_lr + self.session.run(self.last_lr.assign(raw_lr)) + best_x = best_x / 5 + self.target_lr = raw_lr*(2**best_x) + print("LR Target {} {}".format(raw_lr*(2**best_x), best_reg_term)) + # Gradients of batch splits are summed, not averaged like usual, so need to scale lr accordingly to correct for this. corrected_lr = self.lr / batch_splits _, grad_norm = self.session.run([self.train_op, self.grad_norm], @@ -437,6 +491,9 @@ def process(self, batch_size, test_batches, batch_splits=1): if self.swa_enabled: self.calculate_swa_summaries(test_batches, steps) + if lr_search and steps % self.cfg['training']['total_steps'] == 0: + self.session.run(self.last_lr.assign(self.target_lr)) + # Save session and weights at end, and also optionally every 'checkpoint_steps'. if steps % self.cfg['training']['total_steps'] == 0 or ( 'checkpoint_steps' in self.cfg['training'] and steps % self.cfg['training']['checkpoint_steps'] == 0):