diff --git a/django_project/cplus_api/tasks/sync_default_layers.py b/django_project/cplus_api/tasks/sync_default_layers.py index 110c22c..3232ada 100644 --- a/django_project/cplus_api/tasks/sync_default_layers.py +++ b/django_project/cplus_api/tasks/sync_default_layers.py @@ -4,6 +4,7 @@ import tempfile import rasterio +from rasterio.errors import RasterioIOError from datetime import datetime from django.utils import timezone @@ -12,6 +13,7 @@ from storages.backends.s3 import S3Storage from django.conf import settings from django.core.files.storage import FileSystemStorage +from django.db.models import Q from cplus_api.models import ( select_input_layer_storage, @@ -122,18 +124,57 @@ def run(self): if isinstance(self.storage, FileSystemStorage): download_path = os.path.join(media_root, self.file['Key']) os.makedirs(os.path.dirname(download_path), exist_ok=True) - self.read_metadata(download_path) - os.remove(download_path) + iteration = 0 + while iteration < 3: + try: + self.read_metadata(download_path) + except RasterioIOError: + iteration += 1 + if iteration == 3 and ( + self.input_layer.name == '' or + self.input_layer.file is None + ): + self.input_layer.delete() + else: + os.remove(download_path) + break else: - with tempfile.NamedTemporaryFile() as tmpfile: - boto3_client = self.storage.connection.meta.client - boto3_client.download_file( - self.storage.bucket_name, - self.file['Key'], - tmpfile.name, - Config=settings.AWS_TRANSFER_CONFIG - ) - self.read_metadata(tmpfile.name) + iteration = 0 + while iteration < 3: + with tempfile.NamedTemporaryFile() as tmpfile: + boto3_client = self.storage.connection.meta.client + boto3_client.download_file( + self.storage.bucket_name, + self.file['Key'], + tmpfile.name, + Config=settings.AWS_TRANSFER_CONFIG + ) + try: + self.read_metadata(tmpfile.name) + except RasterioIOError: + iteration += 1 + if iteration == 3 and ( + self.input_layer.name == '' or + self.input_layer.file is None + ): + self.input_layer.delete() + else: + break + + +def delete_invalid_default_layers(): + """Delete invalid default layers in DB + + :return: None + :rtype: None + """ + common_layers = InputLayer.objects.filter( + privacy_type=InputLayer.PrivacyTypes.COMMON + ) + invalid_common_layers = common_layers.filter( + Q(name='') | Q(file='') + ) + invalid_common_layers.delete() @shared_task(name="sync_default_layers") @@ -142,6 +183,8 @@ def sync_default_layers(): Create Input Layers from default layers copied to S3/local directory """ + delete_invalid_default_layers() + storage = select_input_layer_storage() component_types = [c[0] for c in InputLayer.ComponentTypes.choices] admin_username = os.getenv('ADMIN_USERNAME') diff --git a/django_project/cplus_api/tests/test_sync_default_layers.py b/django_project/cplus_api/tests/test_sync_default_layers.py index 67f4f1b..5ddf9f8 100644 --- a/django_project/cplus_api/tests/test_sync_default_layers.py +++ b/django_project/cplus_api/tests/test_sync_default_layers.py @@ -1,20 +1,30 @@ import os +import tempfile +import time +from datetime import timedelta from shutil import copyfile +from unittest.mock import patch, MagicMock + +from django.utils import timezone +from rasterio.errors import RasterioIOError +from storages.backends.s3 import S3Storage from core.settings.utils import absolute_path from cplus_api.models.layer import ( InputLayer, COMMON_LAYERS_DIR ) -from cplus_api.tasks.sync_default_layers import sync_default_layers +from cplus_api.tasks.sync_default_layers import ( + sync_default_layers, + ProcessFile +) from cplus_api.tests.common import BaseAPIViewTransactionTest +from cplus_api.tests.factories import InputLayerF class TestSyncDefaultLayer(BaseAPIViewTransactionTest): def setUp(self, *args, **kwargs): super().setUp(*args, **kwargs) - # print(help(self)) - # breakpoint() self.superuser.username = os.getenv('ADMIN_USERNAME') self.superuser.save() @@ -73,6 +83,7 @@ def test_new_layer(self): def test_file_updated(self): input_layer, source_path, dest_path = self.base_run() + time.sleep(5) first_modified_on = input_layer.modified_on copyfile(source_path, dest_path) sync_default_layers() @@ -87,3 +98,128 @@ def test_file_updated(self): input_layer.refresh_from_db() self.assertEqual(input_layer.name, 'New Name') self.assertEqual(input_layer.description, 'New Description') + + def test_delete_invalid_layers(self): + input_layer, source_path, dest_path = self.base_run() + invalid_common_layer_1 = InputLayerF.create( + name='', + privacy_type=InputLayer.PrivacyTypes.COMMON, + file=input_layer.file + ) + invalid_common_layer_2 = InputLayerF.create( + name='invalid_common_layer_2', + privacy_type=InputLayer.PrivacyTypes.COMMON, + file=None + ) + private_layer_1 = InputLayerF.create( + name='', + privacy_type=InputLayer.PrivacyTypes.PRIVATE, + file=input_layer.file + ) + private_layer_2 = InputLayerF.create( + name='private_layer_2', + privacy_type=InputLayer.PrivacyTypes.PRIVATE + ) + + sync_default_layers() + + # Calling refresh_from_db() on these 2 variable would result + # in InputLayer.DoesNotExist as they have been deleted, + # because they are invalid common layers + with self.assertRaises(InputLayer.DoesNotExist): + invalid_common_layer_1.refresh_from_db() + with self.assertRaises(InputLayer.DoesNotExist): + invalid_common_layer_2.refresh_from_db() + + # These layers are not deleted, so we could still call refresh_from_db + private_layer_1.refresh_from_db() + private_layer_2.refresh_from_db() + + def test_invalid_input_layers_not_created(self): + source_path = absolute_path( + 'cplus_api', 'tests', 'data', + 'pathways', 'test_pathway_2.tif' + ) + dest_path = ( + f'/home/web/media/minio_test/{COMMON_LAYERS_DIR}/' + f'{InputLayer.ComponentTypes.NCS_PATHWAY}/test_pathway_2.tif' + ) + os.makedirs(os.path.dirname(dest_path), exist_ok=True) + copyfile(source_path, dest_path) + with patch.object( + ProcessFile, 'read_metadata', autospec=True + ) as mock_read_metadata: + mock_read_metadata.side_effect = [ + RasterioIOError('error'), + RasterioIOError('error'), + RasterioIOError('error') + ] + sync_default_layers() + + self.assertFalse(InputLayer.objects.exists()) + + def test_invalid_input_layers_not_deleted(self): + input_layer, source_path, dest_path = self.base_run() + time.sleep(5) + first_modified_on = input_layer.modified_on + copyfile(source_path, dest_path) + sync_default_layers() + with patch.object( + ProcessFile, 'read_metadata', autospec=True + ) as mock_read_metadata: + mock_read_metadata.side_effect = [ + RasterioIOError('error'), + RasterioIOError('error'), + RasterioIOError('error') + ] + sync_default_layers() + + self.assertTrue(InputLayer.objects.exists()) + + # Check modified_on is updated + input_layer.refresh_from_db() + self.assertNotEquals(input_layer.modified_on, first_modified_on) + + def run_s3(self, mock_storage, mock_named_tmp_file=None): + source_path = absolute_path( + 'cplus_api', 'tests', 'data', + 'pathways', 'test_pathway_2.tif' + ) + dest_path = ( + f'/home/web/media/minio_test/{COMMON_LAYERS_DIR}/' + f'{InputLayer.ComponentTypes.NCS_PATHWAY}/test_pathway_2.tif' + ) + os.makedirs(os.path.dirname(dest_path), exist_ok=True) + copyfile(source_path, dest_path) + + storage = S3Storage(bucket_name='test-bucket') + s3_client = MagicMock() + s3_client.list_objects.return_value = { + 'Contents': [ + { + 'Key': 'common_layers/ncs_pathway/test_pathway_2.tif', + 'LastModified': timezone.now() + timedelta(days=1) + } + ] + } + storage.connection.meta.client = s3_client + mock_storage.return_value = storage + if mock_named_tmp_file: + (mock_named_tmp_file.return_value. + __enter__.return_value).name = dest_path + sync_default_layers() + + @patch('cplus_api.tasks.sync_default_layers.select_input_layer_storage') + def test_invalid_input_layers_not_created_s3(self, mock_storage): + self.run_s3(mock_storage) + self.assertFalse(InputLayer.objects.exists()) + + @patch('cplus_api.tasks.sync_default_layers.select_input_layer_storage') + @patch.object(tempfile, 'NamedTemporaryFile') + def test_invalid_input_layers_created_s3( + self, + mock_named_tmp_file, + mock_storage + ): + self.run_s3(mock_storage, mock_named_tmp_file) + self.assertTrue(InputLayer.objects.exists())