diff --git a/gateway/api/services/file_storage.py b/gateway/api/services/file_storage.py index 1deedf565..5ae0f4e70 100644 --- a/gateway/api/services/file_storage.py +++ b/gateway/api/services/file_storage.py @@ -3,8 +3,11 @@ """ import glob import logging +import mimetypes import os from enum import Enum +from typing import Optional, Tuple +from wsgiref.util import FileWrapper from django.conf import settings @@ -41,6 +44,21 @@ class FileStorage: # pylint: disable=too-few-public-methods provider_name (str | None): name of the provider in caseis needed to build the path """ + @staticmethod + def is_valid_extension(file_name: str) -> bool: + """ + This method verifies if the extension of the file is valid. + + Args: + file_name (str): file name to verify + + Returns: + bool: True or False if it is valid or not + """ + return any( + file_name.endswith(extension) for extension in SUPPORTED_FILE_EXTENSIONS + ) + def __init__( self, username: str, @@ -122,3 +140,34 @@ def get_files(self) -> list[str]: for extension in SUPPORTED_FILE_EXTENSIONS for path in glob.glob(f"{self.file_path}/*{extension}") ] + + def get_file(self, file_name: str) -> Optional[Tuple[FileWrapper, str, int]]: + """ + This method returns a file from file_name: + - Only files with supported extensions are available to download + - It returns only a file from a user or a provider file storage + + Returns: + FileWrapper: the file itself + str: with the type of the file + int: with the size of the file + """ + + file_name_path = os.path.basename(file_name) + path_to_file = sanitize_file_path(os.path.join(self.file_path, file_name_path)) + + if not os.path.exists(path_to_file): + logger.warning( + "Directory %s does not exist for file %s.", + path_to_file, + file_name_path, + ) + return None + + with open(path_to_file, "rb") as file_object: + file_wrapper = FileWrapper(file_object) + + file_type = mimetypes.guess_type(path_to_file)[0] + file_size = os.path.getsize(path_to_file) + + return file_wrapper, file_type, file_size diff --git a/gateway/api/utils.py b/gateway/api/utils.py index 626938e6d..779568b15 100644 --- a/gateway/api/utils.py +++ b/gateway/api/utils.py @@ -421,13 +421,10 @@ def create_dependency_allowlist(): def sanitize_name(name: str | None): """Sanitize name""" - if name: - sanitized_name = "" - for c in name: - if c.isalnum() or c in ["_", "-", "/"]: - sanitized_name += c - return sanitized_name - return name + if not name: + return name + # Remove all characters except alphanumeric, _, -, / + return re.sub("[^a-zA-Z0-9_\\-/]", "", name) def create_gpujob_allowlist(): @@ -448,3 +445,11 @@ def create_gpujob_allowlist(): raise ValueError("Unable to decode gpujob allowlist") from e return gpujobs + + +def sanitize_file_name(name: str | None): + """Sanitize the name of a file""" + if not name: + return name + # Remove all characters except alphanumeric, _, ., - + return re.sub("[^a-zA-Z0-9_\\.\\-]", "", name) diff --git a/gateway/api/v1/views/files.py b/gateway/api/v1/views/files.py index efa9417ee..9609b7ba8 100644 --- a/gateway/api/v1/views/files.py +++ b/gateway/api/v1/views/files.py @@ -64,19 +64,26 @@ def provider_list(self, request): return super().provider_list(request) @swagger_auto_schema( - operation_description="Download a specific file", + operation_description="Download a specific file in the user directory", manual_parameters=[ openapi.Parameter( "file", openapi.IN_QUERY, - description="file name", + description="File name", + type=openapi.TYPE_STRING, + required=True, + ), + openapi.Parameter( + "function", + openapi.IN_QUERY, + description="Qiskit Function title", type=openapi.TYPE_STRING, required=True, ), openapi.Parameter( "provider", openapi.IN_QUERY, - description="provider name", + description="Provider name", type=openapi.TYPE_STRING, required=False, ), @@ -86,6 +93,36 @@ def provider_list(self, request): def download(self, request): return super().download(request) + @swagger_auto_schema( + operation_description="Download a specific file in the provider directory", + manual_parameters=[ + openapi.Parameter( + "file", + openapi.IN_QUERY, + description="File name", + type=openapi.TYPE_STRING, + required=True, + ), + openapi.Parameter( + "function", + openapi.IN_QUERY, + description="Qiskit Function title", + type=openapi.TYPE_STRING, + required=True, + ), + openapi.Parameter( + "provider", + openapi.IN_QUERY, + description="Provider name", + type=openapi.TYPE_STRING, + required=True, + ), + ], + ) + @action(methods=["GET"], detail=False, url_path="provider/download") + def provider_download(self, request): + return super().provider_download(request) + @swagger_auto_schema( operation_description="Deletes file uploaded or produced by the programs", request_body=openapi.Schema( diff --git a/gateway/api/views/files.py b/gateway/api/views/files.py index e441d7d42..4cb40b5fb 100644 --- a/gateway/api/views/files.py +++ b/gateway/api/views/files.py @@ -4,9 +4,7 @@ Version views inherit from the different views. """ import logging -import mimetypes import os -from wsgiref.util import FileWrapper from django.conf import settings from django.http import StreamingHttpResponse @@ -22,8 +20,8 @@ from rest_framework.decorators import action from rest_framework.response import Response -from api.services.file_storage import FileStorage, WorkingDir -from api.utils import sanitize_name +from api.services.file_storage import SUPPORTED_FILE_EXTENSIONS, FileStorage, WorkingDir +from api.utils import sanitize_file_name, sanitize_name from api.models import Provider, Program from utils import sanitize_file_path @@ -207,6 +205,12 @@ def list(self, request): function_title = sanitize_name(request.query_params.get("function", None)) working_dir = WorkingDir.USER_STORAGE + if function_title is None: + return Response( + {"message": "Qiskit Function title is mandatory"}, + status=status.HTTP_400_BAD_REQUEST, + ) + function = self.get_function( user=request.user, function_title=function_title, @@ -246,6 +250,14 @@ def provider_list(self, request): function_title = sanitize_name(request.query_params.get("function")) working_dir = WorkingDir.PROVIDER_STORAGE + if function_title is None or provider_name is None: + return Response( + { + "message": "File name, Qiskit Function title and Provider name are mandatory" # pylint: disable=line-too-long + }, + status=status.HTTP_400_BAD_REQUEST, + ) + if not self.user_has_provider_access(request.user, provider_name): return Response( {"message": f"Provider {provider_name} doesn't exist."}, @@ -276,48 +288,148 @@ def provider_list(self, request): return Response({"results": files}) @action(methods=["GET"], detail=False) - def download(self, request): # pylint: disable=invalid-name - """Download selected file.""" - # default response for file not found, overwritten if file is found - response = Response( - {"message": "Requested file was not found."}, - status=status.HTTP_404_NOT_FOUND, - ) + def download(self, request): + """ + It returns a file from user paths: + - username/ + - username/provider_name/function_title + """ tracer = trace.get_tracer("gateway.tracer") ctx = TraceContextTextMapPropagator().extract(carrier=request.headers) with tracer.start_as_current_span("gateway.files.download", context=ctx): - requested_file_name = request.query_params.get("file") - provider_name = request.query_params.get("provider") - if requested_file_name is not None: - user_dir = request.user.username - if provider_name is not None: - if self.check_user_has_provider(request.user, provider_name): - user_dir = provider_name - else: - return response - # look for file in user's folder - filename = os.path.basename(requested_file_name) - user_dir = os.path.join( - sanitize_file_path(settings.MEDIA_ROOT), - sanitize_file_path(user_dir), + username = request.user.username + requested_file_name = sanitize_file_name( + request.query_params.get("file", None) + ) + provider_name = sanitize_name(request.query_params.get("provider", None)) + function_title = sanitize_name(request.query_params.get("function", None)) + working_dir = WorkingDir.USER_STORAGE + + if not all([requested_file_name, function_title]): + return Response( + {"message": "File name and Qiskit Function title are mandatory"}, + status=status.HTTP_400_BAD_REQUEST, ) - file_path = os.path.join( - sanitize_file_path(user_dir), sanitize_file_path(filename) + + if not FileStorage.is_valid_extension(requested_file_name): + extensions = ", ".join(SUPPORTED_FILE_EXTENSIONS) + return Response( + { + "message": f"File name needs to have a valid extension: {extensions}" + }, + status=status.HTTP_400_BAD_REQUEST, ) - if os.path.exists(user_dir) and os.path.exists(file_path) and filename: - chunk_size = 8192 - # note: we do not use with statements as Streaming response closing file itself. - response = StreamingHttpResponse( - FileWrapper( - open( # pylint: disable=consider-using-with - file_path, "rb" - ), - chunk_size, - ), - content_type=mimetypes.guess_type(file_path)[0], - ) - response["Content-Length"] = os.path.getsize(file_path) - response["Content-Disposition"] = f"attachment; filename={filename}" + + function = self.get_function( + user=request.user, + function_title=function_title, + provider_name=provider_name, + ) + if not function: + if provider_name: + error_message = f"Qiskit Function {provider_name}/{function_title} doesn't exist." # pylint: disable=line-too-long + else: + error_message = f"Qiskit Function {function_title} doesn't exist." + return Response( + {"message": error_message}, + status=status.HTTP_404_NOT_FOUND, + ) + + file_storage = FileStorage( + username=username, + working_dir=working_dir, + function_title=function_title, + provider_name=provider_name, + ) + result = file_storage.get_file(file_name=requested_file_name) + if result is None: + return Response( + {"message": "Requested file was not found."}, + status=status.HTTP_404_NOT_FOUND, + ) + + file_wrapper, file_type, file_size = result + response = StreamingHttpResponse(file_wrapper, content_type=file_type) + response["Content-Length"] = file_size + response[ + "Content-Disposition" + ] = f"attachment; filename={requested_file_name}" + return response + + @action(methods=["GET"], detail=False, url_path="provider/download") + def provider_download(self, request): + """ + It returns a file from provider path: + - provider_name/function_title + """ + tracer = trace.get_tracer("gateway.tracer") + ctx = TraceContextTextMapPropagator().extract(carrier=request.headers) + with tracer.start_as_current_span( + "gateway.files.provider_download", context=ctx + ): + username = request.user.username + requested_file_name = sanitize_file_name( + request.query_params.get("file", None) + ) + provider_name = sanitize_name(request.query_params.get("provider", None)) + function_title = sanitize_name(request.query_params.get("function", None)) + working_dir = WorkingDir.PROVIDER_STORAGE + + if not all([requested_file_name, function_title, provider_name]): + return Response( + { + "message": "File name, Qiskit Function title and Provider name are mandatory" # pylint: disable=line-too-long + }, + status=status.HTTP_400_BAD_REQUEST, + ) + + if not FileStorage.is_valid_extension(requested_file_name): + extensions = ", ".join(SUPPORTED_FILE_EXTENSIONS) + return Response( + { + "message": f"File name needs to have a valid extension: {extensions}" + }, + status=status.HTTP_400_BAD_REQUEST, + ) + + if not self.user_has_provider_access(request.user, provider_name): + return Response( + {"message": f"Provider {provider_name} doesn't exist."}, + status=status.HTTP_404_NOT_FOUND, + ) + + function = self.get_function( + user=request.user, + function_title=function_title, + provider_name=provider_name, + ) + if not function: + return Response( + { + "message": f"Qiskit Function {provider_name}/{function_title} doesn't exist." # pylint: disable=line-too-long + }, + status=status.HTTP_404_NOT_FOUND, + ) + + file_storage = FileStorage( + username=username, + working_dir=working_dir, + function_title=function_title, + provider_name=provider_name, + ) + result = file_storage.get_file(file_name=requested_file_name) + if result is None: + return Response( + {"message": "Requested file was not found."}, + status=status.HTTP_404_NOT_FOUND, + ) + + file_wrapper, file_type, file_size = result + response = StreamingHttpResponse(file_wrapper, content_type=file_type) + response["Content-Length"] = file_size + response[ + "Content-Disposition" + ] = f"attachment; filename={requested_file_name}" return response @action(methods=["DELETE"], detail=False) diff --git a/gateway/tests/api/test_files.py b/gateway/tests/api/test_v1_files.py similarity index 80% rename from gateway/tests/api/test_files.py rename to gateway/tests/api/test_v1_files.py index 8c32efc29..4524c5a60 100644 --- a/gateway/tests/api/test_files.py +++ b/gateway/tests/api/test_v1_files.py @@ -36,7 +36,7 @@ def test_files_list_with_empty_params(self): self.client.force_authenticate(user=user) url = reverse("v1:files-list") response = self.client.get(url, format="json") - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) def test_files_list_from_user_working_dir(self): """Tests files list with working dir as user""" @@ -184,18 +184,34 @@ def test_files_provider_list_with_a_user_that_has_no_access_to_provider(self): def test_non_existing_file_download(self): """Tests downloading non-existing file.""" - user = models.User.objects.get(username="test_user") - self.client.force_authenticate(user=user) - url = reverse("v1:files-download") - response = self.client.get( - url, data={"file": "non_existing.tar"}, format="json" + + media_root = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "resources", + "fake_media", ) - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - response = self.client.get(url, format="json") - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) + + with self.settings(MEDIA_ROOT=media_root): + file = "non_existing_file.tar" + function = "personal-program" + + user = models.User.objects.get(username="test_user_2") + self.client.force_authenticate(user=user) + url = reverse("v1:files-download") + response = self.client.get( + url, + { + "file": file, + "function": function, + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_file_download(self): - """Tests downloading non-existing file.""" + """Tests downloading an existing file.""" media_root = os.path.join( os.path.dirname(os.path.abspath(__file__)), "..", @@ -205,17 +221,54 @@ def test_file_download(self): media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) with self.settings(MEDIA_ROOT=media_root): - user = models.User.objects.get(username="test_user") + file = "artifact_2.tar" + function = "personal-program" + + user = models.User.objects.get(username="test_user_2") self.client.force_authenticate(user=user) url = reverse("v1:files-download") response = self.client.get( - url, data={"file": "artifact.tar"}, format="json" + url, + { + "file": file, + "function": function, + }, + format="json", ) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertTrue(response.streaming) + def test_non_existing_provider_file_download(self): + """Tests downloading a non-existing file from a provider storage.""" + media_root = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "resources", + "fake_media", + ) + media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) + + with self.settings(MEDIA_ROOT=media_root): + file = "non-existing_artifact.tar" + provider = "default" + function = "Program" + + user = models.User.objects.get(username="test_user_2") + self.client.force_authenticate(user=user) + url = reverse("v1:files-provider-download") + response = self.client.get( + url, + { + "file": file, + "provider": provider, + "function": function, + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + def test_provider_file_download(self): - """Tests downloading non-existing file.""" + """Tests downloading a file from a provider storage.""" media_root = os.path.join( os.path.dirname(os.path.abspath(__file__)), "..", @@ -225,12 +278,20 @@ def test_provider_file_download(self): media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) with self.settings(MEDIA_ROOT=media_root): + file = "provider_program_artifact.tar" + provider = "default" + function = "Program" + user = models.User.objects.get(username="test_user_2") self.client.force_authenticate(user=user) - url = reverse("v1:files-download") + url = reverse("v1:files-provider-download") response = self.client.get( url, - data={"file": "provider_artifact.tar", "provider": "default"}, + { + "file": file, + "provider": provider, + "function": function, + }, format="json", ) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -367,15 +428,29 @@ def test_escape_directory(self): "fake_media", ) ): - user = models.User.objects.get(username="test_user") + file = "../test_user/artifact.tar" + function = "personal-program" + + user = models.User.objects.get(username="test_user_2") self.client.force_authenticate(user=user) url = reverse("v1:files-download") response = self.client.get( - url, data={"file": "../test_user_2/artifact_2.tar"}, format="json" + url, + { + "file": file, + "function": function, + }, + format="json", ) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + file = "../test_user/artifact.tar/" response = self.client.get( - url, data={"file": "../test_user_2/artifact_2.tar/"}, format="json" + url, + { + "file": file, + "function": function, + }, + format="json", ) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)