Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reverse weight decay #567

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open

Reverse weight decay #567

wants to merge 14 commits into from

Conversation

AkshitaB
Copy link
Contributor

@AkshitaB AkshitaB commented May 3, 2024

Goal: Perform reverse weight decay on embeddings

Multiply weight_decay factor for the embeddings layer by (1 - norm(embeddings))

TODO:

  • What to do when the log metric interval is > 1?

I tried this on a tiny test model config and got an overflow error. Possibly this will not be an issue with the actual model.

Note: I created the branch from train-olmo-large. See this for actual diffs for this PR.

Copy link
Member

@dirkgr dirkgr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this PR needs to go into the train-olmo-large branch, no?

olmo/train.py Outdated Show resolved Hide resolved
@dirkgr
Copy link
Member

dirkgr commented May 3, 2024 via email

@AkshitaB
Copy link
Contributor Author

AkshitaB commented May 6, 2024

You are right. Then we need to make sure we compute this every time.

Done

@AkshitaB AkshitaB requested a review from dirkgr May 6, 2024 17:15
olmo/config.py Outdated Show resolved Hide resolved
@dirkgr dirkgr requested a review from epwalsh May 6, 2024 20:18
@dirkgr
Copy link
Member

dirkgr commented May 6, 2024

@epwalsh , can you look at this as well? This gets all up in your code.

olmo/optim.py Outdated Show resolved Hide resolved
olmo/optim.py Outdated
Comment on lines 648 to 651
if cfg.optimizer.decay_embeddings:
decay.add(fpn)
elif cfg.optimizer.reverse_embedding_decay:
embeddings_decay.add(fpn)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if these are both set? We should check against that somewhere.

CHANGELOG.md Outdated
@@ -23,6 +23,7 @@ shared memory implementation can be used by passing `use_legacy_shared_mem_impl`
- Added MMLU multiple choice (A/B/C/D) 5-shot variant downstream tasks
- Tokenizer patch
- Added option to specify number of model replicas when using hybrid sharding.
- Added reverse_embedding_decay option.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name also needs to be updated.

@@ -43,6 +43,7 @@ def clip_grads_and_collect_metrics(
global_step: int,
collect_param_metrics: bool = True,
process_group: Optional[dist.ProcessGroup] = None,
regularize_embeddings: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this a parameter to this function? Shouldn't it be just captured in the parameter groups? That's how all the other regularization works.

Comment on lines +745 to +746
if group["name"] == "embedding_decay_group":
group["weight_decay"] *= emb_decay_factor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does't this multiply up emb_decay_factor across batches? It feels like this should just be set, not multiplied? Or is there some other bit that resets group["weight_decay"] every time?

emb_norm = optim_metrics["param/transformer.wte.weight.norm"]
emb_size = self.cfg.model.embedding_size or self.cfg.model.vocab_size
emb_std = math.sqrt(math.pow(emb_norm, 2) / float(emb_size * self.cfg.model.vocab_size))
emb_decay_factor = 1.0 - emb_std
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're using this to plug into the value for WD, that means it needs to be negative when we want to pull up the values. So then it would be emb_std - 1?

)

emb_norm = optim_metrics["param/transformer.wte.weight.norm"]
emb_size = self.cfg.model.embedding_size or self.cfg.model.vocab_size
emb_std = math.sqrt(math.pow(emb_norm, 2) / float(emb_size * self.cfg.model.vocab_size))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the denominator should be float(self.cfg.model.d_model * emb_size). And I'm not sure about the numerator either... I don't see how this is equivalent to standard deviation since the summation terms in the norm are not centered by the mean, no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update: @AkshitaB and I discussed this, we think we need to calculate this metric separately in optim.py.

We also talked about how this standard deviation will be a little biased since it will include parts of the embedding that never are never used, since we inflate the embedding size beyond vocab size to be a multiple of 128. But this is probably okay since that's only a small part of the embeddings.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think this is a big problem. Embeddings will want to be small, so this will push them up. Unused, or rarely used embeddings will never get updated, so they will get bigger and bigger, skewing the calculation of the stddev more and more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Figuring out which embeddings to exclude from the stddev computation is going to be tricky in the distributed setting though.

Copy link
Member

@epwalsh epwalsh May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking out loud here... what if we force the unused params to be zero from the beginning? They would still bias standard deviation by as much as they are different from the mean, but they would always be zero.. I think

Copy link
Contributor Author

@AkshitaB AkshitaB May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would work if we were starting with this from scratch, but what about the case when we want to use this to "rescue" a run? Can we explicitly make the unused embeddings zero when we load the model? And will it matter if we do so halfway through training?

Copy link
Member

@epwalsh epwalsh May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we explicitly make the unused embeddings zero when we load the model?

I think that's our best bet. I can't think of any issues that would introduce in the middle of training. I suspect those parameters are 0 anyway due to weight decay and zero gradients.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rare tokens would still be an issue, but not any more than they always are.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants