Skip to content

Commit

Permalink
add haps configuration (cycle lr schedule) (#709)
Browse files Browse the repository at this point in the history
* add haps configuration (cycle lr schedule)

* tiny fixes
  • Loading branch information
blahBlahhhJ authored Sep 2, 2024
1 parent ffa8e28 commit fd7888d
Showing 1 changed file with 42 additions and 22 deletions.
64 changes: 42 additions & 22 deletions src/levanter/optim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,16 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):

min_lr_ratio: float = 0.1
warmup_ratio: Optional[float] = None # Deprecated. fraction of training steps to use as warmup
"""The lr scheduler operates on 4 stages: [warmup] - [stable] - [decay] - [cooldown]"""
"""The lr scheduler operates on 4 stages: [warmup] - {[stable] - [decay]} x haps - [cooldown]"""
warmup: float = 0.01
"""fraction of training steps to use as warmup, or steps to use. 0.0 means no warmup"""
stable: float = 0.00
"""fraction of training steps to use as stable, or steps to use. 0.0 means no stable"""
cooldown: float = 0.0
"""fraction of training steps to use as cooldown, or steps to use. 0.0 means no cooldown"""
lr_schedule: str = "cosine" # constant, cosine, linear
haps: Optional[list[int]] = None
"""list of integers indicating pit stop steps. See paper https://openreview.net/pdf?id=RSsavSvAvN"""
weight_decay_modules: Optional[list[str] | str] = None
"""A regex or a list of strings to identify where to mask weight.
For nano-GPT, this field can be set as `r".*attn.*weight|.*mlp.*weight|.*token_embeddings|.*position_embeddings"`"""
Expand Down Expand Up @@ -138,22 +140,13 @@ def mask_fn(model):

def lr_scheduler(self, num_train_steps):
warmup_steps = self._convert_warmup(num_train_steps)
stable_steps = _convert_ratio_or_steps(self.stable, num_train_steps)
cooldown_steps = _convert_ratio_or_steps(self.cooldown, num_train_steps)
lr_decay_steps = num_train_steps - warmup_steps - stable_steps - cooldown_steps
min_lr = self.learning_rate * self.min_lr_ratio
if self.haps is None:
self.haps = []
self.haps.insert(0, warmup_steps)
self.haps.append(num_train_steps - cooldown_steps)

match self.lr_schedule:
case "constant":
schedule = optax.constant_schedule(self.learning_rate)
case "cosine":
schedule = optax.cosine_decay_schedule(self.learning_rate, lr_decay_steps, self.min_lr_ratio)
case "linear":
schedule = optax.linear_schedule(self.learning_rate, min_lr, lr_decay_steps)
case "inv_sqrt":
schedule = _inv_sqrt_decay_schedule(self.learning_rate, min_lr, warmup_steps, 10000)
case _:
raise ValueError(f"Unknown lr_schedule: {self.lr_schedule}")
min_lr = self.learning_rate * self.min_lr_ratio

schedules = []
boundaries = []
Expand All @@ -163,18 +156,37 @@ def lr_scheduler(self, num_train_steps):
schedules.append(warmup)
boundaries.append(warmup_steps)

if stable_steps != 0:
stable = optax.constant_schedule(self.learning_rate)
schedules.append(stable)
boundaries.append(warmup_steps + stable_steps)

schedules.append(schedule)
for start, end in zip(self.haps[:-1], self.haps[1:]):
cycle_steps = end - start
stable_steps = _convert_ratio_or_steps(self.stable, cycle_steps)
lr_decay_steps = cycle_steps - stable_steps

if stable_steps != 0:
stable = optax.constant_schedule(self.learning_rate)
schedules.append(stable)
boundaries.append(start + stable_steps)

match self.lr_schedule:
case "constant":
schedule = optax.constant_schedule(self.learning_rate)
case "cosine":
schedule = optax.cosine_decay_schedule(self.learning_rate, lr_decay_steps, self.min_lr_ratio)
case "linear":
schedule = optax.linear_schedule(self.learning_rate, min_lr, lr_decay_steps)
case "inv_sqrt":
schedule = _inv_sqrt_decay_schedule(self.learning_rate, min_lr, warmup_steps, 10000)
case "inv":
schedule = _inv_decay_schedule(self.learning_rate, min_lr, lr_decay_steps)
case _:
raise ValueError(f"Unknown lr_schedule: {self.lr_schedule}")

schedules.append(schedule)
boundaries.append(end)

if cooldown_steps != 0:
final_main_lr = schedule(lr_decay_steps)
cooldown = optax.linear_schedule(final_main_lr, min_lr, cooldown_steps)
schedules.append(cooldown)
boundaries.append(num_train_steps - cooldown_steps)

if len(schedules) > 1:
schedule = optax.join_schedules(schedules, boundaries)
Expand All @@ -197,6 +209,14 @@ def schedule(count):
return schedule


def _inv_decay_schedule(lr: float, min_lr: float, decay_steps: int):
def schedule(count):
decay = jnp.minimum(1.0, 1.0 / ((lr / min_lr - 1) * jnp.maximum(count, 1) / decay_steps + 1))
return jnp.maximum(lr * decay, min_lr)

return schedule


def _convert_ratio_or_steps(ratio_or_steps: float, num_train_steps: int):
if ratio_or_steps < 1.0:
return int(ratio_or_steps * num_train_steps)
Expand Down

0 comments on commit fd7888d

Please sign in to comment.