Skip to content

Commit f13c7a1

Browse files
authored
Change default batch_size for finetuning to max_batch_size for a model (#189)
* add code * add auto * style * fix handlers * auto to max * renaming * add warning message * some fixes * new warning msg
1 parent ec0772b commit f13c7a1

File tree

9 files changed

+210
-24
lines changed

9 files changed

+210
-24
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ client.fine_tuning.create(
242242
model = 'mistralai/Mixtral-8x7B-Instruct-v0.1',
243243
n_epochs = 3,
244244
n_checkpoints = 1,
245-
batch_size = 4,
245+
batch_size = "max",
246246
learning_rate = 1e-5,
247247
suffix = 'my-demo-finetune',
248248
wandb_api_key = '1a2b3c4d5e.......',

src/together/cli/api/finetune.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,17 @@
33
import json
44
from datetime import datetime
55
from textwrap import wrap
6+
from typing import Any, Literal
67

78
import click
89
from click.core import ParameterSource # type: ignore[attr-defined]
910
from rich import print as rprint
1011
from tabulate import tabulate
1112

1213
from together import Together
13-
from together.types.finetune import DownloadCheckpointType
14+
from together.cli.api.utils import INT_WITH_MAX
1415
from together.utils import finetune_price_to_dollars, log_warn, parse_timestamp
16+
from together.types.finetune import DownloadCheckpointType, FinetuneTrainingLimits
1517

1618

1719
_CONFIRMATION_MESSAGE = (
@@ -56,7 +58,7 @@ def fine_tuning(ctx: click.Context) -> None:
5658
@click.option(
5759
"--n-checkpoints", type=int, default=1, help="Number of checkpoints to save"
5860
)
59-
@click.option("--batch-size", type=int, default=16, help="Train batch size")
61+
@click.option("--batch-size", type=INT_WITH_MAX, default="max", help="Train batch size")
6062
@click.option("--learning-rate", type=float, default=1e-5, help="Learning rate")
6163
@click.option(
6264
"--lora/--no-lora",
@@ -93,7 +95,7 @@ def create(
9395
n_epochs: int,
9496
n_evals: int,
9597
n_checkpoints: int,
96-
batch_size: int,
98+
batch_size: int | Literal["max"],
9799
learning_rate: float,
98100
lora: bool,
99101
lora_r: int,
@@ -107,20 +109,64 @@ def create(
107109
"""Start fine-tuning"""
108110
client: Together = ctx.obj
109111

112+
training_args: dict[str, Any] = dict(
113+
training_file=training_file,
114+
model=model,
115+
n_epochs=n_epochs,
116+
validation_file=validation_file,
117+
n_evals=n_evals,
118+
n_checkpoints=n_checkpoints,
119+
batch_size=batch_size,
120+
learning_rate=learning_rate,
121+
lora=lora,
122+
lora_r=lora_r,
123+
lora_dropout=lora_dropout,
124+
lora_alpha=lora_alpha,
125+
lora_trainable_modules=lora_trainable_modules,
126+
suffix=suffix,
127+
wandb_api_key=wandb_api_key,
128+
)
129+
130+
model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits(
131+
model=model
132+
)
133+
110134
if lora:
111-
learning_rate_source = click.get_current_context().get_parameter_source( # type: ignore[attr-defined]
112-
"learning_rate"
113-
)
114-
if learning_rate_source == ParameterSource.DEFAULT:
115-
learning_rate = 1e-3
135+
if model_limits.lora_training is None:
136+
raise click.BadParameter(
137+
f"LoRA fine-tuning is not supported for the model `{model}`"
138+
)
139+
140+
default_values = {
141+
"lora_r": model_limits.lora_training.max_rank,
142+
"batch_size": model_limits.lora_training.max_batch_size,
143+
"learning_rate": 1e-3,
144+
}
145+
for arg in default_values:
146+
arg_source = ctx.get_parameter_source("arg") # type: ignore[attr-defined]
147+
if arg_source == ParameterSource.DEFAULT:
148+
training_args[arg] = default_values[arg_source]
149+
150+
if ctx.get_parameter_source("lora_alpha") == ParameterSource.DEFAULT: # type: ignore[attr-defined]
151+
training_args["lora_alpha"] = training_args["lora_r"] * 2
116152
else:
153+
if model_limits.full_training is None:
154+
raise click.BadParameter(
155+
f"Full fine-tuning is not supported for the model `{model}`"
156+
)
157+
117158
for param in ["lora_r", "lora_dropout", "lora_alpha", "lora_trainable_modules"]:
118-
param_source = click.get_current_context().get_parameter_source(param) # type: ignore[attr-defined]
159+
param_source = ctx.get_parameter_source(param) # type: ignore[attr-defined]
119160
if param_source != ParameterSource.DEFAULT:
120161
raise click.BadParameter(
121162
f"You set LoRA parameter `{param}` for a full fine-tuning job. "
122163
f"Please change the job type with --lora or remove `{param}` from the arguments"
123164
)
165+
166+
batch_size_source = ctx.get_parameter_source("batch_size") # type: ignore[attr-defined]
167+
if batch_size_source == ParameterSource.DEFAULT:
168+
training_args["batch_size"] = model_limits.full_training.max_batch_size
169+
124170
if n_evals <= 0 and validation_file:
125171
log_warn(
126172
"Warning: You have specified a validation file but the number of evaluation loops is set to 0. No evaluations will be performed."

src/together/cli/api/utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import click
2+
3+
from typing import Literal
4+
5+
6+
class AutoIntParamType(click.ParamType):
7+
name = "integer"
8+
9+
def convert(
10+
self, value: str, param: click.Parameter | None, ctx: click.Context | None
11+
) -> int | Literal["max"] | None:
12+
if isinstance(value, int):
13+
return value
14+
15+
if value == "max":
16+
return "max"
17+
18+
self.fail("Invalid integer value: {value}")
19+
20+
21+
INT_WITH_MAX = AutoIntParamType()

src/together/legacy/finetune.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import warnings
4-
from typing import Any, Dict, List
4+
from typing import Any, Dict, List, Literal
55

66
import together
77
from together.legacy.base import API_KEY_WARNING, deprecated
@@ -43,7 +43,7 @@ def create(
4343
model=model,
4444
n_epochs=n_epochs,
4545
n_checkpoints=n_checkpoints,
46-
batch_size=batch_size,
46+
batch_size=batch_size if isinstance(batch_size, int) else "max",
4747
learning_rate=learning_rate,
4848
suffix=suffix,
4949
wandb_api_key=wandb_api_key,

src/together/resources/finetune.py

Lines changed: 99 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from pathlib import Path
4+
from typing import Literal
45

56
from rich import print as rprint
67

@@ -13,14 +14,15 @@
1314
FinetuneListEvents,
1415
FinetuneRequest,
1516
FinetuneResponse,
17+
FinetuneTrainingLimits,
1618
FullTrainingType,
1719
LoRATrainingType,
1820
TogetherClient,
1921
TogetherRequest,
2022
TrainingType,
2123
)
2224
from together.types.finetune import DownloadCheckpointType
23-
from together.utils import log_warn, normalize_key
25+
from together.utils import log_warn_once, normalize_key
2426

2527

2628
class FineTuning:
@@ -36,16 +38,17 @@ def create(
3638
validation_file: str | None = "",
3739
n_evals: int | None = 0,
3840
n_checkpoints: int | None = 1,
39-
batch_size: int | None = 16,
41+
batch_size: int | Literal["max"] = "max",
4042
learning_rate: float | None = 0.00001,
4143
lora: bool = False,
42-
lora_r: int | None = 8,
44+
lora_r: int | None = None,
4345
lora_dropout: float | None = 0,
44-
lora_alpha: float | None = 8,
46+
lora_alpha: float | None = None,
4547
lora_trainable_modules: str | None = "all-linear",
4648
suffix: str | None = None,
4749
wandb_api_key: str | None = None,
4850
verbose: bool = False,
51+
model_limits: FinetuneTrainingLimits | None = None,
4952
) -> FinetuneResponse:
5053
"""
5154
Method to initiate a fine-tuning job
@@ -58,7 +61,7 @@ def create(
5861
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
5962
n_checkpoints (int, optional): Number of checkpoints to save during fine-tuning.
6063
Defaults to 1.
61-
batch_size (int, optional): Batch size for fine-tuning. Defaults to 32.
64+
batch_size (int, optional): Batch size for fine-tuning. Defaults to max.
6265
learning_rate (float, optional): Learning rate multiplier to use for training
6366
Defaults to 0.00001.
6467
lora (bool, optional): Whether to use LoRA adapters. Defaults to True.
@@ -72,24 +75,59 @@ def create(
7275
Defaults to None.
7376
verbose (bool, optional): whether to print the job parameters before submitting a request.
7477
Defaults to False.
78+
model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
79+
Defaults to None.
7580
7681
Returns:
7782
FinetuneResponse: Object containing information about fine-tuning job.
7883
"""
7984

85+
if batch_size == "max":
86+
log_warn_once(
87+
"Starting from together>=1.3.0, "
88+
"the default batch size is set to the maximum allowed value for each model."
89+
)
90+
8091
requestor = api_requestor.APIRequestor(
8192
client=self._client,
8293
)
8394

95+
if model_limits is None:
96+
model_limits = self.get_model_limits(model=model)
97+
8498
training_type: TrainingType = FullTrainingType()
8599
if lora:
100+
if model_limits.lora_training is None:
101+
raise ValueError(
102+
"LoRA adapters are not supported for the selected model."
103+
)
104+
lora_r = (
105+
lora_r if lora_r is not None else model_limits.lora_training.max_rank
106+
)
107+
lora_alpha = lora_alpha if lora_alpha is not None else lora_r * 2
86108
training_type = LoRATrainingType(
87109
lora_r=lora_r,
88110
lora_alpha=lora_alpha,
89111
lora_dropout=lora_dropout,
90112
lora_trainable_modules=lora_trainable_modules,
91113
)
92114

115+
batch_size = (
116+
batch_size
117+
if batch_size != "max"
118+
else model_limits.lora_training.max_batch_size
119+
)
120+
else:
121+
if model_limits.full_training is None:
122+
raise ValueError(
123+
"Full training is not supported for the selected model."
124+
)
125+
batch_size = (
126+
batch_size
127+
if batch_size != "max"
128+
else model_limits.full_training.max_batch_size
129+
)
130+
93131
finetune_request = FinetuneRequest(
94132
model=model,
95133
training_file=training_file,
@@ -121,12 +159,6 @@ def create(
121159

122160
assert isinstance(response, TogetherResponse)
123161

124-
# TODO: Remove after next LoRA default change
125-
log_warn(
126-
"Some of the jobs run _directly_ from the together-python library might be trained using LoRA adapters. "
127-
"The version range when this change occurred is from 1.2.3 to 1.2.6."
128-
)
129-
130162
return FinetuneResponse(**response.data)
131163

132164
def list(self) -> FinetuneList:
@@ -305,6 +337,34 @@ def download(
305337
size=file_size,
306338
)
307339

340+
def get_model_limits(self, *, model: str) -> FinetuneTrainingLimits:
341+
"""
342+
Requests training limits for a specific model
343+
344+
Args:
345+
model_name (str): Name of the model to get limits for
346+
347+
Returns:
348+
FinetuneTrainingLimits: Object containing training limits for the model
349+
"""
350+
351+
requestor = api_requestor.APIRequestor(
352+
client=self._client,
353+
)
354+
355+
model_limits_response, _, _ = requestor.request(
356+
options=TogetherRequest(
357+
method="GET",
358+
url="fine-tunes/models/limits",
359+
params={"model_name": model},
360+
),
361+
stream=False,
362+
)
363+
364+
model_limits = FinetuneTrainingLimits(**model_limits_response.data)
365+
366+
return model_limits
367+
308368

309369
class AsyncFineTuning:
310370
def __init__(self, client: TogetherClient) -> None:
@@ -493,3 +553,31 @@ async def download(
493553
"AsyncFineTuning.download not implemented. "
494554
"Please use FineTuning.download function instead."
495555
)
556+
557+
async def get_model_limits(self, *, model: str) -> FinetuneTrainingLimits:
558+
"""
559+
Requests training limits for a specific model
560+
561+
Args:
562+
model_name (str): Name of the model to get limits for
563+
564+
Returns:
565+
FinetuneTrainingLimits: Object containing training limits for the model
566+
"""
567+
568+
requestor = api_requestor.APIRequestor(
569+
client=self._client,
570+
)
571+
572+
model_limits_response, _, _ = await requestor.arequest(
573+
options=TogetherRequest(
574+
method="GET",
575+
url="fine-tunes/models/limits",
576+
params={"model": model},
577+
),
578+
stream=False,
579+
)
580+
581+
model_limits = FinetuneTrainingLimits(**model_limits_response.data)
582+
583+
return model_limits

src/together/types/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
FullTrainingType,
3030
LoRATrainingType,
3131
TrainingType,
32+
FinetuneTrainingLimits,
3233
)
3334
from together.types.images import (
3435
ImageRequest,
@@ -71,4 +72,5 @@
7172
"LoRATrainingType",
7273
"RerankRequest",
7374
"RerankResponse",
75+
"FinetuneTrainingLimits",
7476
]

src/together/types/finetune.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,21 @@ class FinetuneDownloadResult(BaseModel):
263263
filename: str | None = None
264264
# size in bytes
265265
size: int | None = None
266+
267+
268+
class FinetuneFullTrainingLimits(BaseModel):
269+
max_batch_size: int
270+
min_batch_size: int
271+
272+
273+
class FinetuneLoraTrainingLimits(FinetuneFullTrainingLimits):
274+
max_rank: int
275+
target_modules: List[str]
276+
277+
278+
class FinetuneTrainingLimits(BaseModel):
279+
max_num_epochs: int
280+
max_learning_rate: float
281+
min_learning_rate: float
282+
full_training: FinetuneFullTrainingLimits | None = None
283+
lora_training: FinetuneLoraTrainingLimits | None = None

0 commit comments

Comments
 (0)