Skip to content

Commit

Permalink
Add files via upload (mlcommons#336)
Browse files Browse the repository at this point in the history
  • Loading branch information
pkanwar23 authored Feb 6, 2020
1 parent b58c18e commit f74bec6
Showing 1 changed file with 253 additions and 0 deletions.
253 changes: 253 additions & 0 deletions translation/tensorflow/bert/lamb_optimizer_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
"""LAMB (Layer-wise Adaptive Moments) optimizer as TF1 tf.train.Optimizer.
See paper [Large Batch Optimization for Deep Learning: Training BERT in 76
minutes](https://arxiv.org/abs/1904.00962).
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import re

# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import optimizer
# pylint: enable=g-direct-tensorflow-import


class LAMBOptimizer(optimizer.Optimizer):
"""Optimizer that implements the LAMBOptimizer as tf.train.Optimizer."""

def __init__(self,
learning_rate=0.001,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-6,
weight_decay_rate=0.0,
exclude_from_weight_decay=None,
exclude_from_layer_adaptation=None,
use_locking=False,
name="LAMB"):
super(LAMBOptimizer, self).__init__(use_locking, name)
self._lr = learning_rate
self._beta1 = beta_1
self._beta2 = beta_2
self._epsilon = epsilon
self._weight_decay_rate = weight_decay_rate
self.exclude_from_weight_decay = exclude_from_weight_decay
# exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
# arg is None.
if exclude_from_layer_adaptation:
self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
else:
self.exclude_from_layer_adaptation = exclude_from_weight_decay

# Tensor versions of the constructor arguments, created in _prepare().
self._lr_t = None
self._beta1_t = None
self._beta2_t = None
self._epsilon_t = None
self._weight_decay_rate_t = None

def _get_beta_accumulators(self):
with ops.init_scope():
if context.executing_eagerly():
graph = None
else:
graph = ops.get_default_graph()
return (self._get_non_slot_variable("beta1_power", graph=graph),
self._get_non_slot_variable("beta2_power", graph=graph))

def _create_slots(self, var_list):
# Create the beta1 and beta2 accumulators on the same device as the first
# variable. Sort the var_list to make sure this device is consistent across
# workers (these need to go on the same PS, otherwise some updates are
# silently ignored).
first_var = min(var_list, key=lambda x: x.name)
self._create_non_slot_variable(
initial_value=self._beta1, name="beta1_power", colocate_with=first_var)
self._create_non_slot_variable(
initial_value=self._beta2, name="beta2_power", colocate_with=first_var)

# Create slots for the first and second moments.
for v in var_list:
self._zeros_slot(v, "m", self._name)
self._zeros_slot(v, "v", self._name)

def _prepare(self):
lr = self._call_if_callable(self._lr)
beta1 = self._call_if_callable(self._beta1)
beta2 = self._call_if_callable(self._beta2)
epsilon = self._call_if_callable(self._epsilon)
weight_decay_rate = self._call_if_callable(self._weight_decay_rate)

self._lr_t = ops.convert_to_tensor(lr, name="learning_rate")
self._beta1_t = ops.convert_to_tensor(beta1, name="beta1")
self._beta2_t = ops.convert_to_tensor(beta2, name="beta2")
self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon")
self._weight_decay_rate_t = ops.convert_to_tensor(
weight_decay_rate, name="weight_decay_rate")

def _apply_dense(self, grad, var):
return self._resource_apply_dense(grad, var)

def _resource_apply_dense(self, grad, var):
beta1_power, beta2_power = self._get_beta_accumulators()
beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
weight_decay_rate_t = math_ops.cast(self._weight_decay_rate_t,
var.dtype.base_dtype)
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, "m")
m_scaled_g_values = grad * (1 - beta1_t)
m_t = m * beta1_t + m_scaled_g_values
m_t = state_ops.assign(m, m_t, use_locking=self._use_locking)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, "v")
v_scaled_g_values = (grad * grad) * (1 - beta2_t)
v_t = v * beta2_t + v_scaled_g_values
v_t = state_ops.assign(v, v_t, use_locking=self._use_locking)

# ==== The following is with m_t_hat and v_t_hat
m_t_hat = m_t / (1. - beta1_power)
v_t_hat = v_t / (1. - beta2_power)

v_sqrt = math_ops.sqrt(v_t_hat)
update = m_t_hat / (v_sqrt + epsilon_t)

# ==== The following is the original LAMBOptimizer implementation
# v_sqrt = math_ops.sqrt(v_t_hat)
# update = m_t / (v_sqrt + epsilon_t)

var_name = self._get_variable_name(var.name)
if self._do_use_weight_decay(var_name):
update += weight_decay_rate_t * var

ratio = 1.0
if self._do_layer_adaptation(var_name):
w_norm = linalg_ops.norm(var, ord=2)
g_norm = linalg_ops.norm(update, ord=2)
ratio = array_ops.where(
math_ops.greater(w_norm, 0),
array_ops.where(math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0),
1.0)

var_update = var - ratio * lr_t * update
return state_ops.assign(var, var_update, use_locking=self._use_locking).op

def _apply_sparse_shared(self, grad, var, indices, scatter_add):
beta1_power, beta2_power = self._get_beta_accumulators()
beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
weight_decay_rate_t = math_ops.cast(self._weight_decay_rate_t,
var.dtype.base_dtype)
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, "m")
m_scaled_g_values = grad * (1 - beta1_t)
m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking)
with ops.control_dependencies([m_t]):
m_t = scatter_add(m, indices, m_scaled_g_values)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, "v")
v_scaled_g_values = (grad * grad) * (1 - beta2_t)
v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
with ops.control_dependencies([v_t]):
v_t = scatter_add(v, indices, v_scaled_g_values)

# ==== The following is with m_t_hat and v_t_hat
m_t_hat = m_t / (1. - beta1_power)
v_t_hat = v_t / (1. - beta2_power)

v_sqrt = math_ops.sqrt(v_t_hat)
update = m_t_hat / (v_sqrt + epsilon_t)

# ==== The following is the original LAMBOptimizer implementation
# v_sqrt = math_ops.sqrt(v_t_hat)
# update = m_t / (v_sqrt + epsilon_t)

var_name = self._get_variable_name(var.name)
if self._do_use_weight_decay(var_name):
update += weight_decay_rate_t * var

ratio = 1.0
if self._do_layer_adaptation(var_name):
w_norm = linalg_ops.norm(var, ord=2)
g_norm = linalg_ops.norm(update, ord=2)
ratio = array_ops.where(
math_ops.greater(w_norm, 0),
array_ops.where(math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0),
1.0)
var_update = state_ops.assign_sub(
var, ratio * lr_t * update, use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, m_t, v_t])

def _apply_sparse(self, grad, var):
return self._apply_sparse_shared(
grad.values,
var,
grad.indices,
lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
x,
i,
v,
use_locking=self._use_locking))

def _resource_scatter_add(self, x, i, v):
with ops.control_dependencies(
[resource_variable_ops.resource_scatter_add(x.handle, i, v)]):
return x.value()

def _resource_apply_sparse(self, grad, var, indices):
return self._apply_sparse_shared(grad, var, indices,
self._resource_scatter_add)

def _finish(self, update_ops, name_scope):
# Update the power accumulators.
with ops.control_dependencies(update_ops):
beta1_power, beta2_power = self._get_beta_accumulators()
with ops.colocate_with(beta1_power):
update_beta1 = beta1_power.assign(
beta1_power * self._beta1_t, use_locking=self._use_locking)
update_beta2 = beta2_power.assign(
beta2_power * self._beta2_t, use_locking=self._use_locking)
return control_flow_ops.group(
*update_ops + [update_beta1, update_beta2], name=name_scope)

def _do_use_weight_decay(self, param_name):
"""Whether to use L2 weight decay for `param_name`."""
if self.exclude_from_weight_decay:
for r in self.exclude_from_weight_decay:
if re.search(r, param_name) is not None:
return False
return True

def _do_layer_adaptation(self, param_name):
"""Whether to do layer-wise learning rate adaptation for `param_name`."""
if self.exclude_from_layer_adaptation:
for r in self.exclude_from_layer_adaptation:
if re.search(r, param_name) is not None:
return False
return True

def _get_variable_name(self, param_name):
"""Get the variable name from the tensor name."""
m = re.match("^(.*):\\d+$", param_name)
if m is not None:
param_name = m.group(1)
return param_name

0 comments on commit f74bec6

Please sign in to comment.