diff --git a/README.md b/README.md index e39418c..7f0ff9e 100644 --- a/README.md +++ b/README.md @@ -242,7 +242,7 @@ client.fine_tuning.create( model = 'mistralai/Mixtral-8x7B-Instruct-v0.1', n_epochs = 3, n_checkpoints = 1, - batch_size = 4, + batch_size = "max", learning_rate = 1e-5, suffix = 'my-demo-finetune', wandb_api_key = '1a2b3c4d5e.......', diff --git a/src/together/cli/api/finetune.py b/src/together/cli/api/finetune.py index 6601d49..1f2a05b 100644 --- a/src/together/cli/api/finetune.py +++ b/src/together/cli/api/finetune.py @@ -3,6 +3,7 @@ import json from datetime import datetime from textwrap import wrap +from typing import Any, Literal import click from click.core import ParameterSource # type: ignore[attr-defined] @@ -10,8 +11,9 @@ from tabulate import tabulate from together import Together -from together.types.finetune import DownloadCheckpointType +from together.cli.api.utils import INT_WITH_MAX from together.utils import finetune_price_to_dollars, log_warn, parse_timestamp +from together.types.finetune import DownloadCheckpointType, FinetuneTrainingLimits _CONFIRMATION_MESSAGE = ( @@ -56,7 +58,7 @@ def fine_tuning(ctx: click.Context) -> None: @click.option( "--n-checkpoints", type=int, default=1, help="Number of checkpoints to save" ) -@click.option("--batch-size", type=int, default=16, help="Train batch size") +@click.option("--batch-size", type=INT_WITH_MAX, default="max", help="Train batch size") @click.option("--learning-rate", type=float, default=1e-5, help="Learning rate") @click.option( "--lora/--no-lora", @@ -93,7 +95,7 @@ def create( n_epochs: int, n_evals: int, n_checkpoints: int, - batch_size: int, + batch_size: int | Literal["max"], learning_rate: float, lora: bool, lora_r: int, @@ -107,20 +109,64 @@ def create( """Start fine-tuning""" client: Together = ctx.obj + training_args: dict[str, Any] = dict( + training_file=training_file, + model=model, + n_epochs=n_epochs, + validation_file=validation_file, + n_evals=n_evals, + n_checkpoints=n_checkpoints, + batch_size=batch_size, + learning_rate=learning_rate, + lora=lora, + lora_r=lora_r, + lora_dropout=lora_dropout, + lora_alpha=lora_alpha, + lora_trainable_modules=lora_trainable_modules, + suffix=suffix, + wandb_api_key=wandb_api_key, + ) + + model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits( + model=model + ) + if lora: - learning_rate_source = click.get_current_context().get_parameter_source( # type: ignore[attr-defined] - "learning_rate" - ) - if learning_rate_source == ParameterSource.DEFAULT: - learning_rate = 1e-3 + if model_limits.lora_training is None: + raise click.BadParameter( + f"LoRA fine-tuning is not supported for the model `{model}`" + ) + + default_values = { + "lora_r": model_limits.lora_training.max_rank, + "batch_size": model_limits.lora_training.max_batch_size, + "learning_rate": 1e-3, + } + for arg in default_values: + arg_source = ctx.get_parameter_source("arg") # type: ignore[attr-defined] + if arg_source == ParameterSource.DEFAULT: + training_args[arg] = default_values[arg_source] + + if ctx.get_parameter_source("lora_alpha") == ParameterSource.DEFAULT: # type: ignore[attr-defined] + training_args["lora_alpha"] = training_args["lora_r"] * 2 else: + if model_limits.full_training is None: + raise click.BadParameter( + f"Full fine-tuning is not supported for the model `{model}`" + ) + for param in ["lora_r", "lora_dropout", "lora_alpha", "lora_trainable_modules"]: - param_source = click.get_current_context().get_parameter_source(param) # type: ignore[attr-defined] + param_source = ctx.get_parameter_source(param) # type: ignore[attr-defined] if param_source != ParameterSource.DEFAULT: raise click.BadParameter( f"You set LoRA parameter `{param}` for a full fine-tuning job. " f"Please change the job type with --lora or remove `{param}` from the arguments" ) + + batch_size_source = ctx.get_parameter_source("batch_size") # type: ignore[attr-defined] + if batch_size_source == ParameterSource.DEFAULT: + training_args["batch_size"] = model_limits.full_training.max_batch_size + if n_evals <= 0 and validation_file: log_warn( "Warning: You have specified a validation file but the number of evaluation loops is set to 0. No evaluations will be performed." diff --git a/src/together/cli/api/utils.py b/src/together/cli/api/utils.py new file mode 100644 index 0000000..71ee378 --- /dev/null +++ b/src/together/cli/api/utils.py @@ -0,0 +1,21 @@ +import click + +from typing import Literal + + +class AutoIntParamType(click.ParamType): + name = "integer" + + def convert( + self, value: str, param: click.Parameter | None, ctx: click.Context | None + ) -> int | Literal["max"] | None: + if isinstance(value, int): + return value + + if value == "max": + return "max" + + self.fail("Invalid integer value: {value}") + + +INT_WITH_MAX = AutoIntParamType() diff --git a/src/together/legacy/finetune.py b/src/together/legacy/finetune.py index baad2ab..fe53be0 100644 --- a/src/together/legacy/finetune.py +++ b/src/together/legacy/finetune.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import Any, Dict, List +from typing import Any, Dict, List, Literal import together from together.legacy.base import API_KEY_WARNING, deprecated @@ -43,7 +43,7 @@ def create( model=model, n_epochs=n_epochs, n_checkpoints=n_checkpoints, - batch_size=batch_size, + batch_size=batch_size if isinstance(batch_size, int) else "max", learning_rate=learning_rate, suffix=suffix, wandb_api_key=wandb_api_key, diff --git a/src/together/resources/finetune.py b/src/together/resources/finetune.py index 8a09d42..3c8fb6a 100644 --- a/src/together/resources/finetune.py +++ b/src/together/resources/finetune.py @@ -1,6 +1,7 @@ from __future__ import annotations from pathlib import Path +from typing import Literal from rich import print as rprint @@ -13,6 +14,7 @@ FinetuneListEvents, FinetuneRequest, FinetuneResponse, + FinetuneTrainingLimits, FullTrainingType, LoRATrainingType, TogetherClient, @@ -20,7 +22,7 @@ TrainingType, ) from together.types.finetune import DownloadCheckpointType -from together.utils import log_warn, normalize_key +from together.utils import log_warn_once, normalize_key class FineTuning: @@ -36,16 +38,17 @@ def create( validation_file: str | None = "", n_evals: int | None = 0, n_checkpoints: int | None = 1, - batch_size: int | None = 16, + batch_size: int | Literal["max"] = "max", learning_rate: float | None = 0.00001, lora: bool = False, - lora_r: int | None = 8, + lora_r: int | None = None, lora_dropout: float | None = 0, - lora_alpha: float | None = 8, + lora_alpha: float | None = None, lora_trainable_modules: str | None = "all-linear", suffix: str | None = None, wandb_api_key: str | None = None, verbose: bool = False, + model_limits: FinetuneTrainingLimits | None = None, ) -> FinetuneResponse: """ Method to initiate a fine-tuning job @@ -58,7 +61,7 @@ def create( n_evals (int, optional): Number of evaluation loops to run. Defaults to 0. n_checkpoints (int, optional): Number of checkpoints to save during fine-tuning. Defaults to 1. - batch_size (int, optional): Batch size for fine-tuning. Defaults to 32. + batch_size (int, optional): Batch size for fine-tuning. Defaults to max. learning_rate (float, optional): Learning rate multiplier to use for training Defaults to 0.00001. lora (bool, optional): Whether to use LoRA adapters. Defaults to True. @@ -72,17 +75,36 @@ def create( Defaults to None. verbose (bool, optional): whether to print the job parameters before submitting a request. Defaults to False. + model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning. + Defaults to None. Returns: FinetuneResponse: Object containing information about fine-tuning job. """ + if batch_size == "max": + log_warn_once( + "Starting from together>=1.3.0, " + "the default batch size is set to the maximum allowed value for each model." + ) + requestor = api_requestor.APIRequestor( client=self._client, ) + if model_limits is None: + model_limits = self.get_model_limits(model=model) + training_type: TrainingType = FullTrainingType() if lora: + if model_limits.lora_training is None: + raise ValueError( + "LoRA adapters are not supported for the selected model." + ) + lora_r = ( + lora_r if lora_r is not None else model_limits.lora_training.max_rank + ) + lora_alpha = lora_alpha if lora_alpha is not None else lora_r * 2 training_type = LoRATrainingType( lora_r=lora_r, lora_alpha=lora_alpha, @@ -90,6 +112,22 @@ def create( lora_trainable_modules=lora_trainable_modules, ) + batch_size = ( + batch_size + if batch_size != "max" + else model_limits.lora_training.max_batch_size + ) + else: + if model_limits.full_training is None: + raise ValueError( + "Full training is not supported for the selected model." + ) + batch_size = ( + batch_size + if batch_size != "max" + else model_limits.full_training.max_batch_size + ) + finetune_request = FinetuneRequest( model=model, training_file=training_file, @@ -121,12 +159,6 @@ def create( assert isinstance(response, TogetherResponse) - # TODO: Remove after next LoRA default change - log_warn( - "Some of the jobs run _directly_ from the together-python library might be trained using LoRA adapters. " - "The version range when this change occurred is from 1.2.3 to 1.2.6." - ) - return FinetuneResponse(**response.data) def list(self) -> FinetuneList: @@ -305,6 +337,34 @@ def download( size=file_size, ) + def get_model_limits(self, *, model: str) -> FinetuneTrainingLimits: + """ + Requests training limits for a specific model + + Args: + model_name (str): Name of the model to get limits for + + Returns: + FinetuneTrainingLimits: Object containing training limits for the model + """ + + requestor = api_requestor.APIRequestor( + client=self._client, + ) + + model_limits_response, _, _ = requestor.request( + options=TogetherRequest( + method="GET", + url="fine-tunes/models/limits", + params={"model_name": model}, + ), + stream=False, + ) + + model_limits = FinetuneTrainingLimits(**model_limits_response.data) + + return model_limits + class AsyncFineTuning: def __init__(self, client: TogetherClient) -> None: @@ -493,3 +553,31 @@ async def download( "AsyncFineTuning.download not implemented. " "Please use FineTuning.download function instead." ) + + async def get_model_limits(self, *, model: str) -> FinetuneTrainingLimits: + """ + Requests training limits for a specific model + + Args: + model_name (str): Name of the model to get limits for + + Returns: + FinetuneTrainingLimits: Object containing training limits for the model + """ + + requestor = api_requestor.APIRequestor( + client=self._client, + ) + + model_limits_response, _, _ = await requestor.arequest( + options=TogetherRequest( + method="GET", + url="fine-tunes/models/limits", + params={"model": model}, + ), + stream=False, + ) + + model_limits = FinetuneTrainingLimits(**model_limits_response.data) + + return model_limits diff --git a/src/together/types/__init__.py b/src/together/types/__init__.py index baca178..9aa8c16 100644 --- a/src/together/types/__init__.py +++ b/src/together/types/__init__.py @@ -29,6 +29,7 @@ FullTrainingType, LoRATrainingType, TrainingType, + FinetuneTrainingLimits, ) from together.types.images import ( ImageRequest, @@ -71,4 +72,5 @@ "LoRATrainingType", "RerankRequest", "RerankResponse", + "FinetuneTrainingLimits", ] diff --git a/src/together/types/finetune.py b/src/together/types/finetune.py index 26d2f2a..4e36457 100644 --- a/src/together/types/finetune.py +++ b/src/together/types/finetune.py @@ -263,3 +263,21 @@ class FinetuneDownloadResult(BaseModel): filename: str | None = None # size in bytes size: int | None = None + + +class FinetuneFullTrainingLimits(BaseModel): + max_batch_size: int + min_batch_size: int + + +class FinetuneLoraTrainingLimits(FinetuneFullTrainingLimits): + max_rank: int + target_modules: List[str] + + +class FinetuneTrainingLimits(BaseModel): + max_num_epochs: int + max_learning_rate: float + min_learning_rate: float + full_training: FinetuneFullTrainingLimits | None = None + lora_training: FinetuneLoraTrainingLimits | None = None diff --git a/src/together/utils/__init__.py b/src/together/utils/__init__.py index c5de4f3..0e59966 100644 --- a/src/together/utils/__init__.py +++ b/src/together/utils/__init__.py @@ -1,4 +1,4 @@ -from together.utils._log import log_debug, log_info, log_warn, logfmt +from together.utils._log import log_debug, log_info, log_warn, log_warn_once, logfmt from together.utils.api_helpers import default_api_key, get_headers from together.utils.files import check_file from together.utils.tools import ( @@ -18,6 +18,7 @@ "log_debug", "log_info", "log_warn", + "log_warn_once", "logfmt", "enforce_trailing_slash", "normalize_key", diff --git a/src/together/utils/_log.py b/src/together/utils/_log.py index 5efe51c..23abe21 100644 --- a/src/together/utils/_log.py +++ b/src/together/utils/_log.py @@ -13,6 +13,8 @@ TOGETHER_LOG = os.environ.get("TOGETHER_LOG") +WARNING_MESSAGES_ONCE = set() + def _console_log_level() -> str | None: if together.log in ["debug", "info"]: @@ -59,3 +61,11 @@ def log_warn(message: str | Any, **params: Any) -> None: msg = logfmt(dict(message=message, **params)) print(msg, file=sys.stderr) logger.warn(msg) + + +def log_warn_once(message: str | Any, **params: Any) -> None: + msg = logfmt(dict(message=message, **params)) + if msg not in WARNING_MESSAGES_ONCE: + print(msg, file=sys.stderr) + logger.warn(msg) + WARNING_MESSAGES_ONCE.add(msg)