Skip to content

Commit

Permalink
Merge branch '173-not-to-fail-at-missing-labeling-files' into 'develop'
Browse files Browse the repository at this point in the history
173 Not to fail at missing labeling files

Closes #173

See merge request epm-emrd/model_garden!214
  • Loading branch information
druzhynin-oleksii committed Sep 8, 2020
2 parents d3a9584 + 51f550f commit 3b5b189
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 26 deletions.
26 changes: 18 additions & 8 deletions backend/model_garden/management/commands/process_task_statuses.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
logger = logging.getLogger(__name__)


class NoAnnotationException(Exception):
pass


class Command(BaseCommand):
help = "Process tasks statuses"

Expand Down Expand Up @@ -74,6 +78,9 @@ def _process_labeling_tasks(self, labeling_tasks: List[LabelingTask]) -> None:
for labeling_task, result_future in self._upload_annotations(labeling_tasks=labeling_tasks_to_upload):
try:
result_future.result()
except NoAnnotationException as noAnnotationException:
logger.error(f"{noAnnotationException}")
labeling_task.set_failed(error=f"{noAnnotationException}")
except Exception as e:
logger.error(f"{e}")
labeling_task.set_failed(error=f"{e}")
Expand Down Expand Up @@ -131,6 +138,7 @@ def _upload_labeling_task_annotations(self, labeling_task: LabelingTask):

zip_fp = BytesIO(annotations_content_zip)
zf = ZipFile(file=zip_fp)

annotation_filenames = {
os.path.split(zi.filename)[-1]: zf.open(zi) for zi in zf.filelist if
zi.filename.startswith(self._get_cvat_zip_folderpath(annotation_frmt))
Expand All @@ -143,20 +151,22 @@ def _upload_labeling_task_annotations(self, labeling_task: LabelingTask):
logger.info(media_assets_filenames)
logger.info(annotation_filenames)

missing_annotation_filenames = media_assets_filenames - set(annotation_filenames)
if missing_annotation_filenames:
raise Exception(f"Missing task annotations: {', '.join(sorted(missing_annotation_filenames))}")
if len(annotation_filenames) == 0:
raise NoAnnotationException(f"Missing all task annotations for task :{str(labeling_task.task_id)}")

for media_asset in media_assets:
try:
asset_filename = os.path.splitext(media_asset.filename)[0]
bucket_name = media_asset.dataset.bucket.name
s3_client = S3Client(bucket_name=bucket_name)
s3_client.upload_file_obj(
file_obj=annotation_filenames[f"{asset_filename}" + self._get_label_file_extension(annotation_frmt)],
bucket=bucket_name,
key=media_asset.full_label_path,
)
file_name = f"{asset_filename}" + self._get_label_file_extension(annotation_frmt)
if file_name in annotation_filenames:
file_object = annotation_filenames[f"{asset_filename}" + self._get_label_file_extension(annotation_frmt)]
s3_client.upload_file_obj(
file_obj=file_object,
bucket=bucket_name,
key=media_asset.full_label_path,
)
logger.info(f"Uploaded annotation '{media_asset.full_label_path}'")
except Exception as e:
raise Exception(f"Failed to upload task annotations: {e}")
16 changes: 4 additions & 12 deletions backend/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from io import BytesIO
from typing import Optional
from zipfile import ZipFile

from django.test import TestCase, TransactionTestCase
from rest_framework.test import APITestCase

from model_garden.constants import LabelingTaskStatus
from model_garden.models import Bucket, Dataset, MediaAsset, Labeler, LabelingTask
from .test_zip_file_content_factory import ZipFileContentFactory


class Factory:
Expand Down Expand Up @@ -76,30 +75,23 @@ def create_labeling_task(
error=error,
)

@staticmethod
def get_zip_file(*files):
file = BytesIO()
zip_file = ZipFile(file=file, mode='w')
for file_path, file_content in files:
zip_file.writestr(file_path, file_content)

zip_file.close()
return file.getvalue()


class BaseTestCase(TestCase):
def setUp(self):
super().setUp()
self.test_factory = Factory()
self.test_zip_file_factory = ZipFileContentFactory()


class BaseAPITestCase(APITestCase):
def setUp(self):
super().setUp()
self.test_factory = Factory()
self.test_zip_file_factory = ZipFileContentFactory()


class BaseTransactionTestCase(TransactionTestCase):
def setUp(self):
super().setUp()
self.test_factory = Factory()
self.test_zip_file_factory = ZipFileContentFactory()
36 changes: 30 additions & 6 deletions backend/tests/management/commands/test_process_task_statuses.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import tempfile
from unittest import mock

from django.core import management
Expand All @@ -16,9 +17,6 @@ def setUp(self):
}
self.s3_client_patcher = mock.patch('model_garden.management.commands.process_task_statuses.S3Client')
self.s3_client_mock = self.s3_client_patcher.start().return_value
self.cvat_service_mock.get_annotations.return_value = self.test_factory.get_zip_file(
('Annotations/test.xml', 'test'),
)

