50
50
def createFinetuneRequest (
51
51
model_limits : FinetuneTrainingLimits ,
52
52
training_file : str ,
53
- model : str ,
53
+ model : str | None = None ,
54
54
n_epochs : int = 1 ,
55
55
validation_file : str | None = "" ,
56
56
n_evals : int | None = 0 ,
@@ -77,6 +77,11 @@ def createFinetuneRequest(
77
77
from_checkpoint : str | None = None ,
78
78
) -> FinetuneRequest :
79
79
80
+ if model is not None and from_checkpoint is not None :
81
+ raise ValueError (
82
+ "You must specify either a model or a checkpoint to start a job from, not both"
83
+ )
84
+
80
85
if batch_size == "max" :
81
86
log_warn_once (
82
87
"Starting from together>=1.3.0, "
@@ -237,7 +242,7 @@ def create(
237
242
self ,
238
243
* ,
239
244
training_file : str ,
240
- model : str ,
245
+ model : str | None = None ,
241
246
n_epochs : int = 1 ,
242
247
validation_file : str | None = "" ,
243
248
n_evals : int | None = 0 ,
@@ -270,7 +275,7 @@ def create(
270
275
271
276
Args:
272
277
training_file (str): File-ID of a file uploaded to the Together API
273
- model (str): Name of the base model to run fine-tune job on
278
+ model (str, optional ): Name of the base model to run fine-tune job on
274
279
n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1.
275
280
validation file (str, optional): File ID of a file uploaded to the Together API for validation.
276
281
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
@@ -320,12 +325,24 @@ def create(
320
325
FinetuneResponse: Object containing information about fine-tuning job.
321
326
"""
322
327
328
+ if model is None and from_checkpoint is None :
329
+ raise ValueError ("You must specify either a model or a checkpoint" )
330
+
323
331
requestor = api_requestor .APIRequestor (
324
332
client = self ._client ,
325
333
)
326
334
327
335
if model_limits is None :
328
- model_limits = self .get_model_limits (model = model )
336
+ # mypy doesn't understand that model or from_checkpoint is not None
337
+ if model is not None :
338
+ model_name = model
339
+ elif from_checkpoint is not None :
340
+ model_name = from_checkpoint .split (":" )[0 ]
341
+ else :
342
+ # this branch is unreachable, but mypy doesn't know that
343
+ pass
344
+ model_limits = self .get_model_limits (model = model_name )
345
+
329
346
finetune_request = createFinetuneRequest (
330
347
model_limits = model_limits ,
331
348
training_file = training_file ,
@@ -610,7 +627,7 @@ async def create(
610
627
self ,
611
628
* ,
612
629
training_file : str ,
613
- model : str ,
630
+ model : str | None = None ,
614
631
n_epochs : int = 1 ,
615
632
validation_file : str | None = "" ,
616
633
n_evals : int | None = 0 ,
@@ -643,7 +660,7 @@ async def create(
643
660
644
661
Args:
645
662
training_file (str): File-ID of a file uploaded to the Together API
646
- model (str): Name of the base model to run fine-tune job on
663
+ model (str, optional ): Name of the base model to run fine-tune job on
647
664
n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1.
648
665
validation file (str, optional): File ID of a file uploaded to the Together API for validation.
649
666
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
@@ -693,12 +710,23 @@ async def create(
693
710
FinetuneResponse: Object containing information about fine-tuning job.
694
711
"""
695
712
713
+ if model is None and from_checkpoint is None :
714
+ raise ValueError ("You must specify either a model or a checkpoint" )
715
+
696
716
requestor = api_requestor .APIRequestor (
697
717
client = self ._client ,
698
718
)
699
719
700
720
if model_limits is None :
701
- model_limits = await self .get_model_limits (model = model )
721
+ # mypy doesn't understand that model or from_checkpoint is not None
722
+ if model is not None :
723
+ model_name = model
724
+ elif from_checkpoint is not None :
725
+ model_name = from_checkpoint .split (":" )[0 ]
726
+ else :
727
+ # this branch is unreachable, but mypy doesn't know that
728
+ pass
729
+ model_limits = await self .get_model_limits (model = model_name )
702
730
703
731
finetune_request = createFinetuneRequest (
704
732
model_limits = model_limits ,
0 commit comments