Skip to content

Commit

Permalink
refactor: creating suggestions and responses (#4627)
Browse files Browse the repository at this point in the history
<!-- Thanks for your contribution! As part of our Community Growers
initiative 🌱, we're donating Justdiggit bunds in your name to reforest
sub-Saharan Africa. To claim your Community Growers certificate, please
contact David Berenstein in our Slack community or fill in this form
https://tally.so/r/n9XrxK once your PR has been merged. -->

# Description

This PR adds response and suggestion method helpers to simplify and
unify how users may work with responses and suggestions.

Instead of using response and suggestion schemas init methods, users may
use `*.with_question_value` method to create responses and suggestions
given a question and a response value. This way will introduce some
extra data validation which is not available using init methods
directly.

For example, given the following dataset:
```python
dataset = rg.FeedbackDataset(
    fields =[rg.TextField(name="text")],
    question=[
        rg.TextQuestion(name="question-1", required=True),
        rg.RatingQuestion(name="question-2", values=[1, 2], required=True),
        rg.LabelQuestion(name="question-3", labels=["a", "b", "c"], required=True),
        rg.MultiLabelQuestion(name="question-4", labels=["a", "b", "c"], required=True),
        rg.RankingQuestion(name="question-5", values=["a", "b"], required=True),
    ]
)
```
users could create responses and suggestions as follows:

```python
question_1 = dataset.question_by_name("question-1")
question_2 = dataset.question_by_name("question-2")
question_3 = dataset.question_by_name("question-3")
question_4 = dataset.question_by_name("question-4")
question_5 = dataset.question_by_name("question-5")

record = rg.FeedbackRecord(
   fields={ "text": "This is a text value for field"},
   responses=[
      ResponseSchema(status="submitted")
        .with_question_value(question_1, value="answer")
        .with_question_value(question_2, value=0)
        .with_question_value(question_3, value="a")
        .with_question_value(question_4, value=["a", "b"])
        .with_question_value(question_5, value=[{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}])
   ],
   suggestions=[
      SuggestionSchema.with_question_value(question_1, value="answer"),
      SuggestionSchema.with_question_value(question_2, value=0),
      SuggestionSchema.with_question_value(question_3, value="a"),
      SuggestionSchema.with_question_value(question_4, value=["a", "b"]),
      SuggestionSchema.with_question_value(question_5, value=[{"rank": 1, "value": "a"}, {"rank":2, "value": "b"}])
   ]
)
```

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [ ] New feature (non-breaking change which adds functionality)
- [ ] Refactor (change restructuring the codebase without changing
functionality)
- [ ] Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

- [ ] Test A
- [ ] Test B

**Checklist**

- [ ] I added relevant documentation
- [ ] I followed the style guidelines of this project
- [ ] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [ ] I have added relevant notes to the `CHANGELOG.md` file (See
https://keepachangelog.com/)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
frascuchon and pre-commit-ci[bot] authored Mar 6, 2024
1 parent a3def08 commit ae329fc
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 84 deletions.
14 changes: 5 additions & 9 deletions src/argilla/client/feedback/schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,9 @@
SpanQuestion,
TextQuestion,
)
from argilla.client.feedback.schemas.records import (
FeedbackRecord,
ResponseSchema,
SortBy,
ValueSchema,
)
from argilla.client.feedback.schemas.records import FeedbackRecord, SortBy
from argilla.client.feedback.schemas.response_values import RankingValueSchema, ResponseValue, SpanValueSchema
from argilla.client.feedback.schemas.responses import ResponseSchema, ResponseStatus, ValueSchema
from argilla.client.feedback.schemas.suggestions import SuggestionSchema
from argilla.client.feedback.schemas.vector_settings import VectorSettings

Expand All @@ -67,13 +63,13 @@
"SpanQuestion",
"SpanLabelOption",
"FeedbackRecord",
"RankingValueSchema",
"ResponseSchema",
"SuggestionSchema",
"ResponseValue",
"ResponseStatus",
"SuggestionSchema",
"ValueSchema",
"RankingValueSchema",
"SpanValueSchema",
"ValueSchema",
"SortOrder",
"SortBy",
"RecordSortField",
Expand Down
13 changes: 13 additions & 0 deletions src/argilla/client/feedback/schemas/questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from typing import Any, Dict, List, Literal, Optional, Union

from argilla.client.feedback.schemas.enums import QuestionTypes
from argilla.client.feedback.schemas.response_values import parse_value_response_for_question
from argilla.client.feedback.schemas.responses import ResponseValue, ValueSchema
from argilla.client.feedback.schemas.suggestions import SuggestionSchema
from argilla.client.feedback.schemas.utils import LabelMappingMixin
from argilla.client.feedback.schemas.validators import title_must_have_value
from argilla.pydantic_v1 import BaseModel, Extra, Field, conint, conlist, root_validator, validator
Expand Down Expand Up @@ -77,6 +80,16 @@ def to_server_payload(self) -> Dict[str, Any]:
"settings": self.server_settings,
}

