Skip to content

Commit

Permalink
Merge pull request #1111 from NLGithubWP/add_opt_ms
Browse files Browse the repository at this point in the history
Add the implementation of a single optimization step for model selection
  • Loading branch information
chrishkchris authored Oct 12, 2023
2 parents 830e1f7 + 619d119 commit 31ea937
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions examples/model_selection_psql/ms_mlp/train_ms_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,49 @@ def __init__(self,
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError(
"Nesterov momentum requires a momentum and zero dampening")

def apply(self, param_name, param_value, param_grad):
"""Performs a single optimization step.
Args:
param_name(String): the name of the param
param_value(Tensor): param values to be update in-place
grad(Tensor): param gradients; the values may be updated
in this function; cannot use it anymore
"""
assert param_value.shape == param_grad.shape, ("shape mismatch",
param_value.shape,
param_grad.shape)
self.device_check(param_value, self.step_counter, self.lr_value,
self.mom_value, self.dam_value, self.decay_value)

# derive dtype from input
# assert param_value.dtype == self.dtype

# TODO add branch operator
# if self.decay_value != 0:
if self.weight_decay.init_value != 0:
singa.Axpy(self.decay_value.data, param_value.data, param_grad.data)

if self.momentum.init_value != 0:
if param_name not in self.moments:
flag = param_value.device.graph_enabled()
param_value.device.EnableGraph(False)
self.moments[param_name] = tensor.zeros_like(param_value)
param_value.device.EnableGraph(flag)

buf = self.moments[param_name]
buf *= self.mom_value
alpha = 1.0 - self.dam_value
singa.Axpy(alpha.data, param_grad.data, buf.data)

if self.nesterov:
singa.Axpy(self.mom_value.data, buf.data, param_grad.data)
else:
param_grad = buf

minus_lr = 0.0 - self.lr_value
singa.Axpy(minus_lr.data, param_grad.data, param_value.data)

# Data augmentation
def augmentation(x, batch_size):
Expand Down

0 comments on commit 31ea937

Please sign in to comment.