Skip to content

Commit

Permalink
Add a confirmation message before submitting a fine-tuning job
Browse files Browse the repository at this point in the history
  • Loading branch information
mryab committed Sep 24, 2024
1 parent 8925c6e commit 2fefeb6
Showing 1 changed file with 45 additions and 23 deletions.
68 changes: 45 additions & 23 deletions src/together/cli/api/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,18 @@
from tabulate import tabulate

from together import Together
from together.utils import finetune_price_to_dollars, log_warn, parse_timestamp
from together.types.finetune import DownloadCheckpointType
from together.utils import finetune_price_to_dollars, log_warn, parse_timestamp


_CONFIRMATION_MESSAGE = (
"You are about to launch a fine-tuning job. "
"The cost of your job will be determined by the model size, the number of tokens "
"in the training file, the number of tokens in the validation file, the number of epochs, and "
"the number of evaluations. Visit https://www.together.ai/pricing to get a price estimate.\n"
"You can add `-y` or `--confirm` to skip this message.\n\n"
"Do you want to proceed?"
)


class DownloadCheckpointTypeChoice(click.Choice):
Expand Down Expand Up @@ -65,6 +75,14 @@ def fine_tuning(ctx: click.Context) -> None:
"--suffix", type=str, default=None, help="Suffix for the fine-tuned model name"
)
@click.option("--wandb-api-key", type=str, default=None, help="Wandb API key")
@click.option(
"--confirm",
"-y",
type=bool,
is_flag=True,
default=False,
help="Whether to skip the launch confirmation message",
)
def create(
ctx: click.Context,
training_file: str,
Expand All @@ -82,6 +100,7 @@ def create(
lora_trainable_modules: str,
suffix: str,
wandb_api_key: str,
confirm: bool,
) -> None:
"""Start fine-tuning"""
client: Together = ctx.obj
Expand Down Expand Up @@ -109,30 +128,33 @@ def create(
"You have specified a number of evaluation loops but no validation file."
)

response = client.fine_tuning.create(
training_file=training_file,
model=model,
n_epochs=n_epochs,
validation_file=validation_file,
n_evals=n_evals,
n_checkpoints=n_checkpoints,
batch_size=batch_size,
learning_rate=learning_rate,
lora=lora,
lora_r=lora_r,
lora_dropout=lora_dropout,
lora_alpha=lora_alpha,
lora_trainable_modules=lora_trainable_modules,
suffix=suffix,
wandb_api_key=wandb_api_key,
)
if confirm or click.confirm(_CONFIRMATION_MESSAGE, default=True, show_default=True):
response = client.fine_tuning.create(
training_file=training_file,
model=model,
n_epochs=n_epochs,
validation_file=validation_file,
n_evals=n_evals,
n_checkpoints=n_checkpoints,
batch_size=batch_size,
learning_rate=learning_rate,
lora=lora,
lora_r=lora_r,
lora_dropout=lora_dropout,
lora_alpha=lora_alpha,
lora_trainable_modules=lora_trainable_modules,
suffix=suffix,
wandb_api_key=wandb_api_key,
)

click.echo(json.dumps(response.model_dump(exclude_none=True), indent=4))
click.echo(json.dumps(response.model_dump(exclude_none=True), indent=4))

# TODO: Remove it after the 21st of August
log_warn(
"The default value of batch size has been changed from 32 to 16 since together version >= 1.2.6"
)
# TODO: Remove it after the 21st of August
log_warn(
"The default value of batch size has been changed from 32 to 16 since together version >= 1.2.6"
)
else:
click.echo("No confirmation received, stopping job launch")


@fine_tuning.command()
Expand Down

0 comments on commit 2fefeb6

Please sign in to comment.