-
Notifications
You must be signed in to change notification settings - Fork 389
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: create responses and suggestions with spans #4623
Changes from all commits
d44be46
39c9c6e
28a7ffa
febd300
3b010bb
9759153
5819d87
d0ffa74
9776e2d
8cafa01
86f568d
490dec6
56f360b
5b638e7
f337160
5317900
35f6392
ceded58
424a5c8
11a0068
f4bea43
53af61b
2760348
4edf802
433f27b
540979f
0899176
be02eb2
22f6947
be13a7d
84dd077
63a369f
a1c13e9
5831642
b39cbb3
2e87977
5f14b28
ca0b649
816a839
6a2f636
fd664f4
7f43202
d0b1351
a3def08
ae329fc
5b65e11
af035dc
bb7dba4
b6497ae
9439624
c86b20d
6136414
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,11 +11,11 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import json | ||
import logging | ||
import tempfile | ||
import warnings | ||
from copy import copy | ||
from typing import TYPE_CHECKING, Any, Optional, Type, Union | ||
|
||
from packaging.version import parse as parse_version | ||
|
@@ -50,7 +50,7 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset | |
questions, and metadata_properties formatted as `datasets.Features`. | ||
|
||
Examples: | ||
>>> from argilla.client.feedback.integrations.dataset import HuggingFaceDatasetMixin | ||
>>> from argilla.client.feedback.integrations.huggingface import HuggingFaceDatasetMixin | ||
>>> dataset = FeedbackDataset(...) or RemoteFeedbackDataset(...) | ||
>>> huggingface_dataset = HuggingFaceDatasetMixin._huggingface_format(dataset) | ||
""" | ||
|
@@ -71,17 +71,38 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset | |
for question in dataset.questions: | ||
if question.type in [QuestionTypes.text, QuestionTypes.label_selection]: | ||
value = Value(dtype="string", id="question") | ||
suggestion_value = copy(value) | ||
elif question.type == QuestionTypes.rating: | ||
value = Value(dtype="int32", id="question") | ||
suggestion_value = copy(value) | ||
elif question.type == QuestionTypes.ranking: | ||
value = Sequence({"rank": Value(dtype="uint8"), "value": Value(dtype="string")}, id="question") | ||
suggestion_value = copy(value) | ||
elif question.type in QuestionTypes.multi_label_selection: | ||
value = Sequence(Value(dtype="string"), id="question") | ||
suggestion_value = copy(value) | ||
elif question.type in QuestionTypes.span: | ||
value = Sequence( | ||
{ | ||
"start": Value(dtype="int32"), | ||
"end": Value(dtype="int32"), | ||
"label": Value(dtype="string"), | ||
}, | ||
id="question", | ||
) | ||
suggestion_value = Sequence( | ||
{ | ||
"start": Value(dtype="int32"), | ||
"end": Value(dtype="int32"), | ||
"label": Value(dtype="string"), | ||
"score": Value(dtype="float32"), | ||
} | ||
) | ||
else: | ||
raise ValueError( | ||
f"Question {question.name} is of type `{question.type}`," | ||
" for the moment only the following question types are supported:" | ||
f" `{'`, `'.join([arg.value for arg in QuestionTypes])}`." | ||
f" `{'`, `'.join(QuestionTypes.values())}`." | ||
) | ||
|
||
hf_features[question.name] = [ | ||
|
@@ -94,8 +115,8 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset | |
if question.name not in hf_dataset: | ||
hf_dataset[question.name] = [] | ||
|
||
value.id = "suggestion" | ||
hf_features[f"{question.name}-suggestion"] = value | ||
suggestion_value.id = "suggestion" | ||
hf_features[f"{question.name}-suggestion"] = suggestion_value | ||
if f"{question.name}-suggestion" not in hf_dataset: | ||
hf_dataset[f"{question.name}-suggestion"] = [] | ||
|
||
|
@@ -138,6 +159,15 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset | |
} | ||
if question.type == QuestionTypes.ranking: | ||
value = [r.dict() for r in response.values[question.name].value] | ||
elif question.type == QuestionTypes.span: | ||
value = [ | ||
{ | ||
"start": span.start, | ||
"end": span.end, | ||
"label": span.label, | ||
} | ||
for span in response.values[question.name].value | ||
] | ||
else: | ||
value = response.values[question.name].value | ||
formatted_response["value"] = value | ||
|
@@ -421,6 +451,11 @@ def from_huggingface( | |
if value is not None: | ||
if question.type == QuestionTypes.ranking: | ||
value = [{"rank": r, "value": v} for r, v in zip(value["rank"], value["value"])] | ||
elif question.type == QuestionTypes.span: | ||
value = [ | ||
{"start": s, "end": e, "label": l} | ||
for s, e, l in zip(value["start"], value["end"], value["label"]) | ||
] | ||
responses[user_id or "user_without_id"]["values"].update({question.name: {"value": value}}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perhaps we should add else statements here to be sure it raises errors/has behaviour when we add question types? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm using the current approach to support span values mappings. I wouldn't add cross-cutting solution in the current version of the SDK. This is something that we can add for new SDK implementation |
||
|
||
# First if-condition is here for backwards compatibility | ||
|
@@ -431,6 +466,11 @@ def from_huggingface( | |
value = hfds[index][f"{question.name}-suggestion"] | ||
if question.type == QuestionTypes.ranking: | ||
value = [{"rank": r, "value": v} for r, v in zip(value["rank"], value["value"])] | ||
elif question.type == QuestionTypes.span: | ||
value = [ | ||
{"start": s, "end": e, "label": l} | ||
for s, e, l in zip(value["start"], value["end"], value["label"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in terms of human-readability, I also like to add the extracted text normally. Something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change is not hard to add, if it makes sense. |
||
] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perhaps we should add else statements here to be sure it raises errors/has behaviour when we add question types? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm using the current approach to support span values mappings. I wouldn't add cross-cutting solution in the current version of the SDK. This is something that we can add for new SDK implementation |
||
|
||
suggestion = {"question_name": question.name, "value": value} | ||
if hfds[index][f"{question.name}-suggestion-metadata"] is not None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,9 @@ | |
from typing import Any, Dict, List, Literal, Optional, Union | ||
|
||
from argilla.client.feedback.schemas.enums import QuestionTypes | ||
from argilla.client.feedback.schemas.response_values import parse_value_response_for_question | ||
from argilla.client.feedback.schemas.responses import ResponseValue, ValueSchema | ||
from argilla.client.feedback.schemas.suggestions import SuggestionSchema | ||
from argilla.client.feedback.schemas.utils import LabelMappingMixin | ||
from argilla.client.feedback.schemas.validators import title_must_have_value | ||
from argilla.pydantic_v1 import BaseModel, Extra, Field, conint, conlist, root_validator, validator | ||
|
@@ -77,6 +80,16 @@ def to_server_payload(self) -> Dict[str, Any]: | |
"settings": self.server_settings, | ||
} | ||
|
||
def suggestion(self, value: ResponseValue, **kwargs) -> SuggestionSchema: | ||
"""Method that will be used to create a `SuggestionSchema` from the question and a suggested value.""" | ||
value = parse_value_response_for_question(self, value) | ||
return SuggestionSchema(question_name=self.name, value=value, **kwargs) | ||
|
||
def response(self, value: ResponseValue) -> Dict[str, ValueSchema]: | ||
"""Method that will be used to create a response from the question and a value.""" | ||
value = parse_value_response_for_question(self, value) | ||
return {self.name: ValueSchema(value=value)} | ||
|
||
|
||
class TextQuestion(QuestionSchema): | ||
"""Schema for the `FeedbackDataset` text questions, which are the ones that will | ||
|
@@ -334,21 +347,35 @@ class SpanQuestion(QuestionSchema): | |
|
||
Examples: | ||
>>> from argilla.client.feedback.schemas.questions import SpanQuestion | ||
>>> SpanQuestion(name="span_question", title="Span Question", labels=["person", "org"]) | ||
>>> SpanQuestion(name="span_question", field="prompt", title="Span Question", labels=["person", "org"]) | ||
""" | ||
|
||
type: Literal[QuestionTypes.span] = Field(QuestionTypes.span, allow_mutation=False, const=True) | ||
|
||
labels: conlist(Union[str, SpanLabelOption], min_items=1, unique_items=True) | ||
field: str = Field(..., description="The field in the input that the user will be asked to annotate.") | ||
labels: Union[Dict[str, str], conlist(Union[str, SpanLabelOption], min_items=1, unique_items=True)] | ||
|
||
@validator("labels", pre=True) | ||
def parse_labels_dict(cls, labels) -> List[SpanLabelOption]: | ||
if isinstance(labels, dict): | ||
return [SpanLabelOption(value=label, text=text) for label, text in labels.items()] | ||
return labels | ||
|
||
@validator("labels", always=True) | ||
def normalize_labels(cls, v: List[Union[str, SpanLabelOption]]) -> List[SpanLabelOption]: | ||
return [SpanLabelOption(value=label, text=label) if isinstance(label, str) else label for label in v] | ||
|
||
@validator("labels") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wasn't there a max for the labels too which we defined on the server side? perhaps we can use it here as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is not hard to implement. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Anyway, this is not the current behaviour that exists for single and multi-label settings. In both cases, there is a min validation but not a max one. Also, if we plan to support a configurable value for this, adding a hard validation here may introduce workflow problems. So, I will let as is. |
||
def labels_must_be_valid(cls, labels: List[SpanLabelOption]) -> List[SpanLabelOption]: | ||
# This validator is needed since the conlist constraint does not work. | ||
assert len(labels) > 0, "At least one label must be provided" | ||
return labels | ||
|
||
@property | ||
def server_settings(self) -> Dict[str, Any]: | ||
return { | ||
"type": self.type, | ||
"field": self.field, | ||
"options": [label.dict() for label in self.labels], | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in terms of human-readability, I also like to add the extracted text normally. Something like
value["text"][value["start"]:value["end"]]
but perhaps this is difficult top map back into the correct format when callingfrom_huggingface
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is not hard to add, if it makes sense.