Skip to content

Commit

Permalink
Improve example & fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vhaldemar committed Dec 6, 2024
1 parent 1ec3608 commit f174e8d
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 5 deletions.
14 changes: 10 additions & 4 deletions examples/async/tuning/tuning_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,18 @@ async def main() -> None:
description="cool tuning",
labels={'good': 'yes'},
seed=500,
lr=0.0005,
lr=1e-4,
n_samples=100,
tuning_type=tt.TuningTypePromptTune(virtual_tokens=50),
scheduler=ts.SchedulerConstant(warmup_ratio=0.1),
tuning_type=tt.TuningTypePromptTune(virtual_tokens=20),
scheduler=ts.SchedulerLinear(
warmup_ratio=10,
min_lr=0
),
optimizer=to.OptimizerAdamw(
beta1=0.5
beta1=0.9,
beta2=0.999,
eps=1e-8,
weight_decay=0.1,
)
)

Expand Down
7 changes: 7 additions & 0 deletions src/yandex_cloud_ml_sdk/_models/completions/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,9 @@ async def run_deferred(
timeout=timeout,
)

async def attach_deferred(self, operation_id: str, timeout: float = 60) -> AsyncOperation[GPTModelResult]:
return await self._attach_deferred(operation_id=operation_id, timeout=timeout)

async def tokenize(
self,
messages: MessageInputType,
Expand Down Expand Up @@ -327,6 +330,7 @@ class GPTModel(BaseGPTModel[Operation[GPTModelResult], TuningTask['GPTModel']]):
__run = run_sync(BaseGPTModel._run)
__run_stream = run_sync_generator(BaseGPTModel._run_stream)
__run_deferred = run_sync(BaseGPTModel._run_deferred)
__attach_deferred = run_sync(BaseGPTModel._attach_deferred)
__tokenize = run_sync(BaseGPTModel._tokenize)
__tune_deferred = run_sync(BaseGPTModel._tune_deferred)
__tune = run_sync(BaseGPTModel._tune)
Expand Down Expand Up @@ -365,6 +369,9 @@ def run_deferred(
timeout=timeout,
)

def attach_deferred(self, operation_id: str, timeout: float = 60) -> Operation[GPTModelResult]:
return self.__attach_deferred(operation_id=operation_id, timeout=timeout)

def tokenize(
self,
messages: MessageInputType,
Expand Down
7 changes: 7 additions & 0 deletions src/yandex_cloud_ml_sdk/_models/image_generation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,14 @@ async def run_deferred(
timeout=timeout
)

async def attach_deferred(self, operation_id: str, timeout: float = 60) -> AsyncOperation[ImageGenerationModelResult]:
return await self._attach_deferred(operation_id=operation_id, timeout=timeout)


class ImageGenerationModel(BaseImageGenerationModel[Operation[ImageGenerationModelResult]]):
_operation_type = Operation[ImageGenerationModelResult]
__run_deferred = run_sync(BaseImageGenerationModel[Operation[ImageGenerationModelResult]]._run_deferred)
__attach_deferred = run_sync(BaseImageGenerationModel[Operation[ImageGenerationModelResult]]._attach_deferred)

def run_deferred(
self,
Expand All @@ -111,3 +115,6 @@ def run_deferred(
messages=messages,
timeout=timeout
)

def attach_deferred(self, operation_id: str, timeout: float = 60) -> Operation[ImageGenerationModelResult]:
return self.__attach_deferred(operation_id=operation_id, timeout=timeout)
2 changes: 1 addition & 1 deletion src/yandex_cloud_ml_sdk/_types/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async def _run_deferred(self, *args, **kwargs) -> OperationTypeT:
pass

# pylint: disable=unused-argument
async def attach_deferred(self, operation_id: str, timeout: float = 60) -> OperationTypeT:
async def _attach_deferred(self, operation_id: str, timeout: float = 60) -> OperationTypeT:
return self._operation_type(
id=operation_id,
sdk=self._sdk,
Expand Down

0 comments on commit f174e8d

Please sign in to comment.