From 51f550f54628e4668c2e0d92c719a862d1b68efb Mon Sep 17 00:00:00 2001 From: Octavio Garfio Date: Tue, 8 Sep 2020 19:20:11 +0000 Subject: [PATCH] 173 Not to fail at missing labeling files --- .../commands/process_task_statuses.py | 26 +++++++++----- backend/tests/__init__.py | 16 +++------ .../commands/test_process_task_statuses.py | 36 +++++++++++++++---- .../tests/test_zip_file_content_factory.py | 27 ++++++++++++++ 4 files changed, 79 insertions(+), 26 deletions(-) create mode 100644 backend/tests/test_zip_file_content_factory.py diff --git a/backend/model_garden/management/commands/process_task_statuses.py b/backend/model_garden/management/commands/process_task_statuses.py index 5bedefc1..a0fa4881 100644 --- a/backend/model_garden/management/commands/process_task_statuses.py +++ b/backend/model_garden/management/commands/process_task_statuses.py @@ -16,6 +16,10 @@ logger = logging.getLogger(__name__) +class NoAnnotationException(Exception): + pass + + class Command(BaseCommand): help = "Process tasks statuses" @@ -74,6 +78,9 @@ def _process_labeling_tasks(self, labeling_tasks: List[LabelingTask]) -> None: for labeling_task, result_future in self._upload_annotations(labeling_tasks=labeling_tasks_to_upload): try: result_future.result() + except NoAnnotationException as noAnnotationException: + logger.error(f"{noAnnotationException}") + labeling_task.set_failed(error=f"{noAnnotationException}") except Exception as e: logger.error(f"{e}") labeling_task.set_failed(error=f"{e}") @@ -131,6 +138,7 @@ def _upload_labeling_task_annotations(self, labeling_task: LabelingTask): zip_fp = BytesIO(annotations_content_zip) zf = ZipFile(file=zip_fp) + annotation_filenames = { os.path.split(zi.filename)[-1]: zf.open(zi) for zi in zf.filelist if zi.filename.startswith(self._get_cvat_zip_folderpath(annotation_frmt)) @@ -143,20 +151,22 @@ def _upload_labeling_task_annotations(self, labeling_task: LabelingTask): logger.info(media_assets_filenames) logger.info(annotation_filenames) - missing_annotation_filenames = media_assets_filenames - set(annotation_filenames) - if missing_annotation_filenames: - raise Exception(f"Missing task annotations: {', '.join(sorted(missing_annotation_filenames))}") + if len(annotation_filenames) == 0: + raise NoAnnotationException(f"Missing all task annotations for task :{str(labeling_task.task_id)}") for media_asset in media_assets: try: asset_filename = os.path.splitext(media_asset.filename)[0] bucket_name = media_asset.dataset.bucket.name s3_client = S3Client(bucket_name=bucket_name) - s3_client.upload_file_obj( - file_obj=annotation_filenames[f"{asset_filename}" + self._get_label_file_extension(annotation_frmt)], - bucket=bucket_name, - key=media_asset.full_label_path, - ) + file_name = f"{asset_filename}" + self._get_label_file_extension(annotation_frmt) + if file_name in annotation_filenames: + file_object = annotation_filenames[f"{asset_filename}" + self._get_label_file_extension(annotation_frmt)] + s3_client.upload_file_obj( + file_obj=file_object, + bucket=bucket_name, + key=media_asset.full_label_path, + ) logger.info(f"Uploaded annotation '{media_asset.full_label_path}'") except Exception as e: raise Exception(f"Failed to upload task annotations: {e}") diff --git a/backend/tests/__init__.py b/backend/tests/__init__.py index 3f75e51f..848737bf 100644 --- a/backend/tests/__init__.py +++ b/backend/tests/__init__.py @@ -1,12 +1,11 @@ -from io import BytesIO from typing import Optional -from zipfile import ZipFile from django.test import TestCase, TransactionTestCase from rest_framework.test import APITestCase from model_garden.constants import LabelingTaskStatus from model_garden.models import Bucket, Dataset, MediaAsset, Labeler, LabelingTask +from .test_zip_file_content_factory import ZipFileContentFactory class Factory: @@ -76,30 +75,23 @@ def create_labeling_task( error=error, ) - @staticmethod - def get_zip_file(*files): - file = BytesIO() - zip_file = ZipFile(file=file, mode='w') - for file_path, file_content in files: - zip_file.writestr(file_path, file_content) - - zip_file.close() - return file.getvalue() - class BaseTestCase(TestCase): def setUp(self): super().setUp() self.test_factory = Factory() + self.test_zip_file_factory = ZipFileContentFactory() class BaseAPITestCase(APITestCase): def setUp(self): super().setUp() self.test_factory = Factory() + self.test_zip_file_factory = ZipFileContentFactory() class BaseTransactionTestCase(TransactionTestCase): def setUp(self): super().setUp() self.test_factory = Factory() + self.test_zip_file_factory = ZipFileContentFactory() diff --git a/backend/tests/management/commands/test_process_task_statuses.py b/backend/tests/management/commands/test_process_task_statuses.py index 7285ccd6..8840b32d 100644 --- a/backend/tests/management/commands/test_process_task_statuses.py +++ b/backend/tests/management/commands/test_process_task_statuses.py @@ -1,3 +1,4 @@ +import tempfile from unittest import mock from django.core import management @@ -16,9 +17,6 @@ def setUp(self): } self.s3_client_patcher = mock.patch('model_garden.management.commands.process_task_statuses.S3Client') self.s3_client_mock = self.s3_client_patcher.start().return_value - self.cvat_service_mock.get_annotations.return_value = self.test_factory.get_zip_file( - ('Annotations/test.xml', 'test'), - ) def tearDown(self): self.s3_client_patcher.stop() @@ -27,6 +25,9 @@ def tearDown(self): def test_handle(self): media_asset = self.test_factory.create_media_asset(filename='test.jpg', assigned=True) + self.cvat_service_mock.get_annotations.return_value = self.test_zip_file_factory.get_zip_file_content( + ('Annotations/test.xml', 'test'), + ) labeling_task = media_asset.labeling_task management.call_command('process_task_statuses') @@ -74,6 +75,9 @@ def test_handle_get_annotations_error(self): self.assertEqual(labeling_task.error, "Failed to get task annotations: CVAT error") def test_handle_s3_upload_error(self): + self.cvat_service_mock.get_annotations.return_value = self.test_zip_file_factory.get_zip_file_content( + ('Annotations/test.xml', 'test'), + ) self.s3_client_mock.upload_file_obj.side_effect = Exception('S3 error') media_asset = self.test_factory.create_media_asset(filename='test.jpg', assigned=True) labeling_task = media_asset.labeling_task @@ -83,17 +87,37 @@ def test_handle_s3_upload_error(self): labeling_task.refresh_from_db() self.assertEqual(labeling_task.error, "Failed to upload task annotations: S3 error") - def test_handle_missing_task_annotations(self): - media_asset = self.test_factory.create_media_asset(filename='test2.jpg', assigned=True) + def test_handle_missing_one_annotation_filename(self): + self.cvat_service_mock.get_annotations.return_value = self.test_zip_file_factory.get_zip_file_content( + ('Annotations/test.xml', 'test'), + ) + media_asset = self.test_factory.create_media_asset(filename='test.jpg', assigned=True) + media_asset2 = self.test_factory.create_media_asset(filename='test2.jpg', assigned=True) labeling_task = media_asset.labeling_task + labeling_task2 = media_asset2.labeling_task management.call_command('process_task_statuses') labeling_task.refresh_from_db() - self.assertEqual(labeling_task.error, "Missing task annotations: test2.xml") + labeling_task2.refresh_from_db() + + self.assertEqual(labeling_task.error, None) + self.assertEqual(labeling_task2.error, None) @mock.patch('model_garden.management.commands.process_task_statuses.logger') def test_handle_no_pending_labeling_tasks(self, logger_mock): management.call_command('process_task_statuses') logger_mock.info.assert_called_once_with('No pending labeling tasks found') + + def test_handle_missing_all_annotation_filenames(self): + with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as temporaryFile: + self.cvat_service_mock.get_annotations.return_value = ( + self.test_zip_file_factory.get_empty_zip_file_content(temporaryFile, ["annotations/"])) + + media_asset = self.test_factory.create_media_asset(filename='test.jpg', assigned=True) + labeling_task = media_asset.labeling_task + + management.call_command('process_task_statuses') + labeling_task.refresh_from_db() + self.assertIn("Missing all task annotations for task :1", labeling_task.error) diff --git a/backend/tests/test_zip_file_content_factory.py b/backend/tests/test_zip_file_content_factory.py new file mode 100644 index 00000000..2893ca44 --- /dev/null +++ b/backend/tests/test_zip_file_content_factory.py @@ -0,0 +1,27 @@ +from io import BytesIO +from zipfile import ZipFile + + +class ZipFileContentFactory: + + @staticmethod + def get_zip_file_content(*files): + file = BytesIO() + zip_file = ZipFile(file=file, mode='w') + for file_path, file_content in files: + zip_file.writestr(file_path, file_content) + + zip_file.close() + return file.getvalue() + + @staticmethod + def get_empty_zip_file_content(zip_file_obj, folder_subpathes): + with ZipFile(zip_file_obj, 'w') as zipFile: + for folder in folder_subpathes: + folder_with_slash = folder + "/" if folder[-1] != '/' else folder + zipFile.writestr(folder_with_slash, "") + zip_file_obj.seek(0) + file_content = zip_file_obj.read() + zipFile.close() + + return file_content