Skip to content

Commit ea39fb5

Browse files
authored
Fix model_name requirements (#280)
* Fix model_name requirements * add two field set exception
1 parent b24a5c3 commit ea39fb5

File tree

3 files changed

+45
-10
lines changed

3 files changed

+45
-10
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
1212

1313
[tool.poetry]
1414
name = "together"
15-
version = "1.5.0"
15+
version = "1.5.1"
1616
authors = [
1717
"Together AI <[email protected]>"
1818
]

src/together/cli/api/finetune.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def fine_tuning(ctx: click.Context) -> None:
6060
@click.option(
6161
"--training-file", type=str, required=True, help="Training file ID from Files API"
6262
)
63-
@click.option("--model", type=str, required=True, help="Base model name")
63+
@click.option("--model", type=str, help="Base model name")
6464
@click.option("--n-epochs", type=int, default=1, help="Number of epochs to train for")
6565
@click.option(
6666
"--validation-file", type=str, default="", help="Validation file ID from Files API"
@@ -214,8 +214,15 @@ def create(
214214
from_checkpoint=from_checkpoint,
215215
)
216216

217+
if model is None and from_checkpoint is None:
218+
raise click.BadParameter("You must specify either a model or a checkpoint")
219+
220+
model_name = model
221+
if from_checkpoint is not None:
222+
model_name = from_checkpoint.split(":")[0]
223+
217224
model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits(
218-
model=model
225+
model=model_name
219226
)
220227

221228
if lora:

src/together/resources/finetune.py

+35-7
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
def createFinetuneRequest(
5151
model_limits: FinetuneTrainingLimits,
5252
training_file: str,
53-
model: str,
53+
model: str | None = None,
5454
n_epochs: int = 1,
5555
validation_file: str | None = "",
5656
n_evals: int | None = 0,
@@ -77,6 +77,11 @@ def createFinetuneRequest(
7777
from_checkpoint: str | None = None,
7878
) -> FinetuneRequest:
7979

80+
if model is not None and from_checkpoint is not None:
81+
raise ValueError(
82+
"You must specify either a model or a checkpoint to start a job from, not both"
83+
)
84+
8085
if batch_size == "max":
8186
log_warn_once(
8287
"Starting from together>=1.3.0, "
@@ -237,7 +242,7 @@ def create(
237242
self,
238243
*,
239244
training_file: str,
240-
model: str,
245+
model: str | None = None,
241246
n_epochs: int = 1,
242247
validation_file: str | None = "",
243248
n_evals: int | None = 0,
@@ -270,7 +275,7 @@ def create(
270275
271276
Args:
272277
training_file (str): File-ID of a file uploaded to the Together API
273-
model (str): Name of the base model to run fine-tune job on
278+
model (str, optional): Name of the base model to run fine-tune job on
274279
n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1.
275280
validation file (str, optional): File ID of a file uploaded to the Together API for validation.
276281
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
@@ -320,12 +325,24 @@ def create(
320325
FinetuneResponse: Object containing information about fine-tuning job.
321326
"""
322327

328+
if model is None and from_checkpoint is None:
329+
raise ValueError("You must specify either a model or a checkpoint")
330+
323331
requestor = api_requestor.APIRequestor(
324332
client=self._client,
325333
)
326334

327335
if model_limits is None:
328-
model_limits = self.get_model_limits(model=model)
336+
# mypy doesn't understand that model or from_checkpoint is not None
337+
if model is not None:
338+
model_name = model
339+
elif from_checkpoint is not None:
340+
model_name = from_checkpoint.split(":")[0]
341+
else:
342+
# this branch is unreachable, but mypy doesn't know that
343+
pass
344+
model_limits = self.get_model_limits(model=model_name)
345+
329346
finetune_request = createFinetuneRequest(
330347
model_limits=model_limits,
331348
training_file=training_file,
@@ -610,7 +627,7 @@ async def create(
610627
self,
611628
*,
612629
training_file: str,
613-
model: str,
630+
model: str | None = None,
614631
n_epochs: int = 1,
615632
validation_file: str | None = "",
616633
n_evals: int | None = 0,
@@ -643,7 +660,7 @@ async def create(
643660
644661
Args:
645662
training_file (str): File-ID of a file uploaded to the Together API
646-
model (str): Name of the base model to run fine-tune job on
663+
model (str, optional): Name of the base model to run fine-tune job on
647664
n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1.
648665
validation file (str, optional): File ID of a file uploaded to the Together API for validation.
649666
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
@@ -693,12 +710,23 @@ async def create(
693710
FinetuneResponse: Object containing information about fine-tuning job.
694711
"""
695712

713+
if model is None and from_checkpoint is None:
714+
raise ValueError("You must specify either a model or a checkpoint")
715+
696716
requestor = api_requestor.APIRequestor(
697717
client=self._client,
698718
)
699719

700720
if model_limits is None:
701-
model_limits = await self.get_model_limits(model=model)
721+
# mypy doesn't understand that model or from_checkpoint is not None
722+
if model is not None:
723+
model_name = model
724+
elif from_checkpoint is not None:
725+
model_name = from_checkpoint.split(":")[0]
726+
else:
727+
# this branch is unreachable, but mypy doesn't know that
728+
pass
729+
model_limits = await self.get_model_limits(model=model_name)
702730

703731
finetune_request = createFinetuneRequest(
704732
model_limits=model_limits,

0 commit comments

Comments
 (0)