diff --git a/CHANGELOG.md b/CHANGELOG.md index d3b0b4d3ec..032f6ed234 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,8 +18,10 @@ These are the section headers that we use: ### Added +- If you expand the labels of a `single or multi` label Question, the state is maintained during the entire annotation process. ([#4630](https://github.com/argilla-io/argilla/pull/4630)) +- Added support for span questions in the Python SDK. ([#4617](https://github.com/argilla-io/argilla/pull/4617)) +- Added support for spans values in suggestions and responses. ([#4623](https://github.com/argilla-io/argilla/pull/4623)) - Added `Span` questions for `FeedbackDataset` ([#4622](https://github.com/argilla-io/argilla/pull/4622)) -- If you expand the labels of a `single or multi` label Question, the state is maintained during the entire annotation process ([#4630](https://github.com/argilla-io/argilla/pull/4630)) ### Fixed diff --git a/environment_dev.yml b/environment_dev.yml index a04197db63..a596ddd667 100644 --- a/environment_dev.yml +++ b/environment_dev.yml @@ -66,6 +66,6 @@ dependencies: - ipynbname>=2023.2.0.0 - httpx~=0.26.0 # For now we can just install argilla-server from the GitHub repo - - git+https://github.com/argilla-io/argilla-server.git@main + - git+https://github.com/argilla-io/argilla-server.git # install Argilla in editable mode - -e .[listeners] diff --git a/pyproject.toml b/pyproject.toml index 3abd15996b..7c96758d94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,10 +51,10 @@ dynamic = ["version"] [project.optional-dependencies] server = [ - "argilla-server ~= 1.25.0", # TODO: Fix the versions when publishing a new release + "argilla-server", # TODO: Fix the versions when publishing a new release ] server-postgresql = [ - "argilla-server[postgresql] ~= 1.25.0", # TODO: Fix the versions when publishing a new release + "argilla-server[postgresql]", # TODO: Fix the versions when publishing a new release ] listeners = ["schedule ~= 1.1.0", "prodict ~= 0.8.0"] integrations = [ diff --git a/src/argilla/__init__.py b/src/argilla/__init__.py index 594dbc2853..0d9deb9683 100644 --- a/src/argilla/__init__.py +++ b/src/argilla/__init__.py @@ -74,30 +74,7 @@ configure_dataset_settings, load_dataset_settings, ) - from argilla.feedback import ( - FeedbackDataset, - FeedbackRecord, - FloatMetadataFilter, - FloatMetadataProperty, - IntegerMetadataFilter, - IntegerMetadataProperty, - LabelQuestion, - MultiLabelQuestion, - RankingQuestion, - RatingQuestion, - RecordSortField, - ResponseSchema, - ResponseStatusFilter, - SortBy, - SortOrder, - SuggestionSchema, - TermsMetadataFilter, - TermsMetadataProperty, - TextField, - TextQuestion, - ValueSchema, - VectorSettings, - ) + from argilla.feedback import * # noqa from argilla.listeners import Metrics, RGListenerContext, Search, listener from argilla.monitoring.model_monitor import monitor @@ -115,6 +92,8 @@ "MultiLabelQuestion", "RatingQuestion", "RankingQuestion", + "SpanQuestion", + "SpanLabelOption", "ResponseSchema", "ResponseStatusFilter", "TextField", diff --git a/src/argilla/client/feedback/dataset/local/mixins.py b/src/argilla/client/feedback/dataset/local/mixins.py index e73697fe45..48b1f1ff76 100644 --- a/src/argilla/client/feedback/dataset/local/mixins.py +++ b/src/argilla/client/feedback/dataset/local/mixins.py @@ -34,11 +34,8 @@ RemoteTermsMetadataProperty, ) from argilla.client.feedback.schemas.remote.questions import ( - RemoteLabelQuestion, - RemoteMultiLabelQuestion, - RemoteRankingQuestion, - RemoteRatingQuestion, - RemoteTextQuestion, + QUESTION_TYPE_TO_QUESTION, + AllowedRemoteQuestionTypes, ) from argilla.client.feedback.schemas.types import AllowedMetadataPropertyTypes from argilla.client.feedback.schemas.vector_settings import VectorSettings @@ -125,20 +122,13 @@ def __add_fields( @staticmethod def _parse_to_remote_question(question: "FeedbackQuestionModel") -> "AllowedRemoteQuestionTypes": - if question.settings["type"] == QuestionTypes.rating: - question = RemoteRatingQuestion.from_api(question) - elif question.settings["type"] == QuestionTypes.text: - question = RemoteTextQuestion.from_api(question) - elif question.settings["type"] == QuestionTypes.label_selection: - question = RemoteLabelQuestion.from_api(question) - elif question.settings["type"] == QuestionTypes.multi_label_selection: - question = RemoteMultiLabelQuestion.from_api(question) - elif question.settings["type"] == QuestionTypes.ranking: - question = RemoteRankingQuestion.from_api(question) + question_type = question.settings["type"] + if question_type in QUESTION_TYPE_TO_QUESTION: + question = QUESTION_TYPE_TO_QUESTION[question_type].from_api(question) else: raise ValueError( f"Question '{question.name}' is not a supported question in the current Python package" - f" version, supported question types are: `{'`, `'.join([arg.value for arg in QuestionTypes])}`." + f" version, supported question types are: `{'`, `'.join(QuestionTypes.values())}`." ) return question diff --git a/src/argilla/client/feedback/integrations/huggingface/dataset.py b/src/argilla/client/feedback/integrations/huggingface/dataset.py index 06ee29dfe1..71fda2d26d 100644 --- a/src/argilla/client/feedback/integrations/huggingface/dataset.py +++ b/src/argilla/client/feedback/integrations/huggingface/dataset.py @@ -11,11 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import json import logging import tempfile import warnings +from copy import copy from typing import TYPE_CHECKING, Any, Optional, Type, Union from packaging.version import parse as parse_version @@ -50,7 +50,7 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset questions, and metadata_properties formatted as `datasets.Features`. Examples: - >>> from argilla.client.feedback.integrations.dataset import HuggingFaceDatasetMixin + >>> from argilla.client.feedback.integrations.huggingface import HuggingFaceDatasetMixin >>> dataset = FeedbackDataset(...) or RemoteFeedbackDataset(...) >>> huggingface_dataset = HuggingFaceDatasetMixin._huggingface_format(dataset) """ @@ -71,17 +71,40 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset for question in dataset.questions: if question.type in [QuestionTypes.text, QuestionTypes.label_selection]: value = Value(dtype="string", id="question") + suggestion_value = copy(value) elif question.type == QuestionTypes.rating: value = Value(dtype="int32", id="question") + suggestion_value = copy(value) elif question.type == QuestionTypes.ranking: value = Sequence({"rank": Value(dtype="uint8"), "value": Value(dtype="string")}, id="question") + suggestion_value = copy(value) elif question.type in QuestionTypes.multi_label_selection: value = Sequence(Value(dtype="string"), id="question") + suggestion_value = copy(value) + elif question.type in QuestionTypes.span: + value = Sequence( + { + "start": Value(dtype="int32"), + "end": Value(dtype="int32"), + "label": Value(dtype="string"), + "text": Value(dtype="string"), + }, + id="question", + ) + suggestion_value = Sequence( + { + "start": Value(dtype="int32"), + "end": Value(dtype="int32"), + "label": Value(dtype="string"), + "text": Value(dtype="string"), + "score": Value(dtype="float32"), + } + ) else: raise ValueError( f"Question {question.name} is of type `{question.type}`," " for the moment only the following question types are supported:" - f" `{'`, `'.join([arg.value for arg in QuestionTypes])}`." + f" `{'`, `'.join(QuestionTypes.values())}`." ) hf_features[question.name] = [ @@ -94,8 +117,8 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset if question.name not in hf_dataset: hf_dataset[question.name] = [] - value.id = "suggestion" - hf_features[f"{question.name}-suggestion"] = value + suggestion_value.id = "suggestion" + hf_features[f"{question.name}-suggestion"] = suggestion_value if f"{question.name}-suggestion" not in hf_dataset: hf_dataset[f"{question.name}-suggestion"] = [] @@ -138,6 +161,16 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset } if question.type == QuestionTypes.ranking: value = [r.dict() for r in response.values[question.name].value] + elif question.type == QuestionTypes.span: + value = [ + { + "start": span.start, + "end": span.end, + "label": span.label, + "text": record.fields[question.field][span.start : span.end], + } + for span in response.values[question.name].value + ] else: value = response.values[question.name].value formatted_response["value"] = value @@ -148,7 +181,20 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset if record.suggestions: for suggestion in record.suggestions: if question.name == suggestion.question_name: - suggestion_value = suggestion.dict(include={"value"})["value"] + if question.type == QuestionTypes.span: + suggestion_value = [ + { + "start": span.start, + "end": span.end, + "label": span.label, + "score": span.score, + "text": record.fields[question.field][span.start : span.end], + } + for span in suggestion.value + ] + else: + suggestion_value = suggestion.dict(include={"value"})["value"] + suggestion_metadata = { "type": suggestion.type, "score": suggestion.score, @@ -421,6 +467,11 @@ def from_huggingface( if value is not None: if question.type == QuestionTypes.ranking: value = [{"rank": r, "value": v} for r, v in zip(value["rank"], value["value"])] + elif question.type == QuestionTypes.span: + value = [ + {"start": s, "end": e, "label": l} + for s, e, l in zip(value["start"], value["end"], value["label"]) + ] responses[user_id or "user_without_id"]["values"].update({question.name: {"value": value}}) # First if-condition is here for backwards compatibility @@ -431,6 +482,11 @@ def from_huggingface( value = hfds[index][f"{question.name}-suggestion"] if question.type == QuestionTypes.ranking: value = [{"rank": r, "value": v} for r, v in zip(value["rank"], value["value"])] + elif question.type == QuestionTypes.span: + value = [ + {"start": s, "end": e, "label": l} + for s, e, l in zip(value["start"], value["end"], value["label"]) + ] suggestion = {"question_name": question.name, "value": value} if hfds[index][f"{question.name}-suggestion-metadata"] is not None: diff --git a/src/argilla/client/feedback/metrics/utils.py b/src/argilla/client/feedback/metrics/utils.py index 99eb9e880b..f4e05e8b95 100644 --- a/src/argilla/client/feedback/metrics/utils.py +++ b/src/argilla/client/feedback/metrics/utils.py @@ -191,7 +191,7 @@ def get_unified_responses_and_suggestions( unified_responses = [ tuple(ranking_schema.rank for ranking_schema in response) for response in unified_responses ] - suggestions = [tuple(s["rank"] for s in suggestion) for suggestion in suggestions] + suggestions = [tuple(s.rank for s in suggestion) for suggestion in suggestions] return unified_responses, suggestions diff --git a/src/argilla/client/feedback/schemas/__init__.py b/src/argilla/client/feedback/schemas/__init__.py index 1c5f2abdf4..e776355f89 100644 --- a/src/argilla/client/feedback/schemas/__init__.py +++ b/src/argilla/client/feedback/schemas/__init__.py @@ -33,16 +33,14 @@ QuestionSchema, RankingQuestion, RatingQuestion, + SpanLabelOption, + SpanQuestion, TextQuestion, ) -from argilla.client.feedback.schemas.records import ( - FeedbackRecord, - RankingValueSchema, - ResponseSchema, - SortBy, - SuggestionSchema, - ValueSchema, -) +from argilla.client.feedback.schemas.records import FeedbackRecord, SortBy +from argilla.client.feedback.schemas.response_values import RankingValueSchema, ResponseValue, SpanValueSchema +from argilla.client.feedback.schemas.responses import ResponseSchema, ResponseStatus, ValueSchema +from argilla.client.feedback.schemas.suggestions import SuggestionSchema from argilla.client.feedback.schemas.vector_settings import VectorSettings __all__ = [ @@ -62,11 +60,16 @@ "RankingQuestion", "RatingQuestion", "TextQuestion", + "SpanQuestion", + "SpanLabelOption", "FeedbackRecord", - "RankingValueSchema", "ResponseSchema", + "ResponseValue", + "ResponseStatus", "SuggestionSchema", "ValueSchema", + "RankingValueSchema", + "SpanValueSchema", "SortOrder", "SortBy", "RecordSortField", diff --git a/src/argilla/client/feedback/schemas/enums.py b/src/argilla/client/feedback/schemas/enums.py index ef053e5908..bde8f5164c 100644 --- a/src/argilla/client/feedback/schemas/enums.py +++ b/src/argilla/client/feedback/schemas/enums.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from enum import Enum +from typing import List class FieldTypes(str, Enum): @@ -25,6 +25,11 @@ class QuestionTypes(str, Enum): label_selection = "label_selection" multi_label_selection = "multi_label_selection" ranking = "ranking" + span = "span" + + @classmethod + def values(cls) -> List[str]: + return [_type.value for _type in cls] class MetadataPropertyTypes(str, Enum): diff --git a/src/argilla/client/feedback/schemas/questions.py b/src/argilla/client/feedback/schemas/questions.py index 7acda4ee8c..42317bde36 100644 --- a/src/argilla/client/feedback/schemas/questions.py +++ b/src/argilla/client/feedback/schemas/questions.py @@ -17,6 +17,9 @@ from typing import Any, Dict, List, Literal, Optional, Union from argilla.client.feedback.schemas.enums import QuestionTypes +from argilla.client.feedback.schemas.response_values import parse_value_response_for_question +from argilla.client.feedback.schemas.responses import ResponseValue, ValueSchema +from argilla.client.feedback.schemas.suggestions import SuggestionSchema from argilla.client.feedback.schemas.utils import LabelMappingMixin from argilla.client.feedback.schemas.validators import title_must_have_value from argilla.pydantic_v1 import BaseModel, Extra, Field, conint, conlist, root_validator, validator @@ -77,6 +80,16 @@ def to_server_payload(self) -> Dict[str, Any]: "settings": self.server_settings, } + def suggestion(self, value: ResponseValue, **kwargs) -> SuggestionSchema: + """Method that will be used to create a `SuggestionSchema` from the question and a suggested value.""" + value = parse_value_response_for_question(self, value) + return SuggestionSchema(question_name=self.name, value=value, **kwargs) + + def response(self, value: ResponseValue) -> Dict[str, ValueSchema]: + """Method that will be used to create a response from the question and a value.""" + value = parse_value_response_for_question(self, value) + return {self.name: ValueSchema(value=value)} + class TextQuestion(QuestionSchema): """Schema for the `FeedbackDataset` text questions, which are the ones that will @@ -93,7 +106,7 @@ class TextQuestion(QuestionSchema): >>> TextQuestion(name="text_question", title="Text Question") """ - type: Literal[QuestionTypes.text] = Field(QuestionTypes.text.value, allow_mutation=False) + type: Literal[QuestionTypes.text] = Field(QuestionTypes.text.value, allow_mutation=False, const=True) use_markdown: bool = False @property @@ -120,7 +133,7 @@ class RatingQuestion(QuestionSchema, LabelMappingMixin): >>> RatingQuestion(name="rating_question", title="Rating Question", values=[1, 2, 3, 4, 5]) """ - type: Literal[QuestionTypes.rating] = Field(QuestionTypes.rating.value, allow_mutation=False) + type: Literal[QuestionTypes.rating] = Field(QuestionTypes.rating.value, allow_mutation=False, const=True) values: List[int] = Field(..., unique_items=True, ge=1, le=10, min_items=2) @property @@ -230,7 +243,9 @@ class LabelQuestion(_LabelQuestion): >>> LabelQuestion(name="label_question", title="Label Question", labels=["label_1", "label_2"]) """ - type: Literal[QuestionTypes.label_selection] = Field(QuestionTypes.label_selection.value, allow_mutation=False) + type: Literal[QuestionTypes.label_selection] = Field( + QuestionTypes.label_selection.value, allow_mutation=False, const=True + ) class MultiLabelQuestion(_LabelQuestion): @@ -254,7 +269,7 @@ class MultiLabelQuestion(_LabelQuestion): """ type: Literal[QuestionTypes.multi_label_selection] = Field( - QuestionTypes.multi_label_selection.value, allow_mutation=False + QuestionTypes.multi_label_selection.value, allow_mutation=False, const=True ) @@ -277,7 +292,7 @@ class RankingQuestion(QuestionSchema, LabelMappingMixin): >>> RankingQuestion(name="ranking_question", title="Ranking Question", values=["label_1", "label_2"]) """ - type: Literal[QuestionTypes.ranking] = Field(QuestionTypes.ranking.value, allow_mutation=False) + type: Literal[QuestionTypes.ranking] = Field(QuestionTypes.ranking.value, allow_mutation=False, const=True) values: Union[conlist(str, unique_items=True, min_items=2), Dict[str, str]] @validator("values", always=True) @@ -296,3 +311,112 @@ def server_settings(self) -> Dict[str, Any]: elif isinstance(self.values, list): settings["options"] = [{"value": label, "text": label} for label in self.values] return settings + + +class SpanLabelOption(BaseModel): + """Schema for the `FeedbackDataset` span label options, which are the ones that will be + used in the `SpanQuestion` to define the labels that the user can select. + + Args: + value: The value of the span label. This is the value that will be shown in the UI. + text: The text of the span label. This is the text that will be shown in the UI. + + Examples: + >>> from argilla.client.feedback.schemas.questions import SpanLabelOption + >>> SpanLabelOption(value="span_label_1", text="Span Label 1") + """ + + value: str + text: Optional[str] + description: Optional[str] + + def __eq__(self, other): + return other and self.value == other.value + + @validator("text", pre=True, always=True) + def default_text_value(cls, v: str, values: Dict[str, Any]) -> str: + if v is None: + return values["value"] + return v + + +class SpanQuestion(QuestionSchema): + """Schema for the `FeedbackDataset` span questions, which are the ones that will + require a span response from the user. More specifically, the user will be asked + to select a span of text from the input. + + Examples: + >>> from argilla.client.feedback.schemas.questions import SpanQuestion + >>> SpanQuestion(name="span_question", field="prompt", title="Span Question", labels=["person", "org"]) + """ + + _DEFAULT_MAX_VISIBLE_LABELS = 20 + _MIN_VISIBLE_LABELS = 3 + + type: Literal[QuestionTypes.span] = Field(QuestionTypes.span, allow_mutation=False, const=True) + + field: str = Field(..., description="The field in the input that the user will be asked to annotate.") + labels: Union[Dict[str, str], conlist(Union[str, SpanLabelOption], min_items=1, unique_items=True)] + visible_labels: Union[conint(ge=3), None] = _DEFAULT_MAX_VISIBLE_LABELS + + @validator("labels", pre=True) + def parse_labels_dict(cls, labels) -> List[SpanLabelOption]: + if isinstance(labels, dict): + return [SpanLabelOption(value=label, text=text) for label, text in labels.items()] + return labels + + @validator("labels", always=True) + def normalize_labels(cls, v: List[Union[str, SpanLabelOption]]) -> List[SpanLabelOption]: + return [SpanLabelOption(value=label, text=label) if isinstance(label, str) else label for label in v] + + @validator("labels") + def labels_must_be_valid(cls, labels: List[SpanLabelOption]) -> List[SpanLabelOption]: + # This validator is needed since the conlist constraint does not work. + assert len(labels) > 0, "At least one label must be provided" + return labels + + @root_validator(skip_on_failure=True) + def check_visible_labels_value(cls, values) -> Optional[int]: + visible_labels_key = "visible_labels" + + v = values[visible_labels_key] + if v is None: + return values + + msg = None + number_of_labels = len(values.get("labels", [])) + + if cls._MIN_VISIBLE_LABELS > number_of_labels < v: + msg = f"Since `labels` has less than {cls._MIN_VISIBLE_LABELS} labels, `visible_labels` will be set to `None`." + v = None + elif v > number_of_labels: + msg = ( + f"`visible_labels={v}` is greater than the total number of labels ({number_of_labels}), " + f"so it will be set to `{number_of_labels}`." + ) + v = number_of_labels + + if msg: + warnings.warn(msg, UserWarning, stacklevel=1) + + values[visible_labels_key] = v + return values + + @property + def server_settings(self) -> Dict[str, Any]: + return { + "type": self.type, + "field": self.field, + "visible_options": self.visible_labels, + "options": [label.dict() for label in self.labels], + } + + +AllowedQuestionTypes = Union[ + TextQuestion, + RatingQuestion, + LabelQuestion, + MultiLabelQuestion, + RankingQuestion, + SpanQuestion, +] diff --git a/src/argilla/client/feedback/schemas/records.py b/src/argilla/client/feedback/schemas/records.py index 1b72cb7836..8b997226f3 100644 --- a/src/argilla/client/feedback/schemas/records.py +++ b/src/argilla/client/feedback/schemas/records.py @@ -13,130 +13,21 @@ # limitations under the License. import warnings -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from uuid import UUID -from argilla.client.feedback.schemas.enums import RecordSortField, ResponseStatus, SortOrder -from argilla.pydantic_v1 import BaseModel, Extra, Field, PrivateAttr, StrictInt, StrictStr, conint, validator +from argilla.client.feedback.schemas.enums import RecordSortField, SortOrder + +# Support backward compatibility for import of RankingValueSchema from records module +from argilla.client.feedback.schemas.response_values import RankingValueSchema # noqa +from argilla.client.feedback.schemas.responses import ResponseSchema, ValueSchema # noqa +from argilla.client.feedback.schemas.suggestions import SuggestionSchema +from argilla.pydantic_v1 import BaseModel, Extra, Field, PrivateAttr, validator if TYPE_CHECKING: from argilla.client.feedback.unification import UnifiedValueSchema -class RankingValueSchema(BaseModel): - """Schema for the `RankingQuestion` response value for a `RankingQuestion`. Note that - we may have more than one record in the same rank. - - Args: - value: The value of the record. - rank: The rank of the record. - """ - - value: StrictStr - rank: Optional[conint(ge=1)] = None - - -class ValueSchema(BaseModel): - """Schema for any `FeedbackRecord` response value. - - Args: - value: The value of the record. - """ - - value: Union[StrictStr, StrictInt, List[str], List[RankingValueSchema]] - - -class ResponseSchema(BaseModel): - """Schema for the `FeedbackRecord` response. - - Args: - user_id: ID of the user that provided the response. Defaults to None, and is - automatically fulfilled internally once the question is pushed to Argilla. - values: Values of the response, should match the questions in the record. - status: Status of the response. Defaults to `submitted`. - - Examples: - >>> from argilla.client.feedback.schemas.records import ResponseSchema - >>> ResponseSchema( - ... values={ - ... "question_1": {"value": "answer_1"}, - ... "question_2": {"value": "answer_2"}, - ... } - ... ) - """ - - user_id: Optional[UUID] = None - values: Union[Dict[str, ValueSchema], None] - status: ResponseStatus = ResponseStatus.submitted - - class Config: - extra = Extra.forbid - validate_assignment = True - - @validator("user_id", always=True) - def user_id_must_have_value(cls, v): - if not v: - warnings.warn( - "`user_id` not provided, so it will be set to `None`. Which is not an" - " issue, unless you're planning to log the response in Argilla, as" - " it will be automatically set to the active `user_id`.", - ) - return v - - def to_server_payload(self) -> Dict[str, Any]: - """Method that will be used to create the payload that will be sent to Argilla - to create a `ResponseSchema` for a `FeedbackRecord`.""" - return { - # UUID is not json serializable!!! - "user_id": self.user_id, - "values": {question_name: value.dict() for question_name, value in self.values.items()} - if self.values is not None - else None, - "status": self.status.value if hasattr(self.status, "value") else self.status, - } - - -class SuggestionSchema(BaseModel): - """Schema for the suggestions for the questions related to the record. - - Args: - question_name: name of the question in the `FeedbackDataset`. - type: type of the question. Defaults to None. Possible values are `model` or `human`. - score: score of the suggestion. Defaults to None. - value: value of the suggestion, which should match the type of the question. - agent: agent that generated the suggestion. Defaults to None. - - Examples: - >>> from argilla.client.feedback.schemas.records import SuggestionSchema - >>> SuggestionSchema( - ... question_name="question-1", - ... type="model", - ... score=0.9, - ... value="This is the first suggestion", - ... agent="agent-1", - ... ) - """ - - question_name: str - type: Optional[Literal["model", "human"]] = None - score: Optional[float] = None - value: Any - agent: Optional[str] = None - - class Config: - extra = Extra.forbid - validate_assignment = True - - def to_server_payload(self, question_name_to_id: Dict[str, UUID]) -> Dict[str, Any]: - """Method that will be used to create the payload that will be sent to Argilla - to create a `SuggestionSchema` for a `FeedbackRecord`.""" - # We can do this because there is no default values for the fields - payload = self.dict(exclude_unset=True, include={"type", "score", "value", "agent"}) - payload["question_id"] = str(question_name_to_id[self.question_name]) - - return payload - - class FeedbackRecord(BaseModel): """Schema for the records of a `FeedbackDataset`. @@ -159,7 +50,7 @@ class FeedbackRecord(BaseModel): Defaults to None. Examples: - >>> from argilla.client.feedback.schemas.records import FeedbackRecord, ResponseSchema, SuggestionSchema, ValueSchema + >>> from argilla.feedback import FeedbackRecord, ResponseSchema, SuggestionSchema, ValueSchema >>> FeedbackRecord( ... fields={"text": "This is the first record", "label": "positive"}, ... metadata={"first": True, "nested": {"more": "stuff"}}, @@ -181,6 +72,7 @@ class FeedbackRecord(BaseModel): ... value="This is the first suggestion", ... agent="agent-1", ... ), + ... ], ... external_id="entry-1", ... ) diff --git a/src/argilla/client/feedback/schemas/remote/questions.py b/src/argilla/client/feedback/schemas/remote/questions.py index 272d1d96b7..92cf13ea2b 100644 --- a/src/argilla/client/feedback/schemas/remote/questions.py +++ b/src/argilla/client/feedback/schemas/remote/questions.py @@ -14,11 +14,14 @@ from typing import TYPE_CHECKING, Dict, List, Union +from argilla.client.feedback.schemas import QuestionTypes from argilla.client.feedback.schemas.questions import ( LabelQuestion, MultiLabelQuestion, RankingQuestion, RatingQuestion, + SpanLabelOption, + SpanQuestion, TextQuestion, ) from argilla.client.feedback.schemas.remote.shared import RemoteSchema @@ -146,3 +149,50 @@ def from_api(cls, payload: "FeedbackQuestionModel") -> "RemoteRankingQuestion": required=payload.required, values=_parse_options_from_api(payload), ) + + +class RemoteSpanQuestion(SpanQuestion, RemoteSchema): + def to_local(self) -> SpanQuestion: + return SpanQuestion( + name=self.name, + title=self.title, + field=self.field, + required=self.required, + labels=self.labels, + visible_labels=self.visible_labels, + ) + + @classmethod + def _parse_options_from_api(cls, options: List[Dict[str, str]]) -> List[SpanLabelOption]: + return [SpanLabelOption(value=option["value"], text=option["text"]) for option in options] + + @classmethod + def from_api(cls, payload: "FeedbackQuestionModel") -> "RemoteSpanQuestion": + return RemoteSpanQuestion( + id=payload.id, + name=payload.name, + title=payload.title, + field=payload.settings["field"], + required=payload.required, + visible_labels=payload.settings["visible_options"], + labels=cls._parse_options_from_api(payload.settings["options"]), + ) + + +AllowedRemoteQuestionTypes = Union[ + RemoteTextQuestion, + RemoteRatingQuestion, + RemoteLabelQuestion, + RemoteMultiLabelQuestion, + RemoteRankingQuestion, + RemoteSpanQuestion, +] + +QUESTION_TYPE_TO_QUESTION = { + QuestionTypes.text: RemoteTextQuestion, + QuestionTypes.rating: RemoteRatingQuestion, + QuestionTypes.label_selection: RemoteLabelQuestion, + QuestionTypes.multi_label_selection: RemoteMultiLabelQuestion, + QuestionTypes.ranking: RemoteRankingQuestion, + QuestionTypes.span: RemoteSpanQuestion, +} diff --git a/src/argilla/client/feedback/schemas/response_values.py b/src/argilla/client/feedback/schemas/response_values.py new file mode 100644 index 0000000000..40fc6de7de --- /dev/null +++ b/src/argilla/client/feedback/schemas/response_values.py @@ -0,0 +1,112 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, List, Optional, Union + +from argilla.client.feedback.schemas.enums import QuestionTypes +from argilla.pydantic_v1 import ( + BaseModel, + StrictInt, + StrictStr, + confloat, + conint, + constr, + parse_obj_as, + root_validator, +) + +if TYPE_CHECKING: + from argilla.client.feedback.schemas.questions import QuestionSchema + + +class RankingValueSchema(BaseModel): + """Schema for the `RankingQuestion` response value for a `RankingQuestion`. Note that + we may have more than one record in the same rank. + + Args: + value: The value of the record. + rank: The rank of the record. + """ + + value: StrictStr + rank: Optional[conint(ge=1)] = None + + +class SpanValueSchema(BaseModel): + """Schema for the `SpanQuestion` response value for a `SpanQuestion`. + + Args: + label: The label value of the span. + start: The start of the span. + end: The end of the span. + score: The score of the span. + """ + + label: constr(min_length=1) + start: conint(ge=0) + end: conint(ge=0) + score: Optional[confloat(ge=0.0, le=1.0)] = None + + @root_validator + def check_span(cls, values): + if values["end"] <= values["start"]: + raise ValueError("The end of the span must be greater than the start.") + return values + + +ResponseValue = Union[ + StrictStr, + StrictInt, + List[str], + List[dict], + List[RankingValueSchema], + List[SpanValueSchema], +] + +RESPONSE_VALUE_FOR_QUESTION_TYPE = { + QuestionTypes.text: str, + QuestionTypes.label_selection: str, + QuestionTypes.multi_label_selection: List[str], + QuestionTypes.ranking: List[RankingValueSchema], + QuestionTypes.rating: int, + QuestionTypes.span: List[SpanValueSchema], +} + + +def parse_value_response_for_question(question: "QuestionSchema", value: ResponseValue) -> ResponseValue: + question_type = question.type + response_type = RESPONSE_VALUE_FOR_QUESTION_TYPE[question_type] + + if isinstance(value, (dict, list)): + return parse_obj_as(response_type, value) + elif not isinstance(value, response_type): + raise ValueError(f"Value {value} is not valid for question type {question_type}. Expected {response_type}.") + + return value + + +def normalize_response_value(value: ResponseValue) -> ResponseValue: + """Normalize the response value.""" + if not isinstance(value, list) or not all([isinstance(v, dict) for v in value]): + return value + + new_value = [] + for v in value: + if "start" in v and "end" in v: + new_value.append(SpanValueSchema(**v)) + elif "value" in v: + new_value.append(RankingValueSchema(**v)) + else: + raise ValueError("Invalid value", value) + return new_value diff --git a/src/argilla/client/feedback/schemas/responses.py b/src/argilla/client/feedback/schemas/responses.py new file mode 100644 index 0000000000..0de48e256d --- /dev/null +++ b/src/argilla/client/feedback/schemas/responses.py @@ -0,0 +1,96 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from uuid import UUID + +from argilla.client.feedback.schemas.enums import ResponseStatus +from argilla.client.feedback.schemas.response_values import ( + ResponseValue, + normalize_response_value, + parse_value_response_for_question, +) +from argilla.pydantic_v1 import BaseModel, Extra, validator + +if TYPE_CHECKING: + from argilla.client.feedback.schemas.questions import QuestionSchema + + +class ValueSchema(BaseModel): + """Schema for any `FeedbackRecord` response value. + + Args: + value: The value of the record. + """ + + value: ResponseValue + + _normalize_value = validator("value", allow_reuse=True, always=True)(normalize_response_value) + + +class ResponseSchema(BaseModel): + """Schema for the `FeedbackRecord` response. + + Args: + user_id: ID of the user that provided the response. Defaults to None, and is + automatically fulfilled internally once the question is pushed to Argilla. + values: Values of the response, should match the questions in the record. + status: Status of the response. Defaults to `submitted`. + + Examples: + >>> from argilla.client.feedback.schemas.responses import ResponseSchema, ValueSchema + >>> ResponseSchema( + ... values={ + ... "question_1": ValueSchema(value="answer_1"), + ... "question_2": ValueSchema(value="answer_2"), + ... } + ... ) + """ + + user_id: Optional[UUID] = None + values: Union[List[Dict[str, ValueSchema]], Dict[str, ValueSchema], None] + status: Union[ResponseStatus, str] = ResponseStatus.submitted + + @validator("values", always=True) + def normalize_values(cls, values): + if isinstance(values, list) and all(isinstance(value, dict) for value in values): + return {k: v for value in values for k, v in value.items()} + return values + + @validator("status") + def normalize_status(cls, v) -> ResponseStatus: + if isinstance(v, str): + return ResponseStatus(v) + return v + + @validator("user_id", always=True) + def user_id_must_have_value(cls, v): + if not v: + warnings.warn( + "`user_id` not provided, so it will be set to `None`. Which is not an" + " issue, unless you're planning to log the response in Argilla, as" + " it will be automatically set to the active `user_id`.", + ) + return v + + class Config: + extra = Extra.forbid + validate_assignment = True + + def to_server_payload(self) -> Dict[str, Any]: + """Method that will be used to create the payload that will be sent to Argilla + to create a `ResponseSchema` for a `FeedbackRecord`.""" + payload = {"user_id": self.user_id, "status": self.status, **self.dict(exclude_unset=True, include={"values"})} + return payload diff --git a/src/argilla/client/feedback/schemas/suggestions.py b/src/argilla/client/feedback/schemas/suggestions.py new file mode 100644 index 0000000000..cabb23caf7 --- /dev/null +++ b/src/argilla/client/feedback/schemas/suggestions.py @@ -0,0 +1,68 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional +from uuid import UUID + +from argilla.client.feedback.schemas.response_values import ( + ResponseValue, + normalize_response_value, + parse_value_response_for_question, +) +from argilla.pydantic_v1 import BaseModel, Extra, confloat, validator + +if TYPE_CHECKING: + from argilla.client.feedback.schemas.questions import QuestionSchema + + +class SuggestionSchema(BaseModel): + """Schema for the suggestions for the questions related to the record. + + Args: + question_name: name of the question in the `FeedbackDataset`. + type: type of the question. Defaults to None. Possible values are `model` or `human`. + score: score of the suggestion. Defaults to None. + value: value of the suggestion, which should match the type of the question. + agent: agent that generated the suggestion. Defaults to None. + + Examples: + >>> from argilla.client.feedback.schemas.suggestions import SuggestionSchema + >>> SuggestionSchema( + ... question_name="question-1", + ... type="model", + ... score=0.9, + ... value="This is the first suggestion", + ... agent="agent-1", + ... ) + """ + + question_name: str + value: ResponseValue + score: Optional[confloat(ge=0, le=1)] = None + type: Optional[Literal["model", "human"]] = None + agent: Optional[str] = None + + _normalize_value = validator("value", allow_reuse=True, always=True)(normalize_response_value) + + class Config: + extra = Extra.forbid + validate_assignment = True + + def to_server_payload(self, question_name_to_id: Dict[str, UUID]) -> Dict[str, Any]: + """Method that will be used to create the payload that will be sent to Argilla + to create a `SuggestionSchema` for a `FeedbackRecord`.""" + payload = self.dict(exclude_unset=True, include={"type", "score", "value", "agent"}) + payload["question_id"] = str(question_name_to_id[self.question_name]) + + return payload diff --git a/src/argilla/client/feedback/schemas/types.py b/src/argilla/client/feedback/schemas/types.py index ed2c917153..ea7c39973f 100644 --- a/src/argilla/client/feedback/schemas/types.py +++ b/src/argilla/client/feedback/schemas/types.py @@ -20,35 +20,19 @@ IntegerMetadataProperty, TermsMetadataProperty, ) -from argilla.client.feedback.schemas.questions import ( - LabelQuestion, - MultiLabelQuestion, - RankingQuestion, - RatingQuestion, - TextQuestion, -) +from argilla.client.feedback.schemas.questions import AllowedQuestionTypes # noqa from argilla.client.feedback.schemas.remote.fields import RemoteTextField from argilla.client.feedback.schemas.remote.metadata import ( RemoteFloatMetadataProperty, RemoteIntegerMetadataProperty, RemoteTermsMetadataProperty, ) -from argilla.client.feedback.schemas.remote.questions import ( - RemoteLabelQuestion, - RemoteMultiLabelQuestion, - RemoteRankingQuestion, - RemoteRatingQuestion, - RemoteTextQuestion, -) +from argilla.client.feedback.schemas.remote.questions import AllowedRemoteQuestionTypes # noqa from argilla.client.feedback.schemas.remote.vector_settings import RemoteVectorSettings from argilla.client.feedback.schemas.vector_settings import VectorSettings AllowedFieldTypes = TextField AllowedRemoteFieldTypes = RemoteTextField -AllowedQuestionTypes = Union[TextQuestion, RatingQuestion, LabelQuestion, MultiLabelQuestion, RankingQuestion] -AllowedRemoteQuestionTypes = Union[ - RemoteTextQuestion, RemoteRatingQuestion, RemoteLabelQuestion, RemoteMultiLabelQuestion, RemoteRankingQuestion -] AllowedMetadataPropertyTypes = Union[TermsMetadataProperty, FloatMetadataProperty, IntegerMetadataProperty] AllowedRemoteMetadataPropertyTypes = Union[ RemoteTermsMetadataProperty, RemoteIntegerMetadataProperty, RemoteFloatMetadataProperty diff --git a/src/argilla/client/sdk/v1/datasets/models.py b/src/argilla/client/sdk/v1/datasets/models.py index 379491a0e3..d6d0793837 100644 --- a/src/argilla/client/sdk/v1/datasets/models.py +++ b/src/argilla/client/sdk/v1/datasets/models.py @@ -40,7 +40,7 @@ class FeedbackRankingValueModel(BaseModel): class FeedbackValueModel(BaseModel): - value: Union[StrictStr, StrictInt, List[str], List[FeedbackRankingValueModel]] + value: Any class FeedbackResponseStatus(str, Enum): diff --git a/tests/integration/client/conftest.py b/tests/integration/client/conftest.py index 37f7c7466b..2f5aeef423 100644 --- a/tests/integration/client/conftest.py +++ b/tests/integration/client/conftest.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Generator, List import pytest +from argilla import SpanQuestion from argilla.client.api import log from argilla.client.datasets import read_datasets from argilla.client.feedback.dataset.local.dataset import FeedbackDataset @@ -416,6 +417,7 @@ def feedback_dataset_questions() -> List["AllowedQuestionTypes"]: LabelQuestion(name="question-3", labels=["a", "b", "c"], required=True), MultiLabelQuestion(name="question-4", labels=["a", "b", "c"], required=True), RankingQuestion(name="question-5", values=["a", "b"], required=True), + SpanQuestion(name="question-6", field="text", labels=["a", "b"], required=False), ] @@ -461,6 +463,7 @@ def feedback_dataset_records() -> List[FeedbackRecord]: "question-3": {"value": "a"}, "question-4": {"value": ["a", "b"]}, "question-5": {"value": [{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}]}, + "question-6": {"value": [{"start": 0, "end": 4, "label": "a"}]}, }, "status": "submitted", }, @@ -485,6 +488,7 @@ def feedback_dataset_records() -> List[FeedbackRecord]: "question-5": { "value": [RankingValueSchema(rank=1, value="a"), RankingValueSchema(rank=2, value="b")] }, + "question-6": {"value": [{"start": 0, "end": 4, "label": "a"}]}, }, "status": "submitted", } @@ -525,6 +529,13 @@ def feedback_dataset_records() -> List[FeedbackRecord]: "score": 0.0, "agent": "agent-1", }, + { + "question_name": "question-6", + "value": [{"start": 0, "end": 4, "label": "a"}], + "type": "human", + "score": 0.0, + "agent": "agent-1", + }, ], external_id="3", ), @@ -538,6 +549,7 @@ def feedback_dataset_records() -> List[FeedbackRecord]: "question-3": {"value": "c"}, "question-4": {"value": ["a", "c"]}, "question-5": {"value": [{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}]}, + "question-6": {"value": [{"start": 0, "end": 4, "label": "a"}]}, }, "status": "submitted", } diff --git a/tests/integration/client/feedback/dataset/local/test_dataset.py b/tests/integration/client/feedback/dataset/local/test_dataset.py index 5ac92a9db8..431190426a 100644 --- a/tests/integration/client/feedback/dataset/local/test_dataset.py +++ b/tests/integration/client/feedback/dataset/local/test_dataset.py @@ -18,7 +18,7 @@ import argilla.client.singleton import datasets import pytest -from argilla import User, Workspace +from argilla import ResponseSchema, User, Workspace from argilla.client.feedback.config import DatasetConfig from argilla.client.feedback.constants import FETCHING_BATCH_SIZE from argilla.client.feedback.dataset import FeedbackDataset @@ -28,12 +28,14 @@ IntegerMetadataProperty, TermsMetadataProperty, ) -from argilla.client.feedback.schemas.questions import TextQuestion +from argilla.client.feedback.schemas.questions import SpanLabelOption, SpanQuestion, TextQuestion from argilla.client.feedback.schemas.records import FeedbackRecord from argilla.client.feedback.schemas.remote.records import RemoteSuggestionSchema from argilla.client.feedback.schemas.vector_settings import VectorSettings from argilla.client.feedback.training.schemas.base import TrainingTask from argilla.client.models import Framework +from argilla.client.sdk.commons.errors import ValidationApiError +from argilla.feedback import SpanValueSchema if TYPE_CHECKING: from argilla.client.feedback.schemas.types import AllowedFieldTypes, AllowedQuestionTypes @@ -52,7 +54,7 @@ def test_create_dataset_with_suggestions(argilla_user: "ServerUser") -> None: records=[ FeedbackRecord( fields={"text": "this is a text"}, - suggestions=[{"question_name": "text", "value": "This is a suggestion"}], + suggestions=[ds.question_by_name("text").suggestion(value="This is a suggestion")], ) ] ) @@ -65,6 +67,22 @@ def test_create_dataset_with_suggestions(argilla_user: "ServerUser") -> None: assert remote_dataset.records[0].suggestions[0].question_id == remote_dataset.question_by_name("text").id +def test_create_dataset_with_span_questions(argilla_user: "ServerUser") -> None: + argilla.client.singleton.init(api_key=argilla_user.api_key) + + ds = FeedbackDataset( + fields=[TextField(name="text")], + questions=[SpanQuestion(name="spans", field="text", labels=["label1", "label2"])], + ) + + rg_dataset = ds.push_to_argilla(name="new_dataset") + + assert rg_dataset.id + assert rg_dataset.questions[0].name == "spans" + assert rg_dataset.questions[0].field == "text" + assert rg_dataset.questions[0].labels == [SpanLabelOption(value="label1"), SpanLabelOption(value="label2")] + + @pytest.mark.asyncio async def test_update_dataset_records_with_suggestions(argilla_user: "ServerUser", db: "AsyncSession"): argilla.client.singleton.init(api_key=argilla_user.api_key) @@ -79,7 +97,7 @@ async def test_update_dataset_records_with_suggestions(argilla_user: "ServerUser assert remote_dataset.records[0].id is not None assert remote_dataset.records[0].suggestions == () - remote_dataset.records[0].update(suggestions=[{"question_name": "text", "value": "This is a suggestion"}]) + remote_dataset.records[0].update(suggestions=[ds.question_by_name("text").suggestion(value="This is a suggestion")]) # TODO: Review this requirement for tests and explain, try to avoid use or at least, document. await db.refresh(argilla_user, attribute_names=["datasets"]) @@ -98,9 +116,7 @@ def test_add_records( feedback_dataset_questions: List["AllowedQuestionTypes"], ) -> None: dataset = FeedbackDataset( - guidelines=feedback_dataset_guidelines, - fields=feedback_dataset_fields, - questions=feedback_dataset_questions, + guidelines=feedback_dataset_guidelines, fields=feedback_dataset_fields, questions=feedback_dataset_questions ) assert dataset.records == [] @@ -125,6 +141,13 @@ def test_add_records( assert not dataset.records[0].responses assert not dataset.records[0].suggestions + question_1 = dataset.question_by_name("question-1") + question_2 = dataset.question_by_name("question-2") + question_3 = dataset.question_by_name("question-3") + question_4 = dataset.question_by_name("question-4") + question_5 = dataset.question_by_name("question-5") + question_6 = dataset.question_by_name("question-6") + dataset.add_records( [ FeedbackRecord( @@ -134,38 +157,25 @@ def test_add_records( }, metadata={"unit": "test"}, responses=[ - { - "values": { - "question-1": {"value": "answer"}, - "question-2": {"value": 0}, - "question-3": {"value": "a"}, - "question-4": {"value": ["a", "b"]}, - "question-5": {"value": [{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}]}, - }, - "status": "submitted", - }, + ResponseSchema( + status="submitted", + values=[ + question_1.response(value="answer"), + question_2.response(value=0), + question_3.response(value="a"), + question_4.response(value=["a", "b"]), + question_5.response(value=[{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}]), + question_6.response(value=[SpanValueSchema(start=0, end=4, label="a")]), + ], + ), ], suggestions=[ - { - "question_name": "question-1", - "value": "answer", - }, - { - "question_name": "question-2", - "value": 0, - }, - { - "question_name": "question-3", - "value": "a", - }, - { - "question_name": "question-4", - "value": ["a", "b"], - }, - { - "question_name": "question-5", - "value": [{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}], - }, + question_1.suggestion(value="answer"), + question_2.suggestion(value=0), + question_3.suggestion(value="a"), + question_4.suggestion(value=["a", "b"]), + question_5.suggestion(value=[{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}]), + question_6.suggestion(value=[SpanValueSchema(start=0, end=4, label="a")]), ], external_id="test-id", ), @@ -186,6 +196,7 @@ def test_add_records( "question-3": {"value": "a"}, "question-4": {"value": ["a", "b"]}, "question-5": {"value": [{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}]}, + "question-6": {"value": [{"start": 0, "end": 4, "label": "a", "score": None}]}, }, "status": "submitted", } @@ -227,6 +238,45 @@ def test_add_records( assert len(dataset) == len(dataset.records) +@pytest.mark.parametrize( + "spans, expected_message_match", + [ + ( + [{"start": 0, "end": 400, "label": "label1"}], + "value `end` must have a value lower or equal than record field `text` length", + ), + ( + [ + SpanValueSchema(start=0, end=4, label="wrong-label"), + ], + "undefined label 'wrong-label' for span question.", + ), + ], +) +def test_add_records_with_wrong_spans_suggestions( + argilla_user: "ServerUser", spans: list, expected_message_match: str +) -> None: + argilla.client.singleton.init(api_key=argilla_user.api_key) + + dataset_cfg = FeedbackDataset( + fields=[TextField(name="text")], + questions=[SpanQuestion(name="spans", field="text", labels=["label1", "label2"])], + ) + + dataset = dataset_cfg.push_to_argilla(name="test-dataset") + question = dataset.question_by_name("spans") + + with pytest.raises(ValidationApiError, match=expected_message_match): + dataset.add_records( + [ + FeedbackRecord( + fields={"text": "this is a text"}, + suggestions=[question.suggestion(value=spans)], + ) + ] + ) + + def test_add_records_with_vectors() -> None: dataset = FeedbackDataset( fields=[TextField(name="text", required=True)], @@ -351,6 +401,13 @@ async def test_push_to_argilla_and_from_argilla( fields=feedback_dataset_fields, questions=feedback_dataset_questions, ) + + question_1 = dataset.question_by_name("question-1") + question_2 = dataset.question_by_name("question-2") + question_3 = dataset.question_by_name("question-3") + question_4 = dataset.question_by_name("question-4") + question_5 = dataset.question_by_name("question-5") + question_6 = dataset.question_by_name("question-6") # Make sure UUID in `user_id` is pushed to Argilla with no issues as it should be # converted to a string dataset.add_records( @@ -361,26 +418,28 @@ async def test_push_to_argilla_and_from_argilla( "label": "F", }, responses=[ - { - "values": { - "question-1": {"value": "answer"}, - "question-2": {"value": 1}, - "question-3": {"value": "a"}, - "question-4": {"value": ["a", "b"]}, - "question-5": {"value": [{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}]}, - }, - "status": "submitted", - }, - { - "values": { - "question-1": {"value": "answer"}, - "question-2": {"value": 1}, - "question-3": {"value": "a"}, - "question-4": {"value": ["a", "b"]}, - "question-5": {"value": [{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}]}, - }, - "status": "submitted", - }, + ResponseSchema( + status="submitted", + values=[ + question_1.response(value="answer"), + question_2.response(value=1), + question_3.response(value="a"), + question_4.response(value=["a", "b"]), + question_5.response(value=[{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}]), + question_6.response(value=[SpanValueSchema(start=0, end=1, label="a")]), + ], + ), + ResponseSchema( + status="submitted", + values=[ + question_1.response(value="answer"), + question_2.response(value=1), + question_3.response(value="a"), + question_4.response(value=["a", "b"]), + question_5.response(value=[{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}]), + question_6.response(value=[SpanValueSchema(start=0, end=1, label="a")]), + ], + ), ], ), ] @@ -597,26 +656,36 @@ def test_push_to_huggingface_and_from_huggingface( FeedbackRecord( fields={"text": "This is a negative example", "label": "negative"}, responses=[ - { - "values": { - "question-1": {"value": "This is a response to question 1"}, - "question-2": {"value": 0}, - "question-3": {"value": "b"}, - "question-4": {"value": ["b", "c"]}, - "question-5": {"value": [{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}]}, - }, - "status": "submitted", - }, - { - "values": { - "question-1": {"value": "This is a response to question 1"}, - "question-2": {"value": 0}, - "question-3": {"value": "b"}, - "question-4": {"value": ["b", "c"]}, - "question-5": {"value": [{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}]}, - }, - "status": "submitted", - }, + ResponseSchema( + status="submitted", + values=[ + dataset.question_by_name("question-1").response(value="This is a response to question 1"), + dataset.question_by_name("question-2").response(value=0), + dataset.question_by_name("question-3").response(value="b"), + dataset.question_by_name("question-4").response(value=["b", "c"]), + dataset.question_by_name("question-5").response( + value=[{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}] + ), + dataset.question_by_name("question-6").response( + value=[SpanValueSchema(start=0, end=4, label="a")] + ), + ], + ), + ResponseSchema( + status="submitted", + values=[ + dataset.question_by_name("question-1").response(value="This is a response to question 1"), + dataset.question_by_name("question-2").response(value=0), + dataset.question_by_name("question-3").response(value="b"), + dataset.question_by_name("question-4").response(value=["b", "c"]), + dataset.question_by_name("question-5").response( + value=[{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}] + ), + dataset.question_by_name("question-6").response( + value=[SpanValueSchema(start=0, end=4, label="a")] + ), + ], + ), ], ), ], diff --git a/tests/integration/client/feedback/dataset/remote/test_dataset.py b/tests/integration/client/feedback/dataset/remote/test_dataset.py index 1ada1fb9de..f335a67c94 100644 --- a/tests/integration/client/feedback/dataset/remote/test_dataset.py +++ b/tests/integration/client/feedback/dataset/remote/test_dataset.py @@ -18,14 +18,10 @@ import argilla as rg import argilla.client.singleton -import numpy import pytest -from argilla import ( - FeedbackRecord, -) +from argilla import FeedbackRecord, SuggestionSchema from argilla.client.feedback.dataset import FeedbackDataset from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset -from argilla.client.feedback.schemas import SuggestionSchema from argilla.client.feedback.schemas.fields import TextField from argilla.client.feedback.schemas.metadata import ( FloatMetadataProperty, @@ -166,12 +162,11 @@ async def test_update_records_with_suggestions( ) remote = test_dataset_with_metadata_properties.push_to_argilla(name="test_dataset", workspace=ws) + question = remote.question_by_name("question") records = [] for record in remote: - record.suggestions = [ - SuggestionSchema(question_name="question", value=f"Hello world! for {record.fields['text']}") - ] + record.suggestions = [question.suggestion(f"Hello world! for {record.fields['text']}")] records.append(record) remote.update_records(records) @@ -192,14 +187,17 @@ async def test_update_records_with_empty_list_of_suggestions( ws = rg.Workspace.create(name="test-workspace") remote = test_dataset_with_metadata_properties.push_to_argilla(name="test_dataset", workspace=ws) + question = remote.question_by_name("question") remote.add_records( [ FeedbackRecord( - fields={"text": "Hello world!"}, suggestions=[{"question_name": "question", "value": "test"}] + fields={"text": "Hello world!"}, + suggestions=[question.suggestion("test")], ), FeedbackRecord( - fields={"text": "Another record"}, suggestions=[{"question_name": "question", "value": "test"}] + fields={"text": "Another record"}, + suggestions=[question.suggestion(value="test")], ), ] ) diff --git a/tests/unit/client/feedback/schemas/remote/test_questions.py b/tests/unit/client/feedback/schemas/remote/test_questions.py index d41e1578e0..2b2103d713 100644 --- a/tests/unit/client/feedback/schemas/remote/test_questions.py +++ b/tests/unit/client/feedback/schemas/remote/test_questions.py @@ -30,6 +30,7 @@ RemoteMultiLabelQuestion, RemoteRankingQuestion, RemoteRatingQuestion, + RemoteSpanQuestion, RemoteTextQuestion, ) from argilla.client.sdk.v1.datasets.models import FeedbackQuestionModel @@ -447,3 +448,61 @@ def test_remote_ranking_question_from_api(payload: FeedbackQuestionModel) -> Non assert ranking_question.type == QuestionTypes.ranking assert ranking_question.server_settings == payload.settings assert ranking_question.to_server_payload() == payload.dict(exclude={"id", "inserted_at", "updated_at"}) + + +def test_span_questions_from_api(): + model = FeedbackQuestionModel( + id=uuid4(), + name="question", + title="Question", + required=True, + settings={ + "type": "span", + "field": "field", + "visible_options": None, + "options": [ + {"text": "Span label a", "value": "a", "description": None}, + { + "text": "Span label b", + "value": "b", + "description": None, + }, + ], + }, + inserted_at=datetime.now(), + updated_at=datetime.now(), + ) + question = RemoteSpanQuestion.from_api(model) + + assert question.type == QuestionTypes.span + assert question.server_settings == model.settings + assert question.to_server_payload() == model.dict(exclude={"id", "inserted_at", "updated_at"}) + assert question.to_local().type == QuestionTypes.span + + +def test_span_questions_from_api_with_visible_labels(): + model = FeedbackQuestionModel( + id=uuid4(), + name="question", + title="Question", + required=True, + settings={ + "type": "span", + "field": "field", + "visible_options": 3, + "options": [ + {"text": "Span label a", "value": "a", "description": None}, + {"text": "Span label b", "value": "b", "description": None}, + {"text": "Span label c", "value": "c", "description": None}, + {"text": "Span label d", "value": "d", "description": None}, + ], + }, + inserted_at=datetime.now(), + updated_at=datetime.now(), + ) + question = RemoteSpanQuestion.from_api(model) + + assert question.type == QuestionTypes.span + assert question.server_settings == model.settings + assert question.to_server_payload() == model.dict(exclude={"id", "inserted_at", "updated_at"}) + assert question.to_local().type == QuestionTypes.span diff --git a/tests/unit/client/feedback/schemas/remote/test_records.py b/tests/unit/client/feedback/schemas/remote/test_records.py index 8d113267a6..b8debcecb5 100644 --- a/tests/unit/client/feedback/schemas/remote/test_records.py +++ b/tests/unit/client/feedback/schemas/remote/test_records.py @@ -194,6 +194,7 @@ def test_remote_response_schema(schema_kwargs: Dict[str, Any], server_payload: D "question-4": FeedbackValueModel( value=[FeedbackRankingValueModel(value="a", rank=1), FeedbackRankingValueModel(value="b", rank=2)] ), + "question-5": FeedbackValueModel(value=[{"start": 0, "end": 1, "label": "a"}]), }, status="submitted", user_id=uuid4(), @@ -208,6 +209,14 @@ def test_remote_response_schema(schema_kwargs: Dict[str, Any], server_payload: D inserted_at=datetime.now(), updated_at=datetime.now(), ), + FeedbackResponseModel( + id=uuid4(), + values={"span-question": FeedbackValueModel(value=[{"start": 0, "end": 1, "label": "a"}])}, + status="discarded", + user_id=uuid4(), + inserted_at=datetime.now(), + updated_at=datetime.now(), + ), ], ) def test_remote_response_schema_from_api(payload: FeedbackResponseModel) -> None: diff --git a/tests/unit/client/feedback/schemas/test_questions.py b/tests/unit/client/feedback/schemas/test_questions.py index afa08e1cbe..0e45e46cb8 100644 --- a/tests/unit/client/feedback/schemas/test_questions.py +++ b/tests/unit/client/feedback/schemas/test_questions.py @@ -21,6 +21,8 @@ MultiLabelQuestion, RankingQuestion, RatingQuestion, + SpanLabelOption, + SpanQuestion, TextQuestion, _LabelQuestion, ) @@ -444,3 +446,172 @@ def test_ranking_question(schema_kwargs: Dict[str, Any], server_payload: Dict[st def test_ranking_question_errors(schema_kwargs: Dict[str, Any], exception_cls: Any, exception_message: str) -> None: with pytest.raises(exception_cls, match=exception_message): RankingQuestion(**schema_kwargs) + + +def test_span_question() -> None: + question = SpanQuestion( + name="question", + field="field", + title="Question", + description="Description", + required=True, + labels=["a", "b"], + ) + + assert question.type == QuestionTypes.span + assert question.server_settings == { + "type": "span", + "field": "field", + "visible_options": None, + "options": [{"value": "a", "text": "a", "description": None}, {"value": "b", "text": "b", "description": None}], + } + + +def test_span_question_with_labels_dict() -> None: + question = SpanQuestion( + name="question", + field="field", + title="Question", + description="Description", + labels={"a": "A text", "b": "B text"}, + ) + + assert question.type == QuestionTypes.span + assert question.server_settings == { + "type": "span", + "field": "field", + "visible_options": None, + "options": [ + {"value": "a", "text": "A text", "description": None}, + {"value": "b", "text": "B text", "description": None}, + ], + } + + +def test_span_question_with_visible_labels() -> None: + question = SpanQuestion( + name="question", + field="field", + title="Question", + description="Description", + labels=["a", "b", "c", "d"], + visible_labels=3, + ) + + assert question.type == QuestionTypes.span + assert question.server_settings == { + "type": "span", + "field": "field", + "visible_options": 3, + "options": [ + {"value": "a", "text": "a", "description": None}, + {"value": "b", "text": "b", "description": None}, + {"value": "c", "text": "c", "description": None}, + {"value": "d", "text": "d", "description": None}, + ], + } + + +def test_span_question_with_visible_labels_default_value(): + question = SpanQuestion( + name="question", + field="field", + title="Question", + description="Description", + labels=list(range(21)), + ) + + assert question.visible_labels == 20 + + +def test_span_question_with_default_visible_label_when_labels_is_less_than_20(): + with pytest.warns(UserWarning, match=""): + question = SpanQuestion( + name="question", + field="field", + title="Question", + description="Description", + labels=list(range(19)), + ) + + assert question.visible_labels == 19 + + +def test_span_question_when_visible_labels_is_greater_than_total_labels(): + with pytest.warns( + UserWarning, + match="`visible_labels=4` is greater than the total number of labels \(3\)", + ): + question = SpanQuestion( + name="question", + field="field", + title="Question", + description="Description", + labels=["a", "b", "c"], + visible_labels=4, + ) + + assert question.visible_labels == 3 + + +def test_span_question_with_visible_labels_less_than_total_labels(): + with pytest.warns( + UserWarning, match="Since `labels` has less than 3 labels, `visible_labels` will be set to `None`." + ): + question = SpanQuestion( + name="question", + field="field", + title="Question", + description="Description", + labels=["a", "b"], + visible_labels=3, + ) + + assert question.visible_labels is None + + +def test_span_question_with_visible_labels_less_than_min_value(): + with pytest.raises(ValidationError, match="ensure this value is greater than or equal to 3"): + SpanQuestion( + name="question", + field="field", + title="Question", + description="Description", + labels=["a", "b"], + visible_labels=2, + ) + + +def test_span_questions_with_default_visible_labels_and_less_labels_than_default(): + with pytest.warns(UserWarning, match="visible_labels=20` is greater than the total number of labels"): + question = SpanQuestion( + name="question", + field="field", + title="Question", + description="Description", + labels=list(range(10)), + ) + + assert question.visible_labels == 10 + + +def test_span_question_with_no_labels() -> None: + with pytest.raises(ValidationError, match="At least one label must be provided"): + SpanQuestion( + name="question", + field="field", + title="Question", + description="Description", + labels=[], + ) + + +def test_span_question_with_duplicated_labels() -> None: + with pytest.raises(ValidationError, match="the list has duplicated items"): + SpanQuestion( + name="question", + title="Question", + field="field", + description="Description", + labels=[SpanLabelOption(value="a", text="A text"), SpanLabelOption(value="a", text="Text for A")], + ) diff --git a/tests/unit/client/feedback/schemas/test_responses.py b/tests/unit/client/feedback/schemas/test_responses.py new file mode 100644 index 0000000000..43281e8bc3 --- /dev/null +++ b/tests/unit/client/feedback/schemas/test_responses.py @@ -0,0 +1,72 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from uuid import uuid4 + +import pytest +from argilla.feedback import ( + FeedbackDataset, + ResponseSchema, + ResponseStatus, + SpanValueSchema, + TextQuestion, + ValueSchema, +) + + +def test_create_span_response_wrong_limits(): + with pytest.raises(ValueError, match="The end of the span must be greater than the start."): + SpanValueSchema(start=10, end=8, label="test") + + +def test_create_response(): + question = TextQuestion(name="text") + response = ResponseSchema(status="draft", values=[question.response("Value for text")]) + + assert response.status == ResponseStatus.draft + assert question.name in response.values + assert response.values[question.name].value == "Value for text" + + +def test_create_responses_with_multiple_questions(): + question1 = TextQuestion(name="text") + question2 = TextQuestion(name="text2") + response = ResponseSchema( + status="draft", + values=[ + question1.response("Value for text"), + question2.response("Value for text2"), + ], + ) + + assert response.status == ResponseStatus.draft + assert question1.name in response.values + assert response.values[question1.name].value == "Value for text" + assert question2.name in response.values + assert response.values[question2.name].value == "Value for text2" + + +def test_create_response_with_wrong_value(): + with pytest.raises(ValueError, match="Value 10 is not valid for question type text. Expected ."): + ResponseSchema(status="draft", values=[TextQuestion(name="text").response(10)]) + + +def test_response_to_server_payload_with_string_status(): + assert ResponseSchema(status="draft").to_server_payload() == {"user_id": None, "status": "draft"} + + +def test_response_to_server_payload_with_no_values(): + assert ResponseSchema().to_server_payload() == {"user_id": None, "status": "submitted"} + assert ResponseSchema(values=None).to_server_payload() == {"user_id": None, "status": "submitted", "values": None} + assert ResponseSchema(values=[]).to_server_payload() == {"user_id": None, "status": "submitted", "values": {}} + assert ResponseSchema(values={}).to_server_payload() == {"user_id": None, "status": "submitted", "values": {}} diff --git a/tests/unit/client/feedback/schemas/test_suggestions.py b/tests/unit/client/feedback/schemas/test_suggestions.py new file mode 100644 index 0000000000..5d56f356fb --- /dev/null +++ b/tests/unit/client/feedback/schemas/test_suggestions.py @@ -0,0 +1,30 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from argilla.feedback import SuggestionSchema, TextQuestion + + +def test_create_suggestion(): + question = TextQuestion(name="text") + + suggestion = question.suggestion("Value for text", agent="mock") + + assert suggestion.question_name == question.name + assert suggestion.agent == "mock" + + +def test_create_suggestion_with_wrong_value(): + with pytest.raises(ValueError, match="Value 10 is not valid for question type text. Expected ."): + TextQuestion(name="text").suggestion(value=10, agent="Mock")