From f1d3e210be601de1846874d9ef404cfd55efe2f0 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Tue, 24 Sep 2024 14:21:55 +0100 Subject: [PATCH] Improve report messages with rich --- .pre-commit-config.yaml | 2 +- src/together/cli/api/finetune.py | 18 ++++++++++++------ src/together/resources/finetune.py | 14 +++++++++++--- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b754a21..299013a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,5 +17,5 @@ repos: hooks: - id: mypy args: [--strict] - additional_dependencies: [types-requests, types-tqdm, types-tabulate, types-click, types-filelock, types-Pillow, pyarrow-stubs, pydantic, aiohttp] + additional_dependencies: [types-requests, types-tqdm, types-tabulate, types-click, types-filelock, types-Pillow, rich, pyarrow-stubs, pydantic, aiohttp] exclude: ^tests/ diff --git a/src/together/cli/api/finetune.py b/src/together/cli/api/finetune.py index 5530ab7..d1d3cc3 100644 --- a/src/together/cli/api/finetune.py +++ b/src/together/cli/api/finetune.py @@ -1,10 +1,12 @@ from __future__ import annotations import json +from datetime import datetime from textwrap import wrap import click from click.core import ParameterSource # type: ignore[attr-defined] +from rich import print as rich_print from tabulate import tabulate from together import Together @@ -145,14 +147,18 @@ def create( lora_trainable_modules=lora_trainable_modules, suffix=suffix, wandb_api_key=wandb_api_key, + verbose=True, ) - click.echo(json.dumps(response.model_dump(exclude_none=True), indent=4)) - - # TODO: Remove it after the 21st of August - log_warn( - "The default value of batch size has been changed from 32 to 16 since together version >= 1.2.6" - ) + report_string = f"Successfully submitted a fine-tuning job {response.id}" + if response.created_at is not None: + created_time = datetime.strptime( + response.created_at, "%Y-%m-%dT%H:%M:%S.%f%z" + ) + # created_at reports UTC time, we use .astimezone() to convert to local time + formatted_time = created_time.astimezone().strftime("%m/%d/%Y, %H:%M:%S") + report_string += f" at {formatted_time}" + rich_print(report_string) else: click.echo("No confirmation received, stopping job launch") diff --git a/src/together/resources/finetune.py b/src/together/resources/finetune.py index 5ee7658..3d172f8 100644 --- a/src/together/resources/finetune.py +++ b/src/together/resources/finetune.py @@ -2,6 +2,8 @@ from pathlib import Path +from rich import print + from together.abstract import api_requestor from together.filemanager import DownloadManager from together.together_response import TogetherResponse @@ -43,6 +45,7 @@ def create( lora_trainable_modules: str | None = "all-linear", suffix: str | None = None, wandb_api_key: str | None = None, + verbose: bool = False, ) -> FinetuneResponse: """ Method to initiate a fine-tuning job @@ -67,6 +70,8 @@ def create( Defaults to None. wandb_api_key (str, optional): API key for Weights & Biases integration. Defaults to None. + verbose (bool, optional): whether to print the job parameters before submitting a request. + Defaults to False. Returns: FinetuneResponse: Object containing information about fine-tuning job. @@ -85,7 +90,7 @@ def create( lora_trainable_modules=lora_trainable_modules, ) - parameter_payload = FinetuneRequest( + finetune_request = FinetuneRequest( model=model, training_file=training_file, validation_file=validation_file, @@ -97,7 +102,10 @@ def create( training_type=training_type, suffix=suffix, wandb_key=wandb_api_key, - ).model_dump(exclude_none=True) + ) + if verbose: + print(finetune_request) + parameter_payload = finetune_request.model_dump(exclude_none=True) response, _, _ = requestor.request( options=TogetherRequest( @@ -266,7 +274,7 @@ def download( raise ValueError( "Only DEFAULT checkpoint type is allowed for FullTrainingType" ) - url += f"&checkpoint=modelOutputPath" + url += "&checkpoint=modelOutputPath" elif isinstance(ft_job.training_type, LoRATrainingType): if checkpoint_type == DownloadCheckpointType.DEFAULT: checkpoint_type = DownloadCheckpointType.MERGED