Skip to content
This repository has been archived by the owner on Oct 2, 2024. It is now read-only.

Commit

Permalink
feat: add datasets (#3)
Browse files Browse the repository at this point in the history
* feat: http client

* feat: Define workspace classes

* tests: Add workspace related tests

* chore: Configure pre-commit hooks

* ci: setup basic ci workflow

* chore: format files

* feat: add dataset classes

* feat: retrieve workspace datasets

* chore: class rename

* chore: review Dataset class definition

* refactor: Moving workspace dataset collection

* feat: Add generic iterator class

* fix: dataseat.create should use self.client

* Apply suggestions from code review

Co-authored-by: José Francisco Calvo <[email protected]>

---------

Co-authored-by: José Francisco Calvo <[email protected]>
  • Loading branch information
frascuchon and jfcalvo authored Jan 17, 2024
1 parent 01f1773 commit 903afa6
Show file tree
Hide file tree
Showing 10 changed files with 554 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/argilla_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import httpx

from argilla_sdk._api import HTTPClientConfig, create_http_client # noqa
from argilla_sdk.datasets import * # noqa
from argilla_sdk.workspaces import * # noqa

DEFAULT_HTTP_CLIENT: Optional[httpx.Client] = None
Expand Down
1 change: 1 addition & 0 deletions src/argilla_sdk/_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from argilla_sdk._api._datasets import * # noqa 403
from argilla_sdk._api._http import * # noqa 403
from argilla_sdk._api._workspaces import * # noqa 403
135 changes: 135 additions & 0 deletions src/argilla_sdk/_api/_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright 2024-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
from dataclasses import dataclass, field
from typing import List, Literal, Optional
from uuid import UUID

import httpx

import argilla_sdk
from argilla_sdk._api import _http

__all__ = ["Dataset"]


@dataclass
class Dataset:
name: str
status: Literal["draft", "ready"] = "draft"
guidelines: Optional[str] = None
allow_extra_metadata: bool = True

id: Optional[UUID] = None
workspace_id: Optional[UUID] = None
inserted_at: Optional[datetime.datetime] = None
updated_at: Optional[datetime.datetime] = None
last_activity_at: Optional[datetime.datetime] = None

client: Optional[httpx.Client] = field(default=None, repr=False, compare=False)

def to_dict(self) -> dict:
return {
"id": self.id,
"name": self.name,
"guidelines": self.guidelines,
"allow_extra_metadata": self.allow_extra_metadata,
"workspace_id": self.workspace_id,
"inserted_at": self.inserted_at,
"updated_at": self.updated_at,
}

@classmethod
def from_dict(cls, data: dict) -> "Dataset":
return cls(**data)

@classmethod
def list(cls, workspace_id: Optional[UUID] = None) -> List["Dataset"]:
client = argilla_sdk.get_default_http_client()

response = client.get("/api/v1/me/datasets")
_http.raise_for_status(response)

json_response = response.json()
datasets = [cls._create_from_json(json_dataset, client) for json_dataset in json_response["items"]]

if workspace_id:
datasets = [dataset for dataset in datasets if dataset.workspace_id == workspace_id]

return datasets

@classmethod
def get(cls, dataset_id: UUID) -> "Dataset":
client = argilla_sdk.get_default_http_client()

response = client.get(f"/api/v1/datasets/{dataset_id}")
_http.raise_for_status(response)

return cls._create_from_json(response.json(), client)

@classmethod
def get_by_name_and_workspace_id(cls, name: str, workspace_id: UUID) -> Optional["Dataset"]:
datasets = cls.list(workspace_id=workspace_id)

for dataset in datasets:
if dataset.name == name:
return dataset

def create(self) -> "Dataset":
body = {
"name": self.name,
"workspace_id": self.workspace_id,
"guidelines": self.guidelines,
"allow_extra_metadata": self.allow_extra_metadata,
}

response = self.client.post("/api/v1/datasets", json=body)
_http.raise_for_status(response)

return self._update_from_api_response(response)

def update(self) -> "Dataset":
body = {
"guidelines": self.guidelines,
"allow_extra_metadata": self.allow_extra_metadata,
}

response = self.client.patch(f"/api/v1/datasets/{self.id}", json=body)

_http.raise_for_status(response)

return self._update_from_api_response(response)

def delete(self) -> "Dataset":
response = self.client.delete(f"/api/v1/datasets/{self.id}")
_http.raise_for_status(response)

return self._update_from_api_response(response)

def publish(self) -> "Dataset":
response = self.client.put(f"/api/v1/datasets/{self.id}/publish")
_http.raise_for_status(response)

return self._update_from_api_response(response)

@classmethod
def _create_from_json(cls, json: dict, client: httpx.Client) -> "Dataset":
return cls.from_dict(dict(**json, client=client))

def _update_from_api_response(self, response: httpx.Response) -> "Dataset":
new_instance = self._create_from_json(response.json(), client=self.client)
self.__dict__.update(new_instance.__dict__)

return self
14 changes: 14 additions & 0 deletions src/argilla_sdk/_api/_http/_helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright 2024-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import httpx
from httpx import HTTPStatusError

Expand Down
14 changes: 14 additions & 0 deletions src/argilla_sdk/_helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2024-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

35 changes: 35 additions & 0 deletions src/argilla_sdk/_helpers/_iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2024-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Generic, List, TypeVar

Item = TypeVar("Item")


class GenericIterator(Generic[Item]):
"""Generic iterator for any collection of items."""

def __init__(self, collection: List[Item]):
self._collection = [v for v in collection]
self._index = 0

def __iter__(self):
return self

def __next__(self):
if self._index < len(self._collection):
result = self._collection[self._index]
self._index += 1
return result
raise StopIteration
63 changes: 63 additions & 0 deletions src/argilla_sdk/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2024-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, TYPE_CHECKING

from argilla_sdk import _api
from argilla_sdk._helpers._iterator import GenericIterator # noqa

if TYPE_CHECKING:
from argilla_sdk import Workspace

__all__ = ["Dataset", "WorkspaceDatasets"]


class Dataset(_api.Dataset):
@property
def workspace(self) -> Optional["Workspace"]:
from argilla_sdk.workspaces import Workspace

if self.workspace_id:
return Workspace.get(self.workspace_id)

@workspace.setter
def workspace(self, workspace: "Workspace") -> None:
self.workspace_id = workspace.id

@classmethod
def get_by_name_and_workspace(cls, name: str, workspace: "Workspace") -> Optional["Dataset"]:
return cls.get_by_name_and_workspace_id(name, workspace.id)


DatasetsIterator = GenericIterator["Dataset"]


class WorkspaceDatasets:
def __init__(self, workspace: "Workspace"):
self.workspace = workspace

def list(self) -> List["Dataset"]:
from argilla_sdk import Dataset

return Dataset.list(workspace_id=self.workspace.id)

def add(self, dataset: "Dataset") -> "Dataset":
dataset.workspace_id = self.workspace.id
return dataset.create()

def get_by_name(self, name: str) -> "Dataset":
return Dataset.get_by_name_and_workspace(name=name, workspace=self.workspace)

def __iter__(self):
return DatasetsIterator(self.list())
12 changes: 11 additions & 1 deletion src/argilla_sdk/workspaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from argilla_sdk import _api
from argilla_sdk._helpers._iterator import GenericIterator # noqa

if TYPE_CHECKING:
from argilla_sdk.datasets import WorkspaceDatasets

__all__ = ["Workspace"]


class Workspace(_api.Workspace):
pass
@property
def datasets(self) -> "WorkspaceDatasets":
from argilla_sdk.datasets import WorkspaceDatasets

return WorkspaceDatasets(self)
Loading

0 comments on commit 903afa6

Please sign in to comment.