Skip to content

Commit 745a869

Browse files
[PTDT-2372] Added support for MMC tasks annotations (#1787)
1 parent 906c2e9 commit 745a869

File tree

8 files changed

+189
-4
lines changed

8 files changed

+189
-4
lines changed

libs/labelbox/src/labelbox/data/annotation_types/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,6 @@
6363
from .data.tiled_image import TileLayer
6464

6565
from .llm_prompt_response.prompt import PromptText
66-
from .llm_prompt_response.prompt import PromptClassificationAnnotation
66+
from .llm_prompt_response.prompt import PromptClassificationAnnotation
67+
68+
from .mmc import MessageInfo, OrderedMessageInfo, MessageSingleSelectionTask, MessageMultiSelectionTask, MessageRankingTask, MessageEvaluationTaskAnnotation

libs/labelbox/src/labelbox/data/annotation_types/label.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .types import Cuid
1919
from .video import VideoClassificationAnnotation
2020
from .video import VideoObjectAnnotation, VideoMaskAnnotation
21+
from .mmc import MessageEvaluationTaskAnnotation
2122
from ..ontology import get_feature_schema_lookup
2223

2324
DataType = Union[VideoData, ImageData, TextData, TiledImageData, AudioData,
@@ -51,7 +52,7 @@ class Label(pydantic_compat.BaseModel):
5152
annotations: List[Union[ClassificationAnnotation, ObjectAnnotation,
5253
VideoMaskAnnotation, ScalarMetric,
5354
ConfusionMatrixMetric, RelationshipAnnotation,
54-
PromptClassificationAnnotation]] = []
55+
PromptClassificationAnnotation, MessageEvaluationTaskAnnotation]] = []
5556
extra: Dict[str, Any] = {}
5657
is_benchmark_reference: Optional[bool] = False
5758

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from abc import ABC
2+
from typing import ClassVar, List, Union
3+
4+
from labelbox import pydantic_compat
5+
from labelbox.utils import _CamelCaseMixin
6+
from labelbox.data.annotation_types.annotation import BaseAnnotation
7+
8+
9+
class MessageInfo(_CamelCaseMixin):
10+
message_id: str
11+
model_config_name: str
12+
13+
14+
class OrderedMessageInfo(MessageInfo):
15+
order: int
16+
17+
18+
class _BaseMessageEvaluationTask(_CamelCaseMixin, ABC):
19+
format: ClassVar[str]
20+
parent_message_id: str
21+
22+
23+
class MessageSingleSelectionTask(_BaseMessageEvaluationTask, MessageInfo):
24+
format: ClassVar[str] = "message-single-selection"
25+
26+
27+
class MessageMultiSelectionTask(_BaseMessageEvaluationTask):
28+
format: ClassVar[str] = "message-multi-selection"
29+
selected_messages: List[MessageInfo]
30+
31+
32+
class MessageRankingTask(_BaseMessageEvaluationTask):
33+
format: ClassVar[str] = "message-ranking"
34+
ranked_messages: List[OrderedMessageInfo]
35+
36+
@pydantic_compat.validator("ranked_messages")
37+
def _validate_ranked_messages(cls, v: List[OrderedMessageInfo]):
38+
if not {msg.order for msg in v} == set(range(1, len(v) + 1)):
39+
raise ValueError("Messages must be ordered by unique and consecutive natural numbers starting from 1")
40+
return v
41+
42+
43+
class MessageEvaluationTaskAnnotation(BaseAnnotation):
44+
value: Union[MessageSingleSelectionTask, MessageMultiSelectionTask, MessageRankingTask]

libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from ...annotation_types.collection import LabelCollection, LabelGenerator
1616
from ...annotation_types.relationship import RelationshipAnnotation
17+
from ...annotation_types.mmc import MessageEvaluationTaskAnnotation
1718
from .label import NDLabel
1819