def tearDown(self):
self.s3_client_patcher.stop()
Expand All @@ -27,6 +25,9 @@ def tearDown(self):

def test_handle(self):
media_asset = self.test_factory.create_media_asset(filename='test.jpg', assigned=True)
self.cvat_service_mock.get_annotations.return_value = self.test_zip_file_factory.get_zip_file_content(
('Annotations/test.xml', 'test'),
)
labeling_task = media_asset.labeling_task

management.call_command('process_task_statuses')
Expand Down Expand Up @@ -74,6 +75,9 @@ def test_handle_get_annotations_error(self):
self.assertEqual(labeling_task.error, "Failed to get task annotations: CVAT error")

def test_handle_s3_upload_error(self):
self.cvat_service_mock.get_annotations.return_value = self.test_zip_file_factory.get_zip_file_content(
('Annotations/test.xml', 'test'),
)
self.s3_client_mock.upload_file_obj.side_effect = Exception('S3 error')
media_asset = self.test_factory.create_media_asset(filename='test.jpg', assigned=True)
labeling_task = media_asset.labeling_task
Expand All @@ -83,17 +87,37 @@ def test_handle_s3_upload_error(self):
labeling_task.refresh_from_db()
self.assertEqual(labeling_task.error, "Failed to upload task annotations: S3 error")

def test_handle_missing_task_annotations(self):
media_asset = self.test_factory.create_media_asset(filename='test2.jpg', assigned=True)
def test_handle_missing_one_annotation_filename(self):
self.cvat_service_mock.get_annotations.return_value = self.test_zip_file_factory.get_zip_file_content(
('Annotations/test.xml', 'test'),
)
media_asset = self.test_factory.create_media_asset(filename='test.jpg', assigned=True)
media_asset2 = self.test_factory.create_media_asset(filename='test2.jpg', assigned=True)
labeling_task = media_asset.labeling_task
labeling_task2 = media_asset2.labeling_task

management.call_command('process_task_statuses')

labeling_task.refresh_from_db()
self.assertEqual(labeling_task.error, "Missing task annotations: test2.xml")
labeling_task2.refresh_from_db()

self.assertEqual(labeling_task.error, None)
self.assertEqual(labeling_task2.error, None)

@mock.patch('model_garden.management.commands.process_task_statuses.logger')
def test_handle_no_pending_labeling_tasks(self, logger_mock):
management.call_command('process_task_statuses')

logger_mock.info.assert_called_once_with('No pending labeling tasks found')

def test_handle_missing_all_annotation_filenames(self):
with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as temporaryFile:
self.cvat_service_mock.get_annotations.return_value = (
self.test_zip_file_factory.get_empty_zip_file_content(temporaryFile, ["annotations/"]))

media_asset = self.test_factory.create_media_asset(filename='test.jpg', assigned=True)
labeling_task = media_asset.labeling_task

management.call_command('process_task_statuses')
labeling_task.refresh_from_db()
self.assertIn("Missing all task annotations for task :1", labeling_task.error)
27 changes: 27 additions & 0 deletions backend/tests/test_zip_file_content_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from io import BytesIO
from zipfile import ZipFile


class ZipFileContentFactory:

@staticmethod
def get_zip_file_content(*files):
file = BytesIO()
zip_file = ZipFile(file=file, mode='w')
for file_path, file_content in files:
zip_file.writestr(file_path, file_content)

zip_file.close()
return file.getvalue()

@staticmethod
def get_empty_zip_file_content(zip_file_obj, folder_subpathes):
with ZipFile(zip_file_obj, 'w') as zipFile:
for folder in folder_subpathes:
folder_with_slash = folder + "/" if folder[-1] != '/' else folder
zipFile.writestr(folder_with_slash, "")
zip_file_obj.seek(0)
file_content = zip_file_obj.read()
zipFile.close()

return file_content

0 comments on commit 3b5b189

Please sign in to comment.