Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reverse weight decay #567

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
4 changes: 3 additions & 1 deletion olmo/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@
gc_cuda,
get_fs_local_rank,
get_global_rank,
get_world_size, get_local_world_size, get_local_rank,
get_local_rank,
get_local_world_size,
get_world_size,
)
from .util import (
_get_s3_client,
Expand Down
6 changes: 6 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,12 @@ class OptimizerConfig(BaseConfig):
If not set, defaults to the wandb `log_interval`.
"""

reverse_embedding_decay: bool = False
"""
Applying weight decay to embeddings may make them too small, potentially causing spikes.
Setting this parameter to true is a way of applying "reverse weight decay" to embeddings.
dirkgr marked this conversation as resolved.
Show resolved Hide resolved
"""

def __post_init__(self):
self.betas = tuple(self.betas) # type: ignore[assignment]

Expand Down
21 changes: 19 additions & 2 deletions olmo/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]]
# Separate out parameters that we don't want to apply weight decay to, like norms and biases.
decay = set()
no_decay = set()
embeddings_decay = set()
all_params = {}
for mn, m in model.named_modules():
for pn, p in m.named_parameters():
Expand Down Expand Up @@ -644,12 +645,14 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]]
elif pn.endswith("weight") and isinstance(m, nn.Embedding):
if cfg.optimizer.decay_embeddings:
decay.add(fpn)
elif cfg.optimizer.reverse_embedding_decay:
embeddings_decay.add(fpn)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if these are both set? We should check against that somewhere.

else:
no_decay.add(fpn)

# Validate that we've considered every parameter
inter_params = decay & no_decay
union_params = decay | no_decay
inter_params = decay & no_decay & embeddings_decay
union_params = decay | no_decay | embeddings_decay
assert len(inter_params) == 0, f"parameters {inter_params} made it into both decay/no_decay sets!"
assert (
len(all_params.keys() - union_params) == 0
Expand All @@ -658,12 +661,15 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]]
# Create the pytorch optimizer groups.
decay_sorted = sorted(list(decay))
no_decay_sorted = sorted(list(no_decay))
embeddings_decay_sorted = sorted(list(embeddings_decay))

param_groups = []
if len(decay_sorted) > 0:
param_groups.append(
{
"params": [all_params[pn] for pn in decay_sorted],
"param_names": decay_sorted,
"name": "decay_group",
**param_group_defaults,
}
)
Expand All @@ -673,6 +679,17 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]]
"params": [all_params[pn] for pn in no_decay_sorted],
"param_names": no_decay_sorted,
"weight_decay": 0.0,
"name": "no_decay_group",
**param_group_defaults,
}
)
if len(embeddings_decay_sorted) > 0:
# the weight_decay value will be multiplied by emb_decay_factor in olmo/train.py
param_groups.append(
{
"params": [all_params[pn] for pn in embeddings_decay_sorted],
"param_names": embeddings_decay_sorted,
"name": "embedding_decay_group",
**param_group_defaults,
}
)
Expand Down
9 changes: 9 additions & 0 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,12 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) ->
process_group=self.fsdp_model.process_group,
)

# TODO: what to do otherwise?
if should_log_optim_metrics_this_step:
emb_decay_factor = 1.0 - optim_metrics["param/transformer.wte.weight.norm"]
else:
emb_decay_factor = 1.0
dirkgr marked this conversation as resolved.
Show resolved Hide resolved

# Adjust the learning rate.
for group in self.optim.param_groups:
# TODO (epwalsh): if we want to enable different LRs or gradient clipping settings per group
Expand All @@ -726,6 +732,9 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) ->
self.cfg.max_grad_norm_ratio, self.scheduler_current, self.scheduler_max
)

if group["name"] == "embedding_decay_group":
group["weight_decay"] *= emb_decay_factor
Comment on lines +745 to +746
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does't this multiply up emb_decay_factor across batches? It feels like this should just be set, not multiplied? Or is there some other bit that resets group["weight_decay"] every time?


# Optimizer step.
self.optim.step()

Expand Down
42 changes: 42 additions & 0 deletions test_fixtures/reverse_wd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
run_name: "reverse_test"
save_folder: "/tmp/olmo-train-tiny"
wandb:
name: ${run_name}
project: reverse-test
model:
d_model: 128
n_heads: 4
n_layers: 4
mlp_ratio: 4
alibi: false
alibi_bias_max: 8.0
attention_dropout: 0.1
attention_layer_norm: false
residual_dropout: 0.1
embedding_dropout: 0.1
max_sequence_length: 512
vocab_size: 50257
eos_token_id: 50256
pad_token_id: 50256
init_device: null
init_std: 0.02
optimizer:
learning_rate: 0.001
reverse_embedding_decay: true
metrics_log_interval: 1
scheduler:
name: "cosine_with_warmup"
t_warmup: 10
data:
paths:
- "test_fixtures/mup-sample-data/part-010-00002.npy"
persistent_workers: false
num_workers: 0
prefetch_factor: null
tokenizer:
identifier: "gpt2"
save_overwrite: true
max_duration: 4
global_train_batch_size: 8
device_train_microbatch_size: 8
precision: "fp32"
Loading