1920
logger = logging.getLogger(__name__)
@@ -71,8 +72,9 @@ def serialize(
7172
ScalarMetric,
7273
ConfusionMatrixMetric,
7374
RelationshipAnnotation,
75+
MessageEvaluationTaskAnnotation,
7476
]] = []
75-
# First pass to get all RelatiohnshipAnnotaitons
77+
# First pass to get all RelationshipAnnotaitons
7678
# and update the UUIDs of the source and target annotations
7779
for annotation in label.annotations:
7880
if isinstance(annotation, RelationshipAnnotation):

libs/labelbox/src/labelbox/data/serialization/ndjson/label.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,19 @@
1818
from ...annotation_types.classification import Dropdown
1919
from ...annotation_types.metrics import ScalarMetric, ConfusionMatrixMetric
2020
from ...annotation_types.llm_prompt_response.prompt import PromptClassificationAnnotation
21+
from ...annotation_types.mmc import MessageEvaluationTaskAnnotation
2122

2223
from .metric import NDScalarMetric, NDMetricAnnotation, NDConfusionMatrixMetric
2324
from .classification import NDChecklistSubclass, NDClassification, NDClassificationType, NDRadioSubclass, NDPromptClassification, NDPromptClassificationType, NDPromptText
2425
from .objects import NDObject, NDObjectType, NDSegments, NDDicomSegments, NDVideoMasks, NDDicomMasks
26+
from .mmc import NDMessageTask
2527
from .relationship import NDRelationship
2628
from .base import DataRow
2729

2830
AnnotationType = Union[NDObjectType, NDClassificationType, NDPromptClassificationType,
2931
NDConfusionMatrixMetric, NDScalarMetric, NDDicomSegments,
3032
NDSegments, NDDicomMasks, NDVideoMasks, NDRelationship,
31-
NDPromptText]
33+
NDPromptText, NDMessageTask]
3234

3335

