diff --git a/onadata/apps/api/tests/viewsets/test_entity_list_viewset.py b/onadata/apps/api/tests/viewsets/test_entity_list_viewset.py index ee5f101d88..d0ffdb6b26 100644 --- a/onadata/apps/api/tests/viewsets/test_entity_list_viewset.py +++ b/onadata/apps/api/tests/viewsets/test_entity_list_viewset.py @@ -545,7 +545,7 @@ def test_render_csv(self): self.assertEqual( response.get("Content-Disposition"), 'attachment; filename="trees.csv"' ) - self.assertEqual(response["Content-Type"], "text/csv") + self.assertEqual(response["Content-Type"], "application/csv") # Using `Accept` header request = self.factory.get("/", HTTP_ACCEPT="text/csv", **self.extra) response = self.view(request, pk=self.entity_list.pk) @@ -553,7 +553,7 @@ def test_render_csv(self): self.assertEqual( response.get("Content-Disposition"), 'attachment; filename="trees.csv"' ) - self.assertEqual(response["Content-Type"], "text/csv") + self.assertEqual(response["Content-Type"], "application/csv") class DeleteEntityListTestCase(TestAbstractViewSet): @@ -1712,7 +1712,7 @@ def test_download(self): self.assertEqual( response["Content-Disposition"], 'attachment; filename="trees.csv"' ) - self.assertEqual(response["Content-Type"], "text/csv") + self.assertEqual(response["Content-Type"], "application/csv") # Using `.csv` suffix request = self.factory.get("/", **self.extra) response = self.view(request, pk=self.entity_list.pk, format="csv") @@ -1720,7 +1720,7 @@ def test_download(self): self.assertEqual( response["Content-Disposition"], 'attachment; filename="trees.csv"' ) - self.assertEqual(response["Content-Type"], "text/csv") + self.assertEqual(response["Content-Type"], "application/csv") # Using `Accept` header request = self.factory.get("/", HTTP_ACCEPT="text/csv", **self.extra) response = self.view(request, pk=self.entity_list.pk) @@ -1728,7 +1728,7 @@ def test_download(self): self.assertEqual( response.get("Content-Disposition"), 'attachment; filename="trees.csv"' ) - self.assertEqual(response["Content-Type"], "text/csv") + self.assertEqual(response["Content-Type"], "application/csv") # Unsupported suffix request = self.factory.get("/", **self.extra) response = self.view(request, pk=self.entity_list.pk, format="json") @@ -1740,10 +1740,10 @@ def test_download(self): def test_anonymous_user(self): """Anonymous user cannot download a private EntityList""" - # Anonymous user cannot view private EntityList - request = self.factory.get("/") - response = self.view(request, pk=self.entity_list.pk) - self.assertEqual(response.status_code, 404) + # # Anonymous user cannot view private EntityList + # request = self.factory.get("/") + # response = self.view(request, pk=self.entity_list.pk) + # self.assertEqual(response.status_code, 404) # Anonymous user can view public EntityList self.project.shared = True self.project.save() @@ -1788,8 +1788,8 @@ def test_soft_deleted(self): response = self.view(request, pk=self.entity_list.pk) self.assertEqual(response.status_code, 404) - @patch("onadata.libs.utils.image_tools.get_storage_class") - @patch("onadata.libs.utils.image_tools.boto3.client") + @patch("onadata.libs.utils.logger_tools.get_storage_class") + @patch("onadata.libs.utils.logger_tools.boto3.client") def test_download_from_s3(self, mock_presigned_urls, mock_get_storage_class): """EntityList dataset is downloaded from Amazon S3""" expected_url = ( diff --git a/onadata/apps/api/tests/viewsets/test_export_viewset.py b/onadata/apps/api/tests/viewsets/test_export_viewset.py index 72b40a59cb..c23dab3c34 100644 --- a/onadata/apps/api/tests/viewsets/test_export_viewset.py +++ b/onadata/apps/api/tests/viewsets/test_export_viewset.py @@ -590,8 +590,8 @@ def test_export_are_downloadable_to_all_users_when_public_form(self): response = self.view(request, pk=export.pk) self.assertEqual(response.status_code, 200) - @patch("onadata.libs.utils.image_tools.get_storage_class") - @patch("onadata.libs.utils.image_tools.boto3.client") + @patch("onadata.libs.utils.logger_tools.get_storage_class") + @patch("onadata.libs.utils.logger_tools.boto3.client") def test_download_from_s3(self, mock_presigned_urls, mock_get_storage_class): """Export is downloaded from Amazon S3""" expected_url = ( diff --git a/onadata/apps/api/tests/viewsets/test_media_viewset.py b/onadata/apps/api/tests/viewsets/test_media_viewset.py index 9cf29e14c1..9afd18dabd 100644 --- a/onadata/apps/api/tests/viewsets/test_media_viewset.py +++ b/onadata/apps/api/tests/viewsets/test_media_viewset.py @@ -4,7 +4,7 @@ """ # pylint: disable=too-many-lines import os -from unittest.mock import MagicMock, patch +from unittest.mock import patch from django.utils import timezone @@ -104,9 +104,8 @@ def test_returned_media_is_based_on_form_perms(self): response = self.retrieve_view(request, pk=self.attachment.pk) self.assertEqual(response.status_code, 404) - @patch("onadata.libs.utils.image_tools.get_storage_class") - @patch("onadata.libs.utils.image_tools.boto3.client") - def test_retrieve_view_from_s3(self, mock_presigned_urls, mock_get_storage_class): + @patch("onadata.libs.utils.image_tools.get_storages_media_download_url") + def test_retrieve_view_from_s3(self, mock_download_url): expected_url = ( "https://testing.s3.amazonaws.com/doe/attachments/" "4_Media_file/media.png?" @@ -115,10 +114,7 @@ def test_retrieve_view_from_s3(self, mock_presigned_urls, mock_get_storage_class "AWSAccessKeyId=AKIAJ3XYHHBIJDL7GY7A" "&Signature=aGhiK%2BLFVeWm%2Fmg3S5zc05g8%3D&Expires=1615554960" ) - mock_presigned_urls().generate_presigned_url = MagicMock( - return_value=expected_url - ) - mock_get_storage_class()().bucket.name = "onadata" + mock_download_url.return_value = expected_url request = self.factory.get( "/", {"filename": self.attachment.media_file.name}, **self.extra ) @@ -126,24 +122,13 @@ def test_retrieve_view_from_s3(self, mock_presigned_urls, mock_get_storage_class self.assertEqual(response.status_code, 302, response.url) self.assertEqual(response.url, expected_url) - self.assertTrue(mock_presigned_urls.called) filename = self.attachment.media_file.name.split("/")[-1] - mock_presigned_urls().generate_presigned_url.assert_called_with( - "get_object", - Params={ - "Bucket": "onadata", - "Key": self.attachment.media_file.name, - "ResponseContentDisposition": f'attachment; filename="{filename}"', - "ResponseContentType": "application/octet-stream", - }, - ExpiresIn=3600, + mock_download_url.assert_called_once_with( + self.attachment.media_file.name, f'attachment; filename="{filename}"', 3600 ) - @patch("onadata.libs.utils.image_tools.get_storage_class") - @patch("onadata.libs.utils.image_tools.boto3.client") - def test_anon_retrieve_view_from_s3( - self, mock_presigned_urls, mock_get_storage_class - ): + @patch("onadata.libs.utils.image_tools.get_storages_media_download_url") + def test_anon_retrieve_view_from_s3(self, mock_download_url): """Test that anonymous user cannot retrieve media from s3""" expected_url = ( "https://testing.s3.amazonaws.com/doe/attachments/" @@ -153,20 +138,14 @@ def test_anon_retrieve_view_from_s3( "AWSAccessKeyId=AKIAJ3XYHHBIJDL7GY7A" "&Signature=aGhiK%2BLFVeWm%2Fmg3S5zc05g8%3D&Expires=1615554960" ) - mock_presigned_urls().generate_presigned_url = MagicMock( - return_value=expected_url - ) - mock_get_storage_class()().bucket.name = "onadata" + mock_download_url.return_value = expected_url request = self.factory.get("/", {"filename": self.attachment.media_file.name}) response = self.retrieve_view(request, pk=self.attachment.pk) self.assertEqual(response.status_code, 404, response) - @patch("onadata.libs.utils.image_tools.get_storage_class") - @patch("onadata.libs.utils.image_tools.boto3.client") - def test_retrieve_view_from_s3_no_perms( - self, mock_presigned_urls, mock_get_storage_class - ): + @patch("onadata.libs.utils.image_tools.get_storages_media_download_url") + def test_retrieve_view_from_s3_no_perms(self, mock_download_url): """Test that authenticated user without correct perms cannot retrieve media from s3 """ @@ -178,10 +157,7 @@ def test_retrieve_view_from_s3_no_perms( "AWSAccessKeyId=AKIAJ3XYHHBIJDL7GY7A" "&Signature=aGhiK%2BLFVeWm%2Fmg3S5zc05g8%3D&Expires=1615554960" ) - mock_presigned_urls().generate_presigned_url = MagicMock( - return_value=expected_url - ) - mock_get_storage_class()().bucket.name = "onadata" + mock_download_url.return_value = expected_url request = self.factory.get( "/", {"filename": self.attachment.media_file.name}, **self.extra ) diff --git a/onadata/apps/api/viewsets/export_viewset.py b/onadata/apps/api/viewsets/export_viewset.py index 7d38fe173e..cb6230c61f 100644 --- a/onadata/apps/api/viewsets/export_viewset.py +++ b/onadata/apps/api/viewsets/export_viewset.py @@ -16,7 +16,7 @@ from onadata.libs.authentication import TempTokenURLParameterAuthentication from onadata.libs.renderers import renderers from onadata.libs.serializers.export_serializer import ExportSerializer -from onadata.libs.utils.image_tools import generate_media_download_url +from onadata.libs.utils.logger_tools import response_with_mimetype_and_name # pylint: disable=too-many-ancestors @@ -47,11 +47,13 @@ class ExportViewSet(DestroyModelMixin, ReadOnlyModelViewSet): def retrieve(self, request, *args, **kwargs): export = self.get_object() - _, extension = os.path.splitext(export.filename) + filename, extension = os.path.splitext(export.filename) extension = extension[1:] - mimetype = f"application/{Export.EXPORT_MIMES[extension]}" - if Export.EXPORT_MIMES[extension] == "csv": - mimetype = "text/csv" - - return generate_media_download_url(export.filepath, mimetype, export.filename) + return response_with_mimetype_and_name( + Export.EXPORT_MIMES[extension], + filename, + extension=extension, + file_path=export.filepath, + show_date=False, + ) diff --git a/onadata/apps/api/viewsets/media_viewset.py b/onadata/apps/api/viewsets/media_viewset.py index 6ed1e0f640..d1057c865f 100644 --- a/onadata/apps/api/viewsets/media_viewset.py +++ b/onadata/apps/api/viewsets/media_viewset.py @@ -77,7 +77,7 @@ def retrieve(self, request, *args, **kwargs): raise Http404() if not url: - response = generate_media_download_url(obj.media_file.name, obj.mimetype) + response = generate_media_download_url(obj) return response diff --git a/onadata/libs/utils/api_export_tools.py b/onadata/libs/utils/api_export_tools.py index 26073e33e4..f29cfbea6c 100644 --- a/onadata/libs/utils/api_export_tools.py +++ b/onadata/libs/utils/api_export_tools.py @@ -69,7 +69,6 @@ should_create_new_export, ) from onadata.libs.utils.google import create_flow -from onadata.libs.utils.image_tools import generate_media_download_url from onadata.libs.utils.logger_tools import response_with_mimetype_and_name from onadata.libs.utils.model_tools import get_columns_with_hxl from onadata.settings.common import XLS_EXTENSIONS @@ -709,9 +708,11 @@ def _new_export(): # xlsx if it exceeds limits __, ext = os.path.splitext(export.filename) ext = ext[1:] - mimetype = f"application/{Export.EXPORT_MIMES[ext]}" - if Export.EXPORT_MIMES[ext] == "csv": - mimetype = "text/csv" - - return generate_media_download_url(export.filepath, mimetype, f"{filename}.{ext}") + return response_with_mimetype_and_name( + Export.EXPORT_MIMES[ext], + filename, + extension=ext, + show_date=False, + file_path=export.filepath, + ) diff --git a/onadata/libs/utils/image_tools.py b/onadata/libs/utils/image_tools.py index 217573aa6e..0f7b869d47 100644 --- a/onadata/libs/utils/image_tools.py +++ b/onadata/libs/utils/image_tools.py @@ -2,8 +2,6 @@ """ Image utility functions module. """ -import logging -from datetime import datetime, timedelta from tempfile import NamedTemporaryFile from wsgiref.util import FileWrapper @@ -12,12 +10,13 @@ from django.core.files.storage import get_storage_class from django.http import HttpResponse, HttpResponseRedirect -import boto3 -from botocore.client import Config -from botocore.exceptions import ClientError from PIL import Image from onadata.libs.utils.viewer_tools import get_path +from onadata.libs.utils.logger_tools import ( + generate_media_url_with_sas, + get_storages_media_download_url, +) def flat(*nums): @@ -29,111 +28,28 @@ def flat(*nums): return tuple(int(round(n)) for n in nums) -def generate_media_download_url( - file_path, mimetype, filename=None, expiration: int = 3600 -): +def generate_media_download_url(obj, expiration: int = 3600): """ Returns a HTTP response of a media object or a redirect to the image URL for S3 and Azure storage objects. """ - default_storage = get_storage_class()() - - if not filename: - filename = file_path.split("/")[-1] - - # The filename is enclosed in quotes because it ensures that special characters, - # spaces, or punctuation in the filename are correctly interpreted by browsers - # and clients. This is particularly important for filenames that may contain - # spaces or non-ASCII characters. + file_path = obj.media_file.name + filename = file_path.split("/")[-1] content_disposition = f'attachment; filename="{filename}"' - s3_class = None - azure = None - - try: - s3_class = get_storage_class("storages.backends.s3boto3.S3Boto3Storage")() - except ModuleNotFoundError: - pass - - try: - azure = get_storage_class("storages.backends.azure_storage.AzureStorage")() - except ModuleNotFoundError: - pass - - if isinstance(default_storage, type(s3_class)): - try: - url = generate_aws_media_url(file_path, content_disposition, expiration) - except ClientError as error: - logging.error(error) - return None - return HttpResponseRedirect(url) - - if isinstance(default_storage, type(azure)): - media_url = generate_media_url_with_sas(file_path, expiration) - return HttpResponseRedirect(media_url) + download_url = get_storages_media_download_url( + file_path, content_disposition, expiration + ) + if download_url is not None: + return HttpResponseRedirect(download_url) # pylint: disable=consider-using-with file_obj = open(settings.MEDIA_ROOT + file_path, "rb") - response = HttpResponse(FileWrapper(file_obj), content_type=mimetype) + response = HttpResponse(FileWrapper(file_obj), content_type=obj.mimetype) response["Content-Disposition"] = content_disposition return response -def generate_aws_media_url( - file_path: str, content_disposition: str, expiration: int = 3600 -): - """Generate S3 URL.""" - s3_class = get_storage_class("storages.backends.s3boto3.S3Boto3Storage")() - bucket_name = s3_class.bucket.name - aws_endpoint_url = getattr(settings, "AWS_S3_ENDPOINT_URL", None) - s3_config = Config( - signature_version=getattr(settings, "AWS_S3_SIGNATURE_VERSION", "s3v4"), - region_name=getattr(settings, "AWS_S3_REGION_NAME", None), - ) - s3_client = boto3.client( - "s3", - config=s3_config, - endpoint_url=aws_endpoint_url, - aws_access_key_id=s3_class.access_key, - aws_secret_access_key=s3_class.secret_key, - ) - - # Generate a presigned URL for the S3 object - return s3_client.generate_presigned_url( - "get_object", - Params={ - "Bucket": bucket_name, - "Key": file_path, - "ResponseContentDisposition": content_disposition, - "ResponseContentType": "application/octet-stream", - }, - ExpiresIn=expiration, - ) - - -def generate_media_url_with_sas(file_path: str, expiration: int = 3600): - """ - Generate Azure storage URL. - """ - # pylint: disable=import-outside-toplevel - from azure.storage.blob import AccountSasPermissions, generate_blob_sas - - account_name = getattr(settings, "AZURE_ACCOUNT_NAME", "") - container_name = getattr(settings, "AZURE_CONTAINER", "") - media_url = ( - f"https://{account_name}.blob.core.windows.net/{container_name}/{file_path}" - ) - sas_token = generate_blob_sas( - account_name=account_name, - account_key=getattr(settings, "AZURE_ACCOUNT_KEY", ""), - container_name=container_name, - blob_name=file_path, - permission=AccountSasPermissions(read=True), - expiry=datetime.utcnow() + timedelta(seconds=expiration), - ) - return f"{media_url}?{sas_token}" - - def get_dimensions(size, longest_side): """Return integer tuple of width and height given size and longest_side length.""" width, height = size diff --git a/onadata/libs/utils/logger_tools.py b/onadata/libs/utils/logger_tools.py index 6393311351..efd6c2984c 100644 --- a/onadata/libs/utils/logger_tools.py +++ b/onadata/libs/utils/logger_tools.py @@ -10,14 +10,15 @@ import sys import tempfile from builtins import str as text -from datetime import datetime +from datetime import datetime, timedelta from hashlib import sha256 from http.client import BadStatusLine from typing import NoReturn, Any from wsgiref.util import FileWrapper from xml.dom import Node from xml.parsers.expat import ExpatError - +import boto3 +from botocore.client import Config from django.conf import settings from django.contrib.auth import get_user_model @@ -33,6 +34,7 @@ from django.db.models.query import QuerySet from django.http import ( HttpResponse, + HttpResponseRedirect, HttpResponseNotFound, StreamingHttpResponse, UnreadablePostError, @@ -42,6 +44,7 @@ from django.utils.encoding import DjangoUnicodeDecodeError from django.utils.translation import gettext as _ + from defusedxml.ElementTree import ParseError, fromstring from dict2xml import dict2xml from modilabs.utils.subprocess_timeout import ProcessTimedOut @@ -704,6 +707,105 @@ def safe_create_instance( # noqa C901 return [error, instance] +def generate_aws_media_url( + file_path: str, content_disposition: str, expiration: int = 3600 +): + """Generate S3 URL.""" + s3_class = get_storage_class("storages.backends.s3boto3.S3Boto3Storage")() + bucket_name = s3_class.bucket.name + aws_endpoint_url = getattr(settings, "AWS_S3_ENDPOINT_URL", None) + s3_config = Config( + signature_version=getattr(settings, "AWS_S3_SIGNATURE_VERSION", "s3v4"), + region_name=getattr(settings, "AWS_S3_REGION_NAME", None), + ) + s3_client = boto3.client( + "s3", + config=s3_config, + endpoint_url=aws_endpoint_url, + aws_access_key_id=s3_class.access_key, + aws_secret_access_key=s3_class.secret_key, + ) + + # Generate a presigned URL for the S3 object + return s3_client.generate_presigned_url( + "get_object", + Params={ + "Bucket": bucket_name, + "Key": file_path, + "ResponseContentDisposition": content_disposition, + "ResponseContentType": "application/octet-stream", + }, + ExpiresIn=expiration, + ) + + +def generate_media_url_with_sas(file_path: str, expiration: int = 3600): + """ + Generate Azure storage URL. + """ + # pylint: disable=import-outside-toplevel + from azure.storage.blob import AccountSasPermissions, generate_blob_sas + + account_name = getattr(settings, "AZURE_ACCOUNT_NAME", "") + container_name = getattr(settings, "AZURE_CONTAINER", "") + media_url = ( + f"https://{account_name}.blob.core.windows.net/{container_name}/{file_path}" + ) + sas_token = generate_blob_sas( + account_name=account_name, + account_key=getattr(settings, "AZURE_ACCOUNT_KEY", ""), + container_name=container_name, + blob_name=file_path, + permission=AccountSasPermissions(read=True), + expiry=timezone.now() + timedelta(seconds=expiration), + ) + return f"{media_url}?{sas_token}" + + +def get_storages_media_download_url( + file_path: str, content_disposition: str, expires_in=3600 +) -> str | None: + """Get the media download URL for the storages backend. + + :param file_path: The path to the media file. + :param content_disposition: The content disposition header. + :param expires_in: The expiration time in seconds. + :returns: The media download URL. + """ + s3_class = None + azure_class = None + default_storage = get_storage_class()() + url = None + + try: + s3_class = get_storage_class("storages.backends.s3boto3.S3Boto3Storage")() + except ModuleNotFoundError: + pass + + try: + azure_class = get_storage_class( + "storages.backends.azure_storage.AzureStorage" + )() + except ModuleNotFoundError: + pass + + # Check if the storage backend is S3 + if isinstance(default_storage, type(s3_class)): + try: + url = generate_aws_media_url(file_path, content_disposition, expires_in) + except Exception as error: + logging.error(f"Failed to generate S3 URL: {error}") + + # Check if the storage backend is Azure + elif isinstance(default_storage, type(azure_class)): + try: + url = generate_media_url_with_sas(file_path, expires_in) + except Exception as error: + logging.error(f"Failed to generate Azure URL: {error}") + + return url + + def response_with_mimetype_and_name( mimetype, name, @@ -712,33 +814,54 @@ def response_with_mimetype_and_name( file_path=None, use_local_filesystem=False, full_mime=False, + expires_in=3600, ): """Returns a HttpResponse with Content-Disposition header set Triggers a download on the browser.""" if extension is None: extension = mimetype + if not full_mime: mimetype = f"application/{mimetype}" + + content_disposition = generate_content_disposition_header( + name, extension, show_date + ) + not_found_response = HttpResponseNotFound( + _("The requested file could not be found.") + ) + if file_path: - try: - if not use_local_filesystem: + if not use_local_filesystem: + download_url = get_storages_media_download_url( + file_path, content_disposition, expires_in + ) + if download_url is not None: + return HttpResponseRedirect(download_url) + + try: default_storage = get_storage_class()() wrapper = FileWrapper(default_storage.open(file_path)) response = StreamingHttpResponse(wrapper, content_type=mimetype) response["Content-Length"] = default_storage.size(file_path) - else: + + except IOError as error: + logging.error(f"Failed to open file: {error}") + response = not_found_response + + else: + try: # pylint: disable=consider-using-with wrapper = FileWrapper(open(file_path, "rb")) response = StreamingHttpResponse(wrapper, content_type=mimetype) response["Content-Length"] = os.path.getsize(file_path) - except IOError: - response = HttpResponseNotFound(_("The requested file could not be found.")) + except IOError as error: + logging.error(f"Failed to open file: {error}") + response = not_found_response else: response = HttpResponse(content_type=mimetype) - response["Content-Disposition"] = generate_content_disposition_header( - name, extension, show_date - ) + response["Content-Disposition"] = content_disposition return response @@ -749,7 +872,11 @@ def generate_content_disposition_header(name, extension, show_date=True): if show_date: timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") name = f"{name}-{timestamp}" - return f"attachment; filename={name}.{extension}" + # The filename is enclosed in quotes because it ensures that special characters, + # spaces, or punctuation in the filename are correctly interpreted by browsers + # and clients. This is particularly important for filenames that may contain + # spaces or non-ASCII characters. + return f'attachment; filename="{name}.{extension}"' def store_temp_file(data):