Skip to content

Commit

Permalink
feat: span questions support for SDK (#4643)
Browse files Browse the repository at this point in the history
<!-- Thanks for your contribution! As part of our Community Growers
initiative 🌱, we're donating Justdiggit bunds in your name to reforest
sub-Saharan Africa. To claim your Community Growers certificate, please
contact David Berenstein in our Slack community or fill in this form
https://tally.so/r/n9XrxK once your PR has been merged. -->

# Description

This PR contains the feature branch for span question support from SDK,
allowing to create span questions and providing suggestions and
responses from SDK.

Ref #1868

Closes #4618 
Closes #4635

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [X] New feature (non-breaking change which adds functionality)
- [ ] Refactor (change restructuring the codebase without changing
functionality)
- [X] Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

Tested locally

- [ ] Test A
- [ ] Test B

**Checklist**

- [ ] I added relevant documentation
- [ ] I followed the style guidelines of this project
- [X] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [X] I have added relevant notes to the `CHANGELOG.md` file (See
https://keepachangelog.com/)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Francisco Aranda <[email protected]>
  • Loading branch information
3 people authored Mar 18, 2024
1 parent 93e910c commit 5d2bfd2
Show file tree
Hide file tree
Showing 25 changed files with 1,071 additions and 290 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ These are the section headers that we use:

### Added

- If you expand the labels of a `single or multi` label Question, the state is maintained during the entire annotation process. ([#4630](https://github.com/argilla-io/argilla/pull/4630))
- Added support for span questions in the Python SDK. ([#4617](https://github.com/argilla-io/argilla/pull/4617))
- Added support for spans values in suggestions and responses. ([#4623](https://github.com/argilla-io/argilla/pull/4623))
- Added `Span` questions for `FeedbackDataset` ([#4622](https://github.com/argilla-io/argilla/pull/4622))
- If you expand the labels of a `single or multi` label Question, the state is maintained during the entire annotation process ([#4630](https://github.com/argilla-io/argilla/pull/4630))

### Fixed

Expand Down
2 changes: 1 addition & 1 deletion environment_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,6 @@ dependencies:
- ipynbname>=2023.2.0.0
- httpx~=0.26.0
# For now we can just install argilla-server from the GitHub repo
- git+https://github.com/argilla-io/argilla-server.git@main
- git+https://github.com/argilla-io/argilla-server.git
# install Argilla in editable mode
- -e .[listeners]
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ dynamic = ["version"]
[project.optional-dependencies]

server = [
"argilla-server ~= 1.25.0", # TODO: Fix the versions when publishing a new release
"argilla-server", # TODO: Fix the versions when publishing a new release
]
server-postgresql = [
"argilla-server[postgresql] ~= 1.25.0", # TODO: Fix the versions when publishing a new release
"argilla-server[postgresql]", # TODO: Fix the versions when publishing a new release
]
listeners = ["schedule ~= 1.1.0", "prodict ~= 0.8.0"]
integrations = [
Expand Down
27 changes: 3 additions & 24 deletions src/argilla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,30 +74,7 @@
configure_dataset_settings,
load_dataset_settings,
)
from argilla.feedback import (
FeedbackDataset,
FeedbackRecord,
FloatMetadataFilter,
FloatMetadataProperty,
IntegerMetadataFilter,
IntegerMetadataProperty,
LabelQuestion,
MultiLabelQuestion,
RankingQuestion,
RatingQuestion,
RecordSortField,
ResponseSchema,
ResponseStatusFilter,
SortBy,
SortOrder,
SuggestionSchema,
TermsMetadataFilter,
TermsMetadataProperty,
TextField,
TextQuestion,
ValueSchema,
VectorSettings,
)
from argilla.feedback import * # noqa
from argilla.listeners import Metrics, RGListenerContext, Search, listener
from argilla.monitoring.model_monitor import monitor

Expand All @@ -115,6 +92,8 @@
"MultiLabelQuestion",
"RatingQuestion",
"RankingQuestion",
"SpanQuestion",
"SpanLabelOption",
"ResponseSchema",
"ResponseStatusFilter",
"TextField",
Expand Down
22 changes: 6 additions & 16 deletions src/argilla/client/feedback/dataset/local/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,8 @@
RemoteTermsMetadataProperty,
)
from argilla.client.feedback.schemas.remote.questions import (
RemoteLabelQuestion,
RemoteMultiLabelQuestion,
RemoteRankingQuestion,
RemoteRatingQuestion,
RemoteTextQuestion,
QUESTION_TYPE_TO_QUESTION,
AllowedRemoteQuestionTypes,
)
from argilla.client.feedback.schemas.types import AllowedMetadataPropertyTypes
from argilla.client.feedback.schemas.vector_settings import VectorSettings
Expand Down Expand Up @@ -125,20 +122,13 @@ def __add_fields(

@staticmethod
def _parse_to_remote_question(question: "FeedbackQuestionModel") -> "AllowedRemoteQuestionTypes":
if question.settings["type"] == QuestionTypes.rating:
question = RemoteRatingQuestion.from_api(question)
elif question.settings["type"] == QuestionTypes.text:
question = RemoteTextQuestion.from_api(question)
elif question.settings["type"] == QuestionTypes.label_selection:
question = RemoteLabelQuestion.from_api(question)
elif question.settings["type"] == QuestionTypes.multi_label_selection:
question = RemoteMultiLabelQuestion.from_api(question)
elif question.settings["type"] == QuestionTypes.ranking:
question = RemoteRankingQuestion.from_api(question)
question_type = question.settings["type"]
if question_type in QUESTION_TYPE_TO_QUESTION:
question = QUESTION_TYPE_TO_QUESTION[question_type].from_api(question)
else:
raise ValueError(
f"Question '{question.name}' is not a supported question in the current Python package"
f" version, supported question types are: `{'`, `'.join([arg.value for arg in QuestionTypes])}`."
f" version, supported question types are: `{'`, `'.join(QuestionTypes.values())}`."
)

return question
Expand Down
68 changes: 62 additions & 6 deletions src/argilla/client/feedback/integrations/huggingface/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import tempfile
import warnings
from copy import copy
from typing import TYPE_CHECKING, Any, Optional, Type, Union

from packaging.version import parse as parse_version
Expand Down Expand Up @@ -50,7 +50,7 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset
questions, and metadata_properties formatted as `datasets.Features`.
Examples:
>>> from argilla.client.feedback.integrations.dataset import HuggingFaceDatasetMixin
>>> from argilla.client.feedback.integrations.huggingface import HuggingFaceDatasetMixin
>>> dataset = FeedbackDataset(...) or RemoteFeedbackDataset(...)
>>> huggingface_dataset = HuggingFaceDatasetMixin._huggingface_format(dataset)
"""
Expand All @@ -71,17 +71,40 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset
for question in dataset.questions:
if question.type in [QuestionTypes.text, QuestionTypes.label_selection]:
value = Value(dtype="string", id="question")
suggestion_value = copy(value)
elif question.type == QuestionTypes.rating:
value = Value(dtype="int32", id="question")
suggestion_value = copy(value)
elif question.type == QuestionTypes.ranking:
value = Sequence({"rank": Value(dtype="uint8"), "value": Value(dtype="string")}, id="question")
suggestion_value = copy(value)
elif question.type in QuestionTypes.multi_label_selection:
value = Sequence(Value(dtype="string"), id="question")
suggestion_value = copy(value)
elif question.type in QuestionTypes.span:
value = Sequence(
{
"start": Value(dtype="int32"),
"end": Value(dtype="int32"),
"label": Value(dtype="string"),
"text": Value(dtype="string"),
},
id="question",
)
suggestion_value = Sequence(
{
"start": Value(dtype="int32"),
"end": Value(dtype="int32"),
"label": Value(dtype="string"),
"text": Value(dtype="string"),
"score": Value(dtype="float32"),
}
)
else:
raise ValueError(
f"Question {question.name} is of type `{question.type}`,"
" for the moment only the following question types are supported:"
f" `{'`, `'.join([arg.value for arg in QuestionTypes])}`."
f" `{'`, `'.join(QuestionTypes.values())}`."
)

hf_features[question.name] = [
Expand All @@ -94,8 +117,8 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset
if question.name not in hf_dataset:
hf_dataset[question.name] = []

value.id = "suggestion"
hf_features[f"{question.name}-suggestion"] = value
suggestion_value.id = "suggestion"
hf_features[f"{question.name}-suggestion"] = suggestion_value
if f"{question.name}-suggestion" not in hf_dataset:
hf_dataset[f"{question.name}-suggestion"] = []

Expand Down Expand Up @@ -138,6 +161,16 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset
}
if question.type == QuestionTypes.ranking:
value = [r.dict() for r in response.values[question.name].value]
elif question.type == QuestionTypes.span:
value = [
{
"start": span.start,
"end": span.end,
"label": span.label,
"text": record.fields[question.field][span.start : span.end],
}
for span in response.values[question.name].value
]
else:
value = response.values[question.name].value
formatted_response["value"] = value
Expand All @@ -148,7 +181,20 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset
if record.suggestions:
for suggestion in record.suggestions:
if question.name == suggestion.question_name:
suggestion_value = suggestion.dict(include={"value"})["value"]
if question.type == QuestionTypes.span:
suggestion_value = [
{
"start": span.start,
"end": span.end,
"label": span.label,
"score": span.score,
"text": record.fields[question.field][span.start : span.end],
}
for span in suggestion.value
]
else:
suggestion_value = suggestion.dict(include={"value"})["value"]

suggestion_metadata = {
"type": suggestion.type,
"score": suggestion.score,
Expand Down Expand Up @@ -421,6 +467,11 @@ def from_huggingface(
if value is not None:
if question.type == QuestionTypes.ranking:
value = [{"rank": r, "value": v} for r, v in zip(value["rank"], value["value"])]
elif question.type == QuestionTypes.span:
value = [
{"start": s, "end": e, "label": l}
for s, e, l in zip(value["start"], value["end"], value["label"])
]
responses[user_id or "user_without_id"]["values"].update({question.name: {"value": value}})

# First if-condition is here for backwards compatibility
Expand All @@ -431,6 +482,11 @@ def from_huggingface(
value = hfds[index][f"{question.name}-suggestion"]
if question.type == QuestionTypes.ranking:
value = [{"rank": r, "value": v} for r, v in zip(value["rank"], value["value"])]
elif question.type == QuestionTypes.span:
value = [
{"start": s, "end": e, "label": l}
for s, e, l in zip(value["start"], value["end"], value["label"])
]

suggestion = {"question_name": question.name, "value": value}
if hfds[index][f"{question.name}-suggestion-metadata"] is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/client/feedback/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def get_unified_responses_and_suggestions(
unified_responses = [
tuple(ranking_schema.rank for ranking_schema in response) for response in unified_responses
]
suggestions = [tuple(s["rank"] for s in suggestion) for suggestion in suggestions]
suggestions = [tuple(s.rank for s in suggestion) for suggestion in suggestions]

return unified_responses, suggestions

Expand Down
21 changes: 12 additions & 9 deletions src/argilla/client/feedback/schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,14 @@
QuestionSchema,
RankingQuestion,
RatingQuestion,
SpanLabelOption,
SpanQuestion,
TextQuestion,
)
from argilla.client.feedback.schemas.records import (
FeedbackRecord,
RankingValueSchema,
ResponseSchema,
SortBy,
SuggestionSchema,
ValueSchema,
)
from argilla.client.feedback.schemas.records import FeedbackRecord, SortBy
from argilla.client.feedback.schemas.response_values import RankingValueSchema, ResponseValue, SpanValueSchema
from argilla.client.feedback.schemas.responses import ResponseSchema, ResponseStatus, ValueSchema
from argilla.client.feedback.schemas.suggestions import SuggestionSchema
from argilla.client.feedback.schemas.vector_settings import VectorSettings

__all__ = [
Expand All @@ -62,11 +60,16 @@
"RankingQuestion",
"RatingQuestion",
"TextQuestion",
"SpanQuestion",
"SpanLabelOption",
"FeedbackRecord",
"RankingValueSchema",
"ResponseSchema",
"ResponseValue",
"ResponseStatus",
"SuggestionSchema",
"ValueSchema",
"RankingValueSchema",
"SpanValueSchema",
"SortOrder",
"SortBy",
"RecordSortField",
Expand Down
7 changes: 6 additions & 1 deletion src/argilla/client/feedback/schemas/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from typing import List


class FieldTypes(str, Enum):
Expand All @@ -25,6 +25,11 @@ class QuestionTypes(str, Enum):
label_selection = "label_selection"
multi_label_selection = "multi_label_selection"
ranking = "ranking"
span = "span"

@classmethod
def values(cls) -> List[str]:
return [_type.value for _type in cls]


class MetadataPropertyTypes(str, Enum):
Expand Down
Loading

0 comments on commit 5d2bfd2

Please sign in to comment.