3436
class NDLabel(pydantic_compat.BaseModel):
@@ -126,6 +128,8 @@ def _generate_annotations(
126128
elif isinstance(ndjson_annotation, NDPromptClassificationType):
127129
annotation = NDPromptClassification.to_common(ndjson_annotation)
128130
annotations.append(annotation)
131+
elif isinstance(ndjson_annotation, NDMessageTask):
132+
annotations.append(ndjson_annotation.to_common())
129133
else:
130134
raise TypeError(
131135
f"Unsupported annotation. {type(ndjson_annotation)}")
@@ -277,6 +281,8 @@ def _create_non_video_annotations(cls, label: Label):
277281
yield NDRelationship.from_common(annotation, label.data)
278282
elif isinstance(annotation, PromptClassificationAnnotation):
279283
yield NDPromptClassification.from_common(annotation, label.data)
284+
elif isinstance(annotation, MessageEvaluationTaskAnnotation):
285+
yield NDMessageTask.from_common(annotation, label.data)
280286
else:
281287
raise TypeError(
282288
f"Unable to convert object to MAL format. `{type(getattr(annotation, 'value',annotation))}`"
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from typing import Any, Dict, List, Optional, Union
2+
3+
from labelbox.utils import _CamelCaseMixin
4+
5+
from .base import DataRow, NDAnnotation
6+
from ...annotation_types.types import Cuid
7+
from ...annotation_types.mmc import MessageSingleSelectionTask, MessageMultiSelectionTask, MessageRankingTask, MessageEvaluationTaskAnnotation
8+
9+
10+
class MessageTaskData(_CamelCaseMixin):
11+
format: str
12+
data: Union[MessageSingleSelectionTask, MessageMultiSelectionTask, MessageRankingTask]
13+
14+
15+
class NDMessageTask(NDAnnotation):
16+
17+
message_evaluation_task: MessageTaskData
18+
19+
def to_common(self) -> MessageEvaluationTaskAnnotation:
20+
return MessageEvaluationTaskAnnotation(
21+
name=self.name,
22+
feature_schema_id=self.schema_id,
23+
value=self.message_evaluation_task.data,
24+
extra={"uuid": self.uuid},
25+
)
26+
27+
@classmethod
28+
def from_common(
29+
cls,
30+
annotation: MessageEvaluationTaskAnnotation,
31+
data: Any#Union[ImageData, TextData],
32+
) -> "NDMessageTask":
33+
return cls(
34+
uuid=str(annotation._uuid),
35+
name=annotation.name,
36+
schema_id=annotation.feature_schema_id,
37+
data_row=DataRow(id=data.uid, global_key=data.global_key),
38+
message_evaluation_task=MessageTaskData(
39+
format=annotation.value.format,
40+
data=annotation.value
41+
)
42+
)
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
[
2+
{
3+
"dataRow": {
4+
"id": "cnjencjencjfencvj"
5+
},
6+
"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72",
7+
"name": "single-selection",
8+
"messageEvaluationTask": {
9+
"format": "message-single-selection",
10+
"data": {
11+
"messageId": "clxfzocbm00083b6v8vczsept",
12+
"parentMessageId": "clxfznjb800073b6v43ppx9ca",
13+
"modelConfigName": "GPT 5"
14+
}
15+
}
16+
},
17+
{
18+
"dataRow": {
19+
"id": "cfcerfvergerfefj"
20+
},
21+
"uuid": "gferf3a57-597e-48cb-8d8d-a8526fefe72",
22+
"name": "multi-selection",
23+
"messageEvaluationTask": {
24+
"format": "message-multi-selection",
25+
"data": {
26+
"parentMessageId": "clxfznjb800073b6v43ppx9ca",
27+
"selectedMessages": [
28+
{
29+
"messageId": "clxfzocbm00083b6v8vczsept",
30+
"modelConfigName": "GPT 5"
31+
}
32+
]
33+
}
34+
}
35+
},
36+
{
37+
"dataRow": {
38+
"id": "cwefgtrgrthveferfferffr"
39+
},
40+
"uuid": "hybe3a57-5gt7e-48tgrb-8d8d-a852dswqde72",
41+
"name": "ranking",
42+
"messageEvaluationTask": {
43+
"format": "message-ranking",
44+
"data": {
45+
"parentMessageId": "clxfznjb800073b6v43ppx9ca",
46+
"rankedMessages": [
47+
{
48+
"messageId": "clxfzocbm00083b6v8vczsept",
49+
"modelConfigName": "GPT 4 with temperature 0.7",
50+
"order": 1
51+
},
52+
{
53+
"messageId": "clxfzocbm00093b6vx4ndisub",
54+
"modelConfigName": "GPT 5",
55+
"order": 2
56+
}
57+
]
58+
}
59+
}
60+
}
61+
]
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import json
2+
3+
import pytest
4+
5+
from labelbox.data.serialization import NDJsonConverter
6+
from labelbox.pydantic_compat import ValidationError
7+
8+
9+
def test_message_task_annotation_serialization():
10+
with open('tests/data/assets/ndjson/mmc_import.json', 'r') as file:
11+
data = json.load(file)
12+
13+
deserialized = list(NDJsonConverter.deserialize(data))
14+
reserialized = list(NDJsonConverter.serialize(deserialized))
15+
16+
assert data == reserialized
17+
18+
19+
def test_mesage_ranking_task_wrong_order_serialization():
20+
with open('tests/data/assets/ndjson/mmc_import.json', 'r') as file:
21+
data = json.load(file)
22+
23+
some_ranking_task = next(task for task in data if task["messageEvaluationTask"]["format"] == "message-ranking")
24+
some_ranking_task["messageEvaluationTask"]["data"]["rankedMessages"][0]["order"] = 3
25+
26+
with pytest.raises(ValidationError):
27+
list(NDJsonConverter.deserialize([some_ranking_task]))

0 commit comments

Comments
 (0)