Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reverse weight decay #567

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ shared memory implementation can be used by passing `use_legacy_shared_mem_impl`
- Added MMLU multiple choice (A/B/C/D) 5-shot variant downstream tasks
- Tokenizer patch
- Added option to specify number of model replicas when using hybrid sharding.
- Added reverse_embedding_decay option.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name also needs to be updated.


### Changed

Expand Down
4 changes: 3 additions & 1 deletion olmo/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@
gc_cuda,
get_fs_local_rank,
get_global_rank,
get_world_size, get_local_world_size, get_local_rank,
get_local_rank,
get_local_world_size,
get_world_size,
)
from .util import (
_get_s3_client,
Expand Down
6 changes: 6 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,12 @@ class OptimizerConfig(BaseConfig):
If not set, defaults to the wandb `log_interval`.
"""

reverse_embedding_decay: bool = False
"""
Applying weight decay to embeddings may make them too small, potentially causing spikes.
Setting this parameter to true is a way of applying "reverse weight decay" to embeddings.
dirkgr marked this conversation as resolved.
Show resolved Hide resolved
"""

def __post_init__(self):
self.betas = tuple(self.betas) # type: ignore[assignment]

Expand Down
29 changes: 24 additions & 5 deletions olmo/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def clip_grads_and_collect_metrics(
global_step: int,
collect_param_metrics: bool = True,
process_group: Optional[dist.ProcessGroup] = None,
reverse_embedding_decay: bool = False,
dirkgr marked this conversation as resolved.
Show resolved Hide resolved
) -> Dict[str, torch.Tensor]:
"""
Clips gradients for every group that has the field `max_grad_norm`.
Expand Down Expand Up @@ -83,13 +84,14 @@ def clip_grads_and_collect_metrics(
# with ReLoRa, for example.
assert group.get("sharded", True) is True

is_embedding_group = group["name"] == "embedding_decay_group"
for name, p in zip(group["param_names"], group["params"]):
name = self._clean_param_name(name)
# Always need to collect the norm of gradients for clipping, even if we're not collecting
# Always need to collect the norm of gradients and parameters for clipping, even if we're not collecting
# other metrics.
tensors: List[Optional[torch.Tensor]] = [p.grad]
prefixes: List[str] = [f"grad/{name}"]
if collect_param_metrics:
if collect_param_metrics or (reverse_embedding_decay and is_embedding_group):
state = self.get_state_for_param(p)
sorted_state_keys = sorted([k for k in state.keys()])
tensors.extend([p] + [state[key] for key in sorted_state_keys])
Expand Down Expand Up @@ -232,7 +234,7 @@ def is_grad_norm_metric(metric_name: str) -> bool:
all_metrics["clipping_rate"] = clipping_rate
return all_metrics
else:
return {}
return all_metrics

@torch.no_grad()
def _do_adaptive_clipping(
Expand Down Expand Up @@ -617,6 +619,7 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]]
# Separate out parameters that we don't want to apply weight decay to, like norms and biases.
decay = set()
no_decay = set()
embeddings_decay = set()
all_params = {}
for mn, m in model.named_modules():
for pn, p in m.named_parameters():
Expand Down Expand Up @@ -644,12 +647,14 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]]
elif pn.endswith("weight") and isinstance(m, nn.Embedding):
if cfg.optimizer.decay_embeddings:
decay.add(fpn)
elif cfg.optimizer.reverse_embedding_decay:
embeddings_decay.add(fpn)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if these are both set? We should check against that somewhere.

else:
no_decay.add(fpn)

# Validate that we've considered every parameter
inter_params = decay & no_decay
union_params = decay | no_decay
inter_params = decay & no_decay & embeddings_decay
union_params = decay | no_decay | embeddings_decay
assert len(inter_params) == 0, f"parameters {inter_params} made it into both decay/no_decay sets!"
assert (
len(all_params.keys() - union_params) == 0
Expand All @@ -658,12 +663,15 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]]
# Create the pytorch optimizer groups.
decay_sorted = sorted(list(decay))
no_decay_sorted = sorted(list(no_decay))
embeddings_decay_sorted = sorted(list(embeddings_decay))

param_groups = []
if len(decay_sorted) > 0:
param_groups.append(
{
"params": [all_params[pn] for pn in decay_sorted],
"param_names": decay_sorted,
"name": "decay_group",
**param_group_defaults,
}
)
Expand All @@ -673,6 +681,17 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]]
"params": [all_params[pn] for pn in no_decay_sorted],
"param_names": no_decay_sorted,
"weight_decay": 0.0,
"name": "no_decay_group",
**param_group_defaults,
}
)
if len(embeddings_decay_sorted) > 0:
# the weight_decay value will be multiplied by emb_decay_factor in olmo/train.py
param_groups.append(
{
"params": [all_params[pn] for pn in embeddings_decay_sorted],
"param_names": embeddings_decay_sorted,
"name": "embedding_decay_group",
**param_group_defaults,
}
)
Expand Down
12 changes: 10 additions & 2 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F
import wandb
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader

