Skip to content

Commit

Permalink
[AL-7548] Improve test coverage for send to annotate (#1322)
Browse files Browse the repository at this point in the history
  • Loading branch information
richardsun0713 authored Dec 11, 2023
1 parent afbb6d4 commit f16da35
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 29 deletions.
4 changes: 2 additions & 2 deletions labelbox/schema/send_to_annotate_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class SendToAnnotateFromCatalogParams(TypedDict):
:param override_existing_annotations_rule: Optional[ConflictResolutionStrategy] - The strategy defining how to
handle conflicts in classifications between the data rows that already exist in the project and incoming
predictions from the source model run or annotations from the source project. Defaults to
ConflictResolutionStrategy.SKIP.
ConflictResolutionStrategy.KEEP_EXISTING.
:param batch_priority: Optional[int] - The priority of the batch. Defaults to 5.
"""

Expand All @@ -49,7 +49,7 @@ class SendToAnnotateFromModelParams(TypedDict):
to False.
:param override_existing_annotations_rule: Optional[ConflictResolutionStrategy] - The strategy defining how to
handle conflicts in classifications between the data rows that already exist in the project and incoming
predictions from the source model run. Defaults to ConflictResolutionStrategy.SKIP.
predictions from the source model run. Defaults to ConflictResolutionStrategy.KEEP_EXISTING.
:param batch_priority: Optional[int] - The priority of the batch. Defaults to 5.
"""

Expand Down
15 changes: 15 additions & 0 deletions tests/integration/annotation_import/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,21 @@ def configured_project(client, initial_dataset, ontology, rand_gen, image_url):
project.delete()


@pytest.fixture
def project_with_ontology(client, configured_project, ontology, rand_gen):
project = client.create_project(name=rand_gen(str),
queue_mode=QueueMode.Batch,
media_type=MediaType.Image)
editor = list(
client.get_labeling_frontends(
where=LabelingFrontend.name == "editor"))[0]
project.setup(editor, ontology)

yield project, ontology

project.delete()


@pytest.fixture
def configured_project_pdf(client, ontology, rand_gen, pdf_url):
project = client.create_project(name=rand_gen(str),
Expand Down
41 changes: 35 additions & 6 deletions tests/integration/annotation_import/test_send_to_annotate_mea.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from labelbox import UniqueIds
from labelbox import UniqueIds, OntologyBuilder
from labelbox.schema.conflict_resolution_strategy import ConflictResolutionStrategy


def test_send_to_annotate_from_model(client, configured_project,
Expand All @@ -9,19 +10,43 @@ def test_send_to_annotate_from_model(client, configured_project,
project_with_ontology):
model_run = model_run_with_data_rows
data_row_ids = [p['dataRow']['id'] for p in model_run_predictions]
assert len(data_row_ids) > 0

[destination_project, _] = project_with_ontology
destination_project, _ = project_with_ontology

queues = destination_project.task_queues()
initial_labeling_task = next(
q for q in queues if q.name == "Initial labeling task")
initial_review_task = next(
q for q in queues if q.name == "Initial review task")

# build an ontology mapping using the top level tools and classifications
source_ontology_builder = OntologyBuilder.from_project(configured_project)
feature_schema_ids = list(
tool.feature_schema_id for tool in source_ontology_builder.tools)
# create a dictionary of feature schema id to itself
ontology_mapping = dict(zip(feature_schema_ids, feature_schema_ids))

classification_feature_schema_ids = list(
classification.feature_schema_id
for classification in source_ontology_builder.classifications)
# create a dictionary of feature schema id to itself
classification_ontology_mapping = dict(
zip(classification_feature_schema_ids,
classification_feature_schema_ids))

# combine the two ontology mappings
ontology_mapping.update(classification_ontology_mapping)

task = model_run.send_to_annotate_from_model(
destination_project_id=destination_project.uid,
batch_name="batch",
data_rows=UniqueIds(data_row_ids),
task_queue_id=initial_labeling_task.uid,
params={})
task_queue_id=initial_review_task.uid,
params={
"predictions_ontology_mapping":
ontology_mapping,
"override_existing_annotations_rule":
ConflictResolutionStrategy.OverrideWithPredictions
})

task.wait_till_done()

Expand All @@ -32,3 +57,7 @@ def test_send_to_annotate_from_model(client, configured_project,
destination_data_rows = list(destination_batches[0].export_data_rows())
assert len(destination_data_rows) == len(data_row_ids)
assert all([dr.uid in data_row_ids for dr in destination_data_rows])

# Since data rows were added to a review queue, predictions should be imported into the project as labels
destination_project_labels = (list(destination_project.labels()))
assert len(destination_project_labels) == len(data_row_ids)
14 changes: 1 addition & 13 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,18 +372,6 @@ def project_with_empty_ontology(project):
yield project


@pytest.fixture
def project_with_ontology(client, rand_gen):
project = client.create_project(name=rand_gen(str),
queue_mode=QueueMode.Batch,
media_type=MediaType.Image)
ontology = _setup_ontology(project)

yield [project, ontology]

project.delete()


@pytest.fixture
def configured_project(project_with_empty_ontology, initial_dataset, rand_gen,
image_url):
Expand Down Expand Up @@ -520,7 +508,7 @@ def _setup_ontology(project):
project.setup(editor, ontology_builder.asdict())
# TODO: ontology may not be synchronous after setup. remove sleep when api is more consistent
time.sleep(2)
return ontology_builder.from_project(project)
return OntologyBuilder.from_project(project)


@pytest.fixture
Expand Down
32 changes: 24 additions & 8 deletions tests/integration/test_send_to_annotate.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,40 @@
import pytest

from labelbox import UniqueIds
from labelbox import UniqueIds, OntologyBuilder, LabelingFrontend
from labelbox.schema.conflict_resolution_strategy import ConflictResolutionStrategy


def test_send_to_annotate_include_annotations(
client, configured_batch_project_with_label, project_with_ontology):
client, configured_batch_project_with_label, project_pack):
[source_project, _, data_row, _] = configured_batch_project_with_label
[destination_project, _] = project_with_ontology
destination_project = project_pack[0]

source_ontology_builder = OntologyBuilder.from_project(source_project)
editor = list(
client.get_labeling_frontends(
where=LabelingFrontend.name == "editor"))[0]
destination_project.setup(editor, source_ontology_builder.asdict())

# build an ontology mapping using the top level tools
feature_schema_ids = list(
tool.feature_schema_id for tool in source_ontology_builder.tools)
# create a dictionary of feature schema id to itself
ontology_mapping = dict(zip(feature_schema_ids, feature_schema_ids))

try:
queues = destination_project.task_queues()
initial_labeling_task = next(
q for q in queues if q.name == "Initial labeling task")
initial_review_task = next(
q for q in queues if q.name == "Initial review task")

# Send the data row to the new project
task = client.send_to_annotate_from_catalog(
destination_project_id=destination_project.uid,
task_queue_id=initial_labeling_task.uid,
task_queue_id=initial_review_task.uid,
batch_name="test-batch",
data_rows=UniqueIds([data_row.uid]),
params={
"source_project_id":
source_project.uid,
"annotations_ontology_mapping":
ontology_mapping,
"override_existing_annotations_rule":
ConflictResolutionStrategy.OverrideWithAnnotations
})
Expand All @@ -36,5 +48,9 @@ def test_send_to_annotate_include_annotations(
destination_data_rows = list(destination_batches[0].export_data_rows())
assert len(destination_data_rows) == 1
assert destination_data_rows[0].uid == data_row.uid

# Verify annotations were copied into the destination project
destination_project_labels = (list(destination_project.labels()))
assert len(destination_project_labels) == 1
finally:
destination_project.delete()

0 comments on commit f16da35

Please sign in to comment.