Skip to content

Commit

Permalink
Change default batch_size for finetuning to max_batch_size for a model (
Browse files Browse the repository at this point in the history
#189)

* add code

* add auto

* style

* fix handlers

* auto to max

* renaming

* add warning message

* some fixes

* new warning msg
  • Loading branch information
artek0chumak authored Sep 25, 2024
1 parent ec0772b commit f13c7a1
Show file tree
Hide file tree
Showing 9 changed files with 210 additions and 24 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ client.fine_tuning.create(
model = 'mistralai/Mixtral-8x7B-Instruct-v0.1',
n_epochs = 3,
n_checkpoints = 1,
batch_size = 4,
batch_size = "max",
learning_rate = 1e-5,
suffix = 'my-demo-finetune',
wandb_api_key = '1a2b3c4d5e.......',
Expand Down
64 changes: 55 additions & 9 deletions src/together/cli/api/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
import json
from datetime import datetime
from textwrap import wrap
from typing import Any, Literal

import click
from click.core import ParameterSource # type: ignore[attr-defined]
from rich import print as rprint
from tabulate import tabulate

from together import Together
from together.types.finetune import DownloadCheckpointType
from together.cli.api.utils import INT_WITH_MAX
from together.utils import finetune_price_to_dollars, log_warn, parse_timestamp
from together.types.finetune import DownloadCheckpointType, FinetuneTrainingLimits


_CONFIRMATION_MESSAGE = (
Expand Down Expand Up @@ -56,7 +58,7 @@ def fine_tuning(ctx: click.Context) -> None:
@click.option(
"--n-checkpoints", type=int, default=1, help="Number of checkpoints to save"
)
@click.option("--batch-size", type=int, default=16, help="Train batch size")
@click.option("--batch-size", type=INT_WITH_MAX, default="max", help="Train batch size")
@click.option("--learning-rate", type=float, default=1e-5, help="Learning rate")
@click.option(
"--lora/--no-lora",
Expand Down Expand Up @@ -93,7 +95,7 @@ def create(
n_epochs: int,
n_evals: int,
n_checkpoints: int,
batch_size: int,
batch_size: int | Literal["max"],
learning_rate: float,
lora: bool,
lora_r: int,
Expand All @@ -107,20 +109,64 @@ def create(
"""Start fine-tuning"""
client: Together = ctx.obj

training_args: dict[str, Any] = dict(
training_file=training_file,
model=model,
n_epochs=n_epochs,
validation_file=validation_file,
n_evals=n_evals,
n_checkpoints=n_checkpoints,
batch_size=batch_size,
learning_rate=learning_rate,
lora=lora,
lora_r=lora_r,
lora_dropout=lora_dropout,
lora_alpha=lora_alpha,
lora_trainable_modules=lora_trainable_modules,
suffix=suffix,
wandb_api_key=wandb_api_key,
)

model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits(
model=model
)

if lora:
learning_rate_source = click.get_current_context().get_parameter_source( # type: ignore[attr-defined]
"learning_rate"
)
if learning_rate_source == ParameterSource.DEFAULT:
learning_rate = 1e-3
if model_limits.lora_training is None:
raise click.BadParameter(
f"LoRA fine-tuning is not supported for the model `{model}`"
)

default_values = {
"lora_r": model_limits.lora_training.max_rank,
"batch_size": model_limits.lora_training.max_batch_size,
"learning_rate": 1e-3,
}
for arg in default_values:
arg_source = ctx.get_parameter_source("arg") # type: ignore[attr-defined]
if arg_source == ParameterSource.DEFAULT:
training_args[arg] = default_values[arg_source]

if ctx.get_parameter_source("lora_alpha") == ParameterSource.DEFAULT: # type: ignore[attr-defined]
training_args["lora_alpha"] = training_args["lora_r"] * 2
else:
if model_limits.full_training is None:
raise click.BadParameter(
f"Full fine-tuning is not supported for the model `{model}`"
)

for param in ["lora_r", "lora_dropout", "lora_alpha", "lora_trainable_modules"]:
param_source = click.get_current_context().get_parameter_source(param) # type: ignore[attr-defined]
param_source = ctx.get_parameter_source(param) # type: ignore[attr-defined]
if param_source != ParameterSource.DEFAULT:
raise click.BadParameter(
f"You set LoRA parameter `{param}` for a full fine-tuning job. "
f"Please change the job type with --lora or remove `{param}` from the arguments"
)

batch_size_source = ctx.get_parameter_source("batch_size") # type: ignore[attr-defined]
if batch_size_source == ParameterSource.DEFAULT:
training_args["batch_size"] = model_limits.full_training.max_batch_size

if n_evals <= 0 and validation_file:
log_warn(
"Warning: You have specified a validation file but the number of evaluation loops is set to 0. No evaluations will be performed."
Expand Down
21 changes: 21 additions & 0 deletions src/together/cli/api/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import click

from typing import Literal


class AutoIntParamType(click.ParamType):
name = "integer"

def convert(
self, value: str, param: click.Parameter | None, ctx: click.Context | None
) -> int | Literal["max"] | None:
if isinstance(value, int):
return value

if value == "max":
return "max"

self.fail("Invalid integer value: {value}")


INT_WITH_MAX = AutoIntParamType()
4 changes: 2 additions & 2 deletions src/together/legacy/finetune.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import warnings
from typing import Any, Dict, List
from typing import Any, Dict, List, Literal

import together
from together.legacy.base import API_KEY_WARNING, deprecated
Expand Down Expand Up @@ -43,7 +43,7 @@ def create(
model=model,
n_epochs=n_epochs,
n_checkpoints=n_checkpoints,
batch_size=batch_size,
batch_size=batch_size if isinstance(batch_size, int) else "max",
learning_rate=learning_rate,
suffix=suffix,
wandb_api_key=wandb_api_key,
Expand Down
110 changes: 99 additions & 11 deletions src/together/resources/finetune.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from pathlib import Path
from typing import Literal

from rich import print as rprint

Expand All @@ -13,14 +14,15 @@
FinetuneListEvents,
FinetuneRequest,
FinetuneResponse,
FinetuneTrainingLimits,
FullTrainingType,
LoRATrainingType,
TogetherClient,
TogetherRequest,
TrainingType,
)
from together.types.finetune import DownloadCheckpointType
from together.utils import log_warn, normalize_key
from together.utils import log_warn_once, normalize_key


class FineTuning:
Expand All @@ -36,16 +38,17 @@ def create(
validation_file: str | None = "",
n_evals: int | None = 0,
n_checkpoints: int | None = 1,
batch_size: int | None = 16,
batch_size: int | Literal["max"] = "max",
learning_rate: float | None = 0.00001,
lora: bool = False,
lora_r: int | None = 8,
lora_r: int | None = None,
lora_dropout: float | None = 0,
lora_alpha: float | None = 8,
lora_alpha: float | None = None,
lora_trainable_modules: str | None = "all-linear",
suffix: str | None = None,
wandb_api_key: str | None = None,
verbose: bool = False,
model_limits: FinetuneTrainingLimits | None = None,
) -> FinetuneResponse:
"""
Method to initiate a fine-tuning job
Expand All @@ -58,7 +61,7 @@ def create(
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
n_checkpoints (int, optional): Number of checkpoints to save during fine-tuning.
Defaults to 1.
batch_size (int, optional): Batch size for fine-tuning. Defaults to 32.
batch_size (int, optional): Batch size for fine-tuning. Defaults to max.
learning_rate (float, optional): Learning rate multiplier to use for training
Defaults to 0.00001.
lora (bool, optional): Whether to use LoRA adapters. Defaults to True.
Expand All @@ -72,24 +75,59 @@ def create(
Defaults to None.
verbose (bool, optional): whether to print the job parameters before submitting a request.
Defaults to False.
model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
Defaults to None.
Returns:
FinetuneResponse: Object containing information about fine-tuning job.
"""

if batch_size == "max":
log_warn_once(
"Starting from together>=1.3.0, "
"the default batch size is set to the maximum allowed value for each model."
)

requestor = api_requestor.APIRequestor(
client=self._client,
)

if model_limits is None:
model_limits = self.get_model_limits(model=model)

training_type: TrainingType = FullTrainingType()
if lora:
if model_limits.lora_training is None:
raise ValueError(
"LoRA adapters are not supported for the selected model."
)
lora_r = (
lora_r if lora_r is not None else model_limits.lora_training.max_rank
)
lora_alpha = lora_alpha if lora_alpha is not None else lora_r * 2
training_type = LoRATrainingType(
lora_r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
lora_trainable_modules=lora_trainable_modules,
)

batch_size = (
batch_size
if batch_size != "max"
else model_limits.lora_training.max_batch_size
)
else:
if model_limits.full_training is None:
raise ValueError(
"Full training is not supported for the selected model."
)
batch_size = (
batch_size
if batch_size != "max"
else model_limits.full_training.max_batch_size
)

finetune_request = FinetuneRequest(
model=model,
training_file=training_file,
Expand Down Expand Up @@ -121,12 +159,6 @@ def create(

assert isinstance(response, TogetherResponse)

# TODO: Remove after next LoRA default change
log_warn(
"Some of the jobs run _directly_ from the together-python library might be trained using LoRA adapters. "
"The version range when this change occurred is from 1.2.3 to 1.2.6."
)

return FinetuneResponse(**response.data)

def list(self) -> FinetuneList:
Expand Down Expand Up @@ -305,6 +337,34 @@ def download(
size=file_size,
)

def get_model_limits(self, *, model: str) -> FinetuneTrainingLimits:
"""
Requests training limits for a specific model
Args:
model_name (str): Name of the model to get limits for
Returns:
FinetuneTrainingLimits: Object containing training limits for the model
"""

requestor = api_requestor.APIRequestor(
client=self._client,
)

model_limits_response, _, _ = requestor.request(
options=TogetherRequest(
method="GET",
url="fine-tunes/models/limits",
params={"model_name": model},
),
stream=False,
)

model_limits = FinetuneTrainingLimits(**model_limits_response.data)

return model_limits


class AsyncFineTuning:
def __init__(self, client: TogetherClient) -> None:
Expand Down Expand Up @@ -493,3 +553,31 @@ async def download(
"AsyncFineTuning.download not implemented. "
"Please use FineTuning.download function instead."
)

async def get_model_limits(self, *, model: str) -> FinetuneTrainingLimits:
"""
Requests training limits for a specific model
Args:
model_name (str): Name of the model to get limits for
Returns:
FinetuneTrainingLimits: Object containing training limits for the model
"""

requestor = api_requestor.APIRequestor(
client=self._client,
)

model_limits_response, _, _ = await requestor.arequest(
options=TogetherRequest(
method="GET",
url="fine-tunes/models/limits",
params={"model": model},
),
stream=False,
)

model_limits = FinetuneTrainingLimits(**model_limits_response.data)

return model_limits
2 changes: 2 additions & 0 deletions src/together/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
FullTrainingType,
LoRATrainingType,
TrainingType,
FinetuneTrainingLimits,
)
from together.types.images import (
ImageRequest,
Expand Down Expand Up @@ -71,4 +72,5 @@
"LoRATrainingType",
"RerankRequest",
"RerankResponse",
"FinetuneTrainingLimits",
]
18 changes: 18 additions & 0 deletions src/together/types/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,21 @@ class FinetuneDownloadResult(BaseModel):
filename: str | None = None
# size in bytes
size: int | None = None


class FinetuneFullTrainingLimits(BaseModel):
max_batch_size: int
min_batch_size: int


class FinetuneLoraTrainingLimits(FinetuneFullTrainingLimits):
max_rank: int
target_modules: List[str]


class FinetuneTrainingLimits(BaseModel):
max_num_epochs: int
max_learning_rate: float
min_learning_rate: float
full_training: FinetuneFullTrainingLimits | None = None
lora_training: FinetuneLoraTrainingLimits | None = None
Loading

0 comments on commit f13c7a1

Please sign in to comment.