Skip to content

Commit 10ad24d

Browse files
authored
Add Cosine LR Scheduler for Fine-Tuning (#273)
* Port cosine lr scheduler init * Port cosine lr scheduler init * Upgrade 1.4.7 * Typos, type error * Use subclasses instead of validation * Update num_cycles description * Change cli arg from num_cycles to scheduler_num_cycles * Update version
1 parent 7346356 commit 10ad24d

File tree

5 files changed

+81
-16
lines changed

5 files changed

+81
-16
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
1212

1313
[tool.poetry]
1414
name = "together"
15-
version = "1.5.1"
15+
version = "1.5.2"
1616
authors = [
1717
"Together AI <[email protected]>"
1818
]

src/together/cli/api/finetune.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -79,17 +79,29 @@ def fine_tuning(ctx: click.Context) -> None:
7979
"--batch-size", "-b", type=INT_WITH_MAX, default="max", help="Train batch size"
8080
)
8181
@click.option("--learning-rate", "-lr", type=float, default=1e-5, help="Learning rate")
82+
@click.option(
83+
"--lr-scheduler-type",
84+
type=click.Choice(["linear", "cosine"]),
85+
default="linear",
86+
help="Learning rate scheduler type",
87+
)
8288
@click.option(
8389
"--min-lr-ratio",
8490
type=float,
8591
default=0.0,
8692
help="The ratio of the final learning rate to the peak learning rate",
8793
)
94+
@click.option(
95+
"--scheduler-num-cycles",
96+
type=float,
97+
default=0.5,
98+
help="Number or fraction of cycles for the cosine learning rate scheduler.",
99+
)
88100
@click.option(
89101
"--warmup-ratio",
90102
type=float,
91103
default=0.0,
92-
help="Warmup ratio for learning rate scheduler.",
104+
help="Warmup ratio for the learning rate scheduler.",
93105
)
94106
@click.option(
95107
"--max-grad-norm",
@@ -174,7 +186,9 @@ def create(
174186
n_checkpoints: int,
175187
batch_size: int | Literal["max"],
176188
learning_rate: float,
189+
lr_scheduler_type: Literal["linear", "cosine"],
177190
min_lr_ratio: float,
191+
scheduler_num_cycles: float,
178192
warmup_ratio: float,
179193
max_grad_norm: float,
180194
weight_decay: float,
@@ -206,7 +220,9 @@ def create(
206220
n_checkpoints=n_checkpoints,
207221
batch_size=batch_size,
208222
learning_rate=learning_rate,
223+
lr_scheduler_type=lr_scheduler_type,
209224
min_lr_ratio=min_lr_ratio,
225+
scheduler_num_cycles=scheduler_num_cycles,
210226
warmup_ratio=warmup_ratio,
211227
max_grad_norm=max_grad_norm,
212228
weight_decay=weight_decay,

src/together/resources/finetune.py

+35-6
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
TogetherRequest,
2323
TrainingType,
2424
FinetuneLRScheduler,
25+
FinetuneLinearLRScheduler,
26+
FinetuneCosineLRScheduler,
2527
FinetuneLinearLRSchedulerArgs,
28+
FinetuneCosineLRSchedulerArgs,
2629
TrainingMethodDPO,
2730
TrainingMethodSFT,
2831
FinetuneCheckpoint,
@@ -57,7 +60,9 @@ def createFinetuneRequest(
5760
n_checkpoints: int | None = 1,
5861
batch_size: int | Literal["max"] = "max",
5962
learning_rate: float | None = 0.00001,
63+
lr_scheduler_type: Literal["linear", "cosine"] = "linear",
6064
min_lr_ratio: float = 0.0,
65+
scheduler_num_cycles: float = 0.5,
6166
warmup_ratio: float = 0.0,
6267
max_grad_norm: float = 1.0,
6368
weight_decay: float = 0.0,
@@ -134,10 +139,22 @@ def createFinetuneRequest(
134139
f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}"
135140
)
136141

137-
lrScheduler = FinetuneLRScheduler(
138-
lr_scheduler_type="linear",
139-
lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
140-
)
142+
# Default to generic lr scheduler
143+
lrScheduler: FinetuneLRScheduler = FinetuneLRScheduler(lr_scheduler_type="linear")
144+
145+
if lr_scheduler_type == "cosine":
146+
if scheduler_num_cycles <= 0.0:
147+
raise ValueError("Number of cycles should be greater than 0")
148+
149+
lrScheduler = FinetuneCosineLRScheduler(
150+
lr_scheduler_args=FinetuneCosineLRSchedulerArgs(
151+
min_lr_ratio=min_lr_ratio, num_cycles=scheduler_num_cycles
152+
),
153+
)
154+
else:
155+
lrScheduler = FinetuneLinearLRScheduler(
156+
lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
157+
)
141158

142159
training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT()
143160
if training_method == "dpo":
@@ -249,7 +266,9 @@ def create(
249266
n_checkpoints: int | None = 1,
250267
batch_size: int | Literal["max"] = "max",
251268
learning_rate: float | None = 0.00001,
269+
lr_scheduler_type: Literal["linear", "cosine"] = "linear",
252270
min_lr_ratio: float = 0.0,
271+
scheduler_num_cycles: float = 0.5,
253272
warmup_ratio: float = 0.0,
254273
max_grad_norm: float = 1.0,
255274
weight_decay: float = 0.0,
@@ -284,9 +303,11 @@ def create(
284303
batch_size (int or "max"): Batch size for fine-tuning. Defaults to max.
285304
learning_rate (float, optional): Learning rate multiplier to use for training
286305
Defaults to 0.00001.
306+
lr_scheduler_type (Literal["linear", "cosine"]): Learning rate scheduler type. Defaults to "linear".
287307
min_lr_ratio (float, optional): Min learning rate ratio of the initial learning rate for
288308
the learning rate scheduler. Defaults to 0.0.
289-
warmup_ratio (float, optional): Warmup ratio for learning rate scheduler.
309+
scheduler_num_cycles (float, optional): Number or fraction of cycles for the cosine learning rate scheduler. Defaults to 0.5.
310+
warmup_ratio (float, optional): Warmup ratio for the learning rate scheduler.
290311
max_grad_norm (float, optional): Max gradient norm. Defaults to 1.0, set to 0 to disable.
291312
weight_decay (float, optional): Weight decay. Defaults to 0.0.
292313
lora (bool, optional): Whether to use LoRA adapters. Defaults to True.
@@ -353,7 +374,9 @@ def create(
353374
n_checkpoints=n_checkpoints,
354375
batch_size=batch_size,
355376
learning_rate=learning_rate,
377+
lr_scheduler_type=lr_scheduler_type,
356378
min_lr_ratio=min_lr_ratio,
379+
scheduler_num_cycles=scheduler_num_cycles,
357380
warmup_ratio=warmup_ratio,
358381
max_grad_norm=max_grad_norm,
359382
weight_decay=weight_decay,
@@ -634,7 +657,9 @@ async def create(
634657
n_checkpoints: int | None = 1,
635658
batch_size: int | Literal["max"] = "max",
636659
learning_rate: float | None = 0.00001,
660+
lr_scheduler_type: Literal["linear", "cosine"] = "linear",
637661
min_lr_ratio: float = 0.0,
662+
scheduler_num_cycles: float = 0.5,
638663
warmup_ratio: float = 0.0,
639664
max_grad_norm: float = 1.0,
640665
weight_decay: float = 0.0,
@@ -669,9 +694,11 @@ async def create(
669694
batch_size (int, optional): Batch size for fine-tuning. Defaults to max.
670695
learning_rate (float, optional): Learning rate multiplier to use for training
671696
Defaults to 0.00001.
697+
lr_scheduler_type (Literal["linear", "cosine"]): Learning rate scheduler type. Defaults to "linear".
672698
min_lr_ratio (float, optional): Min learning rate ratio of the initial learning rate for
673699
the learning rate scheduler. Defaults to 0.0.
674-
warmup_ratio (float, optional): Warmup ratio for learning rate scheduler.
700+
scheduler_num_cycles (float, optional): Number or fraction of cycles for the cosine learning rate scheduler. Defaults to 0.5.
701+
warmup_ratio (float, optional): Warmup ratio for the learning rate scheduler.
675702
max_grad_norm (float, optional): Max gradient norm. Defaults to 1.0, set to 0 to disable.
676703
weight_decay (float, optional): Weight decay. Defaults to 0.0.
677704
lora (bool, optional): Whether to use LoRA adapters. Defaults to True.
@@ -738,7 +765,9 @@ async def create(
738765
n_checkpoints=n_checkpoints,
739766
batch_size=batch_size,
740767
learning_rate=learning_rate,
768+
lr_scheduler_type=lr_scheduler_type,
741769
min_lr_ratio=min_lr_ratio,
770+
scheduler_num_cycles=scheduler_num_cycles,
742771
warmup_ratio=warmup_ratio,
743772
max_grad_norm=max_grad_norm,
744773
weight_decay=weight_decay,

src/together/types/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,14 @@
3434
TrainingMethodDPO,
3535
TrainingMethodSFT,
3636
FinetuneCheckpoint,
37+
FinetuneCosineLRScheduler,
38+
FinetuneCosineLRSchedulerArgs,
3739
FinetuneDownloadResult,
40+
FinetuneLinearLRScheduler,
3841
FinetuneLinearLRSchedulerArgs,
42+
FinetuneLRScheduler,
3943
FinetuneList,
4044
FinetuneListEvents,
41-
FinetuneLRScheduler,
4245
FinetuneRequest,
4346
FinetuneResponse,
4447
FinetuneTrainingLimits,
@@ -69,7 +72,10 @@
6972
"FinetuneListEvents",
7073
"FinetuneDownloadResult",
7174
"FinetuneLRScheduler",
75+
"FinetuneLinearLRScheduler",
7276
"FinetuneLinearLRSchedulerArgs",
77+
"FinetuneCosineLRScheduler",
78+
"FinetuneCosineLRSchedulerArgs",
7379
"FileRequest",
7480
"FileResponse",
7581
"FileList",

src/together/types/finetune.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from __future__ import annotations
22

33
from enum import Enum
4-
from typing import List, Literal
4+
from typing import List, Literal, Union
55

6-
from pydantic import StrictBool, Field, validator, field_validator
6+
from pydantic import StrictBool, Field, validator, field_validator, ValidationInfo
77

88
from together.types.abstract import BaseModel
99
from together.types.common import (
@@ -176,7 +176,7 @@ class FinetuneRequest(BaseModel):
176176
# training learning rate
177177
learning_rate: float
178178
# learning rate scheduler type and args
179-
lr_scheduler: FinetuneLRScheduler | None = None
179+
lr_scheduler: FinetuneLinearLRScheduler | FinetuneCosineLRScheduler | None = None
180180
# learning rate warmup ratio
181181
warmup_ratio: float
182182
# max gradient norm
@@ -239,7 +239,7 @@ class FinetuneResponse(BaseModel):
239239
# training learning rate
240240
learning_rate: float | None = None
241241
# learning rate scheduler type and args
242-
lr_scheduler: FinetuneLRScheduler | None = None
242+
lr_scheduler: FinetuneLinearLRScheduler | FinetuneCosineLRScheduler | None = None
243243
# learning rate warmup ratio
244244
warmup_ratio: float | None = None
245245
# max gradient norm
@@ -345,13 +345,27 @@ class FinetuneTrainingLimits(BaseModel):
345345
lora_training: FinetuneLoraTrainingLimits | None = None
346346

347347

348+
class FinetuneLinearLRSchedulerArgs(BaseModel):
349+
min_lr_ratio: float | None = 0.0
350+
351+
352+
class FinetuneCosineLRSchedulerArgs(BaseModel):
353+
min_lr_ratio: float | None = 0.0
354+
num_cycles: float | None = 0.5
355+
356+
348357
class FinetuneLRScheduler(BaseModel):
349358
lr_scheduler_type: str
350-
lr_scheduler_args: FinetuneLinearLRSchedulerArgs | None = None
351359

352360

353-
class FinetuneLinearLRSchedulerArgs(BaseModel):
354-
min_lr_ratio: float | None = 0.0
361+
class FinetuneLinearLRScheduler(FinetuneLRScheduler):
362+
lr_scheduler_type: Literal["linear"] = "linear"
363+
lr_scheduler: FinetuneLinearLRSchedulerArgs | None = None
364+
365+
366+
class FinetuneCosineLRScheduler(FinetuneLRScheduler):
367+
lr_scheduler_type: Literal["cosine"] = "cosine"
368+
lr_scheduler: FinetuneCosineLRSchedulerArgs | None = None
355369

356370

357371
class FinetuneCheckpoint(BaseModel):

0 commit comments

Comments
 (0)