import wandb

from .aliases import PathOrStr
from .checkpoint import Checkpointer, FullCheckpointer, build_sharded_checkpointer
from .config import (
Expand Down Expand Up @@ -720,8 +719,14 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) ->
# passing this process group here ensures metrics are reduced correctly when we're using
# HYBRID sharding.
process_group=self.fsdp_model.process_group,
reverse_embedding_decay=self.cfg.optimizer.reverse_embedding_decay,
)

emb_norm = optim_metrics["param/transformer.wte.weight.norm"]
emb_size = self.cfg.model.embedding_size or self.cfg.model.vocab_size
emb_std = math.sqrt(math.pow(emb_norm, 2) / float(emb_size * self.cfg.model.vocab_size))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the denominator should be float(self.cfg.model.d_model * emb_size). And I'm not sure about the numerator either... I don't see how this is equivalent to standard deviation since the summation terms in the norm are not centered by the mean, no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update: @AkshitaB and I discussed this, we think we need to calculate this metric separately in optim.py.

We also talked about how this standard deviation will be a little biased since it will include parts of the embedding that never are never used, since we inflate the embedding size beyond vocab size to be a multiple of 128. But this is probably okay since that's only a small part of the embeddings.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think this is a big problem. Embeddings will want to be small, so this will push them up. Unused, or rarely used embeddings will never get updated, so they will get bigger and bigger, skewing the calculation of the stddev more and more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Figuring out which embeddings to exclude from the stddev computation is going to be tricky in the distributed setting though.

Copy link
Member

@epwalsh epwalsh May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking out loud here... what if we force the unused params to be zero from the beginning? They would still bias standard deviation by as much as they are different from the mean, but they would always be zero.. I think

Copy link
Contributor Author

@AkshitaB AkshitaB May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would work if we were starting with this from scratch, but what about the case when we want to use this to "rescue" a run? Can we explicitly make the unused embeddings zero when we load the model? And will it matter if we do so halfway through training?

Copy link
Member

@epwalsh epwalsh May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we explicitly make the unused embeddings zero when we load the model?

I think that's our best bet. I can't think of any issues that would introduce in the middle of training. I suspect those parameters are 0 anyway due to weight decay and zero gradients.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rare tokens would still be an issue, but not any more than they always are.

emb_decay_factor = 1.0 - emb_std
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're using this to plug into the value for WD, that means it needs to be negative when we want to pull up the values. So then it would be emb_std - 1?


# Adjust the learning rate.
for group in self.optim.param_groups:
# TODO (epwalsh): if we want to enable different LRs or gradient clipping settings per group
Expand All @@ -737,6 +742,9 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) ->
self.cfg.max_grad_norm_ratio, self.scheduler_current, self.scheduler_max
)

if group["name"] == "embedding_decay_group":
group["weight_decay"] *= emb_decay_factor
Comment on lines +745 to +746
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does't this multiply up emb_decay_factor across batches? It feels like this should just be set, not multiplied? Or is there some other bit that resets group["weight_decay"] every time?


# Optimizer step.
self.optim.step()

Expand Down
43 changes: 43 additions & 0 deletions test_fixtures/reverse_wd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
run_name: "reverse_test"
save_folder: "/tmp/olmo-train-tiny"
wandb:
name: ${run_name}
project: reverse-test
model:
d_model: 128
n_heads: 4
n_layers: 4
mlp_ratio: 4
alibi: false
alibi_bias_max: 8.0
attention_dropout: 0.1
attention_layer_norm: false
residual_dropout: 0.1
embedding_dropout: 0.1
max_sequence_length: 512
vocab_size: 50257
eos_token_id: 50256
pad_token_id: 50256
init_device: null
init_std: 0.02
optimizer:
learning_rate: 0.001
reverse_embedding_decay: true
metrics_log_interval: 100
scheduler:
name: "cosine_with_warmup"
t_warmup: 10
data:
paths:
- "/net/nfs.cirrascale/allennlp/llm-data/c4/en/c4-train.00000-00099.npy"
persistent_workers: false
num_workers: 0
prefetch_factor: null
tokenizer:
identifier: "gpt2"
save_overwrite: true
max_duration: 16
stop_at: ${max_duration}
global_train_batch_size: 8
device_train_microbatch_size: 8
precision: "fp32"
Loading