def suggestion(self, value: ResponseValue, **kwargs) -> SuggestionSchema:
"""Method that will be used to create a `SuggestionSchema` from the question and a suggested value."""
value = parse_value_response_for_question(self, value)
return SuggestionSchema(question_name=self.name, value=value, **kwargs)

def response(self, value: ResponseValue) -> Dict[str, ValueSchema]:
"""Method that will be used to create a response from the question and a value."""
value = parse_value_response_for_question(self, value)
return {self.name: ValueSchema(value=value)}


class TextQuestion(QuestionSchema):
"""Schema for the `FeedbackDataset` text questions, which are the ones that will
Expand Down
4 changes: 2 additions & 2 deletions src/argilla/client/feedback/schemas/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class FeedbackRecord(BaseModel):
Defaults to None.
Examples:
>>> from argilla.client.feedback.schemas.records import FeedbackRecord, ResponseSchema, SuggestionSchema, ValueSchema
>>> from argilla.feedback import FeedbackRecord, ResponseSchema, SuggestionSchema, ValueSchema
>>> FeedbackRecord(
... fields={"text": "This is the first record", "label": "positive"},
... metadata={"first": True, "nested": {"more": "stuff"}},
Expand All @@ -72,7 +72,7 @@ class FeedbackRecord(BaseModel):
... value="This is the first suggestion",
... agent="agent-1",
... ),
... ]
... ],
... external_id="entry-1",
... )
Expand Down
42 changes: 17 additions & 25 deletions src/argilla/client/feedback/schemas/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,16 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from uuid import UUID

from argilla.client.feedback.schemas.enums import QuestionTypes, ResponseStatus
from argilla.client.feedback.schemas.enums import ResponseStatus
from argilla.client.feedback.schemas.response_values import (
RankingValueSchema,
ResponseValue,
SpanValueSchema,
normalize_response_value,
parse_value_response_for_question,
)
from argilla.pydantic_v1 import BaseModel, Extra, validator

if TYPE_CHECKING:
from argilla.client.feedback.schemas.questions import QuestionSchema
from argilla.client.feedback.schemas.records import FeedbackRecord


class ValueSchema(BaseModel):
Expand Down Expand Up @@ -63,8 +60,20 @@ class ResponseSchema(BaseModel):
"""

user_id: Optional[UUID] = None
values: Union[Dict[str, ValueSchema], None]
status: ResponseStatus = ResponseStatus.submitted
values: Union[List[Dict[str, ValueSchema]], Dict[str, ValueSchema], None]
status: Union[ResponseStatus, str] = ResponseStatus.submitted

@validator("values", always=True)
def normalize_values(cls, values):
if isinstance(values, list) and all(isinstance(value, dict) for value in values):
return {k: v for value in values for k, v in value.items()}
return values

@validator("status")
def normalize_status(cls, v) -> ResponseStatus:
if isinstance(v, str):
return ResponseStatus(v)
return v

@validator("user_id", always=True)
def user_id_must_have_value(cls, v):
Expand All @@ -80,25 +89,8 @@ class Config:
extra = Extra.forbid
validate_assignment = True

def with_question_value(self, question: "QuestionSchema", value: ResponseValue) -> "ResponseSchema":
"""Returns the response value for the given record."""
value = parse_value_response_for_question(question, value)

values = self.values or {}
values[question.name] = ValueSchema(value=value)

self.values = values

return self

def to_server_payload(self) -> Dict[str, Any]:
"""Method that will be used to create the payload that will be sent to Argilla
to create a `ResponseSchema` for a `FeedbackRecord`."""
return {
# UUID is not json serializable!!!
"user_id": self.user_id,
"values": {question_name: value.dict(exclude_unset=True) for question_name, value in self.values.items()}
if self.values is not None
else None,
"status": self.status.value if hasattr(self.status, "value") else self.status,
}
payload = {"user_id": self.user_id, "status": self.status, **self.dict(exclude_unset=True, include={"values"})}
return payload
15 changes: 11 additions & 4 deletions src/argilla/client/feedback/schemas/suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Literal, Optional
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional
from uuid import UUID

from argilla.client.feedback.schemas.response_values import ResponseValue, normalize_response_value
from argilla.client.feedback.schemas.response_values import (
ResponseValue,
normalize_response_value,
parse_value_response_for_question,
)
from argilla.pydantic_v1 import BaseModel, Extra, confloat, validator

if TYPE_CHECKING:
from argilla.client.feedback.schemas.questions import QuestionSchema


class SuggestionSchema(BaseModel):
"""Schema for the suggestions for the questions related to the record.
Expand All @@ -41,9 +48,9 @@ class SuggestionSchema(BaseModel):
"""

question_name: str
type: Optional[Literal["model", "human"]] = None
score: Optional[confloat(ge=0, le=1)] = None
value: ResponseValue
score: Optional[confloat(ge=0, le=1)] = None
type: Optional[Literal["model", "human"]] = None
agent: Optional[str] = None

_normalize_value = validator("value", allow_reuse=True, always=True)(normalize_response_value)
Expand Down
57 changes: 24 additions & 33 deletions tests/integration/client/feedback/dataset/local/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import argilla.client.singleton
import datasets
import pytest
from argilla import User, Workspace
from argilla import ResponseSchema, User, Workspace
from argilla.client.feedback.config import DatasetConfig
from argilla.client.feedback.dataset import FeedbackDataset
from argilla.client.feedback.schemas.fields import TextField
Expand All @@ -30,6 +30,7 @@
from argilla.client.feedback.schemas.questions import SpanLabelOption, SpanQuestion, TextQuestion
from argilla.client.feedback.schemas.records import FeedbackRecord
from argilla.client.feedback.schemas.remote.records import RemoteSuggestionSchema
from argilla.client.feedback.schemas.suggestions import SuggestionSchema
from argilla.client.feedback.schemas.vector_settings import VectorSettings
from argilla.client.feedback.training.schemas.base import TrainingTask
from argilla.client.models import Framework
Expand All @@ -51,7 +52,7 @@ def test_create_dataset_with_suggestions(argilla_user: "ServerUser") -> None:
records=[
FeedbackRecord(
fields={"text": "this is a text"},
suggestions=[{"question_name": "text", "value": "This is a suggestion"}],
suggestions=[ds.question_by_name("text").suggestion(value="This is a suggestion")],
)
]
)
Expand Down Expand Up @@ -94,7 +95,7 @@ async def test_update_dataset_records_with_suggestions(argilla_user: "ServerUser
assert remote_dataset.records[0].id is not None
assert remote_dataset.records[0].suggestions == ()

remote_dataset.records[0].update(suggestions=[{"question_name": "text", "value": "This is a suggestion"}])
remote_dataset.records[0].update(suggestions=[ds.question_by_name("text").suggestion(value="This is a suggestion")])

# TODO: Review this requirement for tests and explain, try to avoid use or at least, document.
await db.refresh(argilla_user, attribute_names=["datasets"])
Expand Down Expand Up @@ -140,6 +141,11 @@ def test_add_records(
assert not dataset.records[0].responses
assert not dataset.records[0].suggestions

question_1 = dataset.question_by_name("question-1")
question_2 = dataset.question_by_name("question-2")
question_3 = dataset.question_by_name("question-3")
question_4 = dataset.question_by_name("question-4")
question_5 = dataset.question_by_name("question-5")
dataset.add_records(
[
FeedbackRecord(
Expand All @@ -149,38 +155,23 @@ def test_add_records(
},
metadata={"unit": "test"},
responses=[
{
"values": {
"question-1": {"value": "answer"},
"question-2": {"value": 0},
"question-3": {"value": "a"},
"question-4": {"value": ["a", "b"]},
"question-5": {"value": [{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}]},
},
"status": "submitted",
},
ResponseSchema(
status="submitted",
values=[
question_1.response(value="answer"),
question_2.response(value=0),
question_3.response(value="a"),
question_4.response(value=["a", "b"]),
question_5.response(value=[{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}]),
],
),
],
suggestions=[
{
"question_name": "question-1",
"value": "answer",
},
{
"question_name": "question-2",
"value": 0,
},
{
"question_name": "question-3",
"value": "a",
},
{
"question_name": "question-4",
"value": ["a", "b"],
},
{
"question_name": "question-5",
"value": [{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}],
},
question_1.suggestion(value="answer"),
question_2.suggestion(value=0),
question_3.suggestion(value="a"),
question_4.suggestion(value=["a", "b"]),
question_5.suggestion(value=[{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}]),
],
external_id="test-id",
),
Expand Down
18 changes: 8 additions & 10 deletions tests/integration/client/feedback/dataset/remote/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,10 @@

import argilla as rg
import argilla.client.singleton
import numpy
import pytest
from argilla import (
FeedbackRecord,
)
from argilla import FeedbackRecord, SuggestionSchema
from argilla.client.feedback.dataset import FeedbackDataset
from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset
from argilla.client.feedback.schemas import SuggestionSchema
from argilla.client.feedback.schemas.fields import TextField
from argilla.client.feedback.schemas.metadata import (
FloatMetadataProperty,
Expand Down Expand Up @@ -166,12 +162,11 @@ async def test_update_records_with_suggestions(
)

remote = test_dataset_with_metadata_properties.push_to_argilla(name="test_dataset", workspace=ws)
question = remote.question_by_name("question")

records = []
for record in remote:
record.suggestions = [
SuggestionSchema(question_name="question", value=f"Hello world! for {record.fields['text']}")
]
record.suggestions = [question.suggestion(f"Hello world! for {record.fields['text']}")]
records.append(record)

remote.update_records(records)
Expand All @@ -192,14 +187,17 @@ async def test_update_records_with_empty_list_of_suggestions(
ws = rg.Workspace.create(name="test-workspace")

remote = test_dataset_with_metadata_properties.push_to_argilla(name="test_dataset", workspace=ws)
question = remote.question_by_name("question")

remote.add_records(
[
FeedbackRecord(
fields={"text": "Hello world!"}, suggestions=[{"question_name": "question", "value": "test"}]
fields={"text": "Hello world!"},
suggestions=[question.suggestion("test")],
),
FeedbackRecord(
fields={"text": "Another record"}, suggestions=[{"question_name": "question", "value": "test"}]
fields={"text": "Another record"},
suggestions=[question.suggestion(value="test")],
),
]
)
Expand Down
53 changes: 52 additions & 1 deletion tests/unit/client/feedback/schemas/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,62 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from uuid import uuid4

import pytest
from argilla.feedback import SpanValueSchema
from argilla.feedback import (
FeedbackDataset,
ResponseSchema,
ResponseStatus,
SpanValueSchema,
TextQuestion,
ValueSchema,
)


def test_create_span_response_wrong_limits():
with pytest.raises(ValueError, match="The end of the span must be greater than the start."):
SpanValueSchema(start=10, end=8, label="test")


def test_create_response():
question = TextQuestion(name="text")
response = ResponseSchema(status="draft", values=[question.response("Value for text")])

assert response.status == ResponseStatus.draft
assert question.name in response.values
assert response.values[question.name].value == "Value for text"


def test_create_responses_with_multiple_questions():
question1 = TextQuestion(name="text")
question2 = TextQuestion(name="text2")
response = ResponseSchema(
status="draft",
values=[
question1.response("Value for text"),
question2.response("Value for text2"),
],
)

assert response.status == ResponseStatus.draft
assert question1.name in response.values
assert response.values[question1.name].value == "Value for text"
assert question2.name in response.values
assert response.values[question2.name].value == "Value for text2"


def test_create_response_with_wrong_value():
with pytest.raises(ValueError, match="Value 10 is not valid for question type text. Expected <class 'str'>."):
ResponseSchema(status="draft", values=[TextQuestion(name="text").response(10)])


def test_response_to_server_payload_with_string_status():
assert ResponseSchema(status="draft").to_server_payload() == {"user_id": None, "status": "draft"}


def test_response_to_server_payload_with_no_values():
assert ResponseSchema().to_server_payload() == {"user_id": None, "status": "submitted"}
assert ResponseSchema(values=None).to_server_payload() == {"user_id": None, "status": "submitted", "values": None}
assert ResponseSchema(values=[]).to_server_payload() == {"user_id": None, "status": "submitted", "values": {}}
assert ResponseSchema(values={}).to_server_payload() == {"user_id": None, "status": "submitted", "values": {}}
Loading

0 comments on commit ae329fc

Please sign in to comment.