Skip to content

Commit

Permalink
Rename some methods for datasets while it is beta-feature (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
vhaldemar authored Jan 21, 2025
1 parent d4e6e9c commit f88a3e5
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 12 deletions.
11 changes: 5 additions & 6 deletions examples/async/tuning/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,23 @@ async def main() -> None:
folder_id='b1ghsjum2v37c2un8h64',
)

dataset_draft = sdk.datasets.from_path_deferred(
dataset_draft = sdk.datasets.draft_from_path(
task_type='TextToTextGeneration',
path=local_path('example_dataset'),
upload_format='jsonlines',
name='foo',
)

operation = await dataset_draft.upload()
dataset = await operation
dataset = await dataset_draft.upload()
print(f'new {dataset=}')

dataset_draft = sdk.datasets.completions.from_path_deferred(
dataset_draft = sdk.datasets.completions.draft_from_path(
local_path('example_bad_dataset')
)
dataset_draft.upload_format = 'jsonlines'
dataset_draft.name = 'foo'

operation = await dataset_draft.upload()
operation = await dataset_draft.upload_deferred()
try:
dataset = await operation
except DatasetValidationError as error:
Expand All @@ -44,7 +43,7 @@ async def main() -> None:
print(f"going to delete {bad_dataset=}")
await bad_dataset.delete()

operation = await dataset_draft.upload(raise_on_validation_failure=False)
operation = await dataset_draft.upload_deferred(raise_on_validation_failure=False)
bad_dataset = await operation
print(f"New {bad_dataset=} have a bad status {dataset.status=}")
await dataset.delete()
Expand Down
2 changes: 1 addition & 1 deletion src/yandex_cloud_ml_sdk/_datasets/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class BaseDatasets(BaseDomain, Generic[DatasetTypeT, DatasetDraftT]):
text_classifiers_multilabel = TaskTypeProxy(KnownTaskType.TextClassificationMultilabel)
text_classifiers_multiclass = TaskTypeProxy(KnownTaskType.TextClassificationMulticlass)

def from_path_deferred(
def draft_from_path(
self,
path: PathLike,
*,
Expand Down
62 changes: 59 additions & 3 deletions src/yandex_cloud_ml_sdk/_datasets/draft.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ async def _validate_deferred(
default_poll_timeout=DEFAULT_OPERATION_POLL_TIMEOUT,
)

async def _upload(
async def _upload_deferred(
self,
*,
timeout: float = 60,
Expand Down Expand Up @@ -154,41 +154,97 @@ async def _upload(
)
return operation

async def _upload(
self,
*,
timeout: float = 60,
poll_timeout: int = DEFAULT_OPERATION_POLL_TIMEOUT,
poll_interval: float = 60,
**kwargs,
) -> DatasetTypeT:
operation = await self._upload_deferred(
**kwargs,
timeout=timeout,
)
# pylint: disable=protected-access
result = await operation._wait(
timeout=timeout,
poll_timeout=poll_timeout,
poll_interval=poll_interval,
)
return result



class AsyncDatasetDraft(BaseDatasetDraft[AsyncDataset, AsyncOperation[AsyncDataset]]):
_dataset_impl = AsyncDataset
_operation_impl = AsyncOperation[AsyncDataset]

async def upload(
async def upload_deferred(
self,
*,
timeout: float = 60,
upload_timeout: float = 360,
raise_on_validation_failure: bool = True,
) -> AsyncOperation[AsyncDataset]:
return await self._upload_deferred(
timeout=timeout,
upload_timeout=upload_timeout,
raise_on_validation_failure=raise_on_validation_failure,
)

async def upload(
self,
*,
timeout: float = 60,
upload_timeout: float = 360,
raise_on_validation_failure: bool = True,
poll_timeout: int = DEFAULT_OPERATION_POLL_TIMEOUT,
poll_interval: float = 60,
):
return await self._upload(
timeout=timeout,
upload_timeout=upload_timeout,
raise_on_validation_failure=raise_on_validation_failure,
poll_timeout=poll_timeout,
poll_interval=poll_interval
)


class DatasetDraft(BaseDatasetDraft[Dataset, Operation[Dataset]]):
_dataset_impl = Dataset
_operation_impl = Operation[Dataset]
__upload_deferred = run_sync(BaseDatasetDraft._upload_deferred)
__upload = run_sync(BaseDatasetDraft._upload)

def upload(
def upload_deferred(
self,
*,
timeout: float = 60,
upload_timeout: float = 360,
raise_on_validation_failure: bool = True,
) -> Operation[Dataset]:
return self.__upload_deferred(
timeout=timeout,
upload_timeout=upload_timeout,
raise_on_validation_failure=raise_on_validation_failure,
)

def upload(
self,
*,
timeout: float = 60,
upload_timeout: float = 360,
raise_on_validation_failure: bool = True,
poll_timeout: int = DEFAULT_OPERATION_POLL_TIMEOUT,
poll_interval: float = 60,
):
return self.__upload(
timeout=timeout,
upload_timeout=upload_timeout,
raise_on_validation_failure=raise_on_validation_failure,
poll_timeout=poll_timeout,
poll_interval=poll_interval
)


Expand Down
4 changes: 2 additions & 2 deletions src/yandex_cloud_ml_sdk/_datasets/task_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def task_type(self) -> str:
return self._task_type

@property
def from_path_deferred(self):
return partial(self._domain.from_path_deferred, task_type=self._task_type)
def draft_from_path(self):
return partial(self._domain.draft_from_path, task_type=self._task_type)

@property
def list_upload_formats(self):
Expand Down

0 comments on commit f88a3e5

Please sign in to comment.