Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Download end point refactor #1547

Merged
merged 11 commits into from
Dec 13, 2024
57 changes: 57 additions & 0 deletions gateway/api/services/file_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
"""
import glob
import logging
import mimetypes
import os
from enum import Enum
from typing import Optional, Tuple
from wsgiref.util import FileWrapper

from django.conf import settings

Expand Down Expand Up @@ -41,6 +44,21 @@ class FileStorage: # pylint: disable=too-few-public-methods
provider_name (str | None): name of the provider in caseis needed to build the path
"""

@staticmethod
def file_extension_is_valid(file_name: str) -> bool:
Tansito marked this conversation as resolved.
Show resolved Hide resolved
"""
This method verifies if the extension of the file is valid.

Args:
file_name (str): file name to verify

Returns:
bool: True or False if it is valid or not
"""
return any(
file_name.endswith(extension) for extension in SUPPORTED_FILE_EXTENSIONS
)

def __init__(
self,
username: str,
Expand Down Expand Up @@ -122,3 +140,42 @@ def get_files(self) -> list[str]:
for extension in SUPPORTED_FILE_EXTENSIONS
for path in glob.glob(f"{self.file_path}/*{extension}")
]

def get_file(self, file_name: str) -> Optional[Tuple[FileWrapper, str, int]]:
"""
This method returns a file from file_name:
- Only files with supported extensions are available to download
- It returns only a file from a user or a provider file storage

Returns:
FileWrapper: the file itself
str: with the type of the file
int: with the size of the file
"""

file_name_path = os.path.basename(file_name)
path_to_file = sanitize_file_path(os.path.join(self.file_path, file_name_path))

if not os.path.exists(path_to_file):
logger.warning(
"Directory %s does not exist for file %s.",
path_to_file,
file_name_path,
)
return None

try:
with open(path_to_file, "rb") as file_object:
file_wrapper = FileWrapper(file_object)

file_type = mimetypes.guess_type(path_to_file)[0]
file_size = os.path.getsize(path_to_file)

return file_wrapper, file_type, file_size
except FileNotFoundError:
logger.warning(
"Directory %s does not exist for file %s.",
path_to_file,
file_name_path,
)
return None
Tansito marked this conversation as resolved.
Show resolved Hide resolved
11 changes: 11 additions & 0 deletions gateway/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,14 @@ def create_gpujob_allowlist():
raise ValueError("Unable to decode gpujob allowlist") from e

return gpujobs


def sanitize_file_name(name: str | None):
"""Sanitize name of a file"""
if name:
sanitized_name = ""
for c in name:
if c.isalnum() or c in ["_", "-", "."]:
sanitized_name += c
return sanitized_name
return name
Tansito marked this conversation as resolved.
Show resolved Hide resolved
43 changes: 40 additions & 3 deletions gateway/api/v1/views/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,26 @@ def provider_list(self, request):
return super().provider_list(request)

@swagger_auto_schema(
operation_description="Download a specific file",
operation_description="Download a specific file in the user directory",
manual_parameters=[
openapi.Parameter(
"file",
openapi.IN_QUERY,
description="file name",
description="File name",
type=openapi.TYPE_STRING,
required=True,
),
openapi.Parameter(
"function",
openapi.IN_QUERY,
description="Qiskit Function title",
type=openapi.TYPE_STRING,
required=True,
),
openapi.Parameter(
"provider",
openapi.IN_QUERY,
description="provider name",
description="Provider name",
type=openapi.TYPE_STRING,
required=False,
),
Expand All @@ -86,6 +93,36 @@ def provider_list(self, request):
def download(self, request):
return super().download(request)

@swagger_auto_schema(
operation_description="Download a specific file in the provider directory",
manual_parameters=[
openapi.Parameter(
"file",
openapi.IN_QUERY,
description="File name",
type=openapi.TYPE_STRING,
required=True,
),
openapi.Parameter(
"function",
openapi.IN_QUERY,
description="Qiskit Function title",
type=openapi.TYPE_STRING,
required=True,
),
openapi.Parameter(
"provider",
openapi.IN_QUERY,
description="Provider name",
type=openapi.TYPE_STRING,
required=True,
),
],
)
@action(methods=["GET"], detail=False, url_path="provider/download")
def provider_download(self, request):
return super().provider_download(request)

@swagger_auto_schema(
operation_description="Deletes file uploaded or produced by the programs",
request_body=openapi.Schema(
Expand Down
198 changes: 157 additions & 41 deletions gateway/api/views/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
Version views inherit from the different views.
"""
import logging
import mimetypes
import os
from wsgiref.util import FileWrapper

from django.conf import settings
from django.http import StreamingHttpResponse
Expand All @@ -22,8 +20,8 @@
from rest_framework.decorators import action
from rest_framework.response import Response

from api.services.file_storage import FileStorage, WorkingDir
from api.utils import sanitize_name
from api.services.file_storage import SUPPORTED_FILE_EXTENSIONS, FileStorage, WorkingDir
from api.utils import sanitize_file_name, sanitize_name
from api.models import Provider, Program
from utils import sanitize_file_path

Expand Down Expand Up @@ -207,6 +205,12 @@ def list(self, request):
function_title = sanitize_name(request.query_params.get("function", None))
working_dir = WorkingDir.USER_STORAGE

if function_title is None:
return Response(
{"message": "Qiskit Function title is mandatory"},
status=status.HTTP_400_BAD_REQUEST,
)

function = self.get_function(
user=request.user,
function_title=function_title,
Expand Down Expand Up @@ -246,6 +250,14 @@ def provider_list(self, request):
function_title = sanitize_name(request.query_params.get("function"))
working_dir = WorkingDir.PROVIDER_STORAGE

if function_title is None or provider_name is None:
return Response(
{
"message": "File name, Qiskit Function title and Provider name are mandatory" # pylint: disable=line-too-long
},
status=status.HTTP_400_BAD_REQUEST,
)

if not self.user_has_provider_access(request.user, provider_name):
return Response(
{"message": f"Provider {provider_name} doesn't exist."},
Expand Down Expand Up @@ -276,48 +288,152 @@ def provider_list(self, request):
return Response({"results": files})

@action(methods=["GET"], detail=False)
def download(self, request): # pylint: disable=invalid-name
"""Download selected file."""
# default response for file not found, overwritten if file is found
response = Response(
{"message": "Requested file was not found."},
status=status.HTTP_404_NOT_FOUND,
)
def download(self, request):
"""
It returns a file from user paths:
- username/
- username/provider_name/function_title
"""
tracer = trace.get_tracer("gateway.tracer")
ctx = TraceContextTextMapPropagator().extract(carrier=request.headers)
with tracer.start_as_current_span("gateway.files.download", context=ctx):
requested_file_name = request.query_params.get("file")
provider_name = request.query_params.get("provider")
if requested_file_name is not None:
user_dir = request.user.username
if provider_name is not None:
if self.check_user_has_provider(request.user, provider_name):
user_dir = provider_name
else:
return response
# look for file in user's folder
filename = os.path.basename(requested_file_name)
user_dir = os.path.join(
sanitize_file_path(settings.MEDIA_ROOT),
sanitize_file_path(user_dir),
username = request.user.username
requested_file_name = sanitize_file_name(
request.query_params.get("file", None)
)
provider_name = sanitize_name(request.query_params.get("provider", None))
Tansito marked this conversation as resolved.
Show resolved Hide resolved
function_title = sanitize_name(request.query_params.get("function", None))
working_dir = WorkingDir.USER_STORAGE

if requested_file_name is None or function_title is None:
return Response(
{"message": "File name and Qiskit Function title are mandatory"},
status=status.HTTP_400_BAD_REQUEST,
)
file_path = os.path.join(
sanitize_file_path(user_dir), sanitize_file_path(filename)

if not FileStorage.file_extension_is_valid(requested_file_name):
extensions = ", ".join(SUPPORTED_FILE_EXTENSIONS)
return Response(
{
"message": f"File name needs to have a valid extension: {extensions}"
},
status=status.HTTP_400_BAD_REQUEST,
)
if os.path.exists(user_dir) and os.path.exists(file_path) and filename:
chunk_size = 8192
# note: we do not use with statements as Streaming response closing file itself.
response = StreamingHttpResponse(
FileWrapper(
open( # pylint: disable=consider-using-with
file_path, "rb"
),
chunk_size,
),
content_type=mimetypes.guess_type(file_path)[0],
)
response["Content-Length"] = os.path.getsize(file_path)
response["Content-Disposition"] = f"attachment; filename={filename}"

function = self.get_function(
user=request.user,
function_title=function_title,
provider_name=provider_name,
)
if not function:
if provider_name:
error_message = f"Qiskit Function {provider_name}/{function_title} doesn't exist." # pylint: disable=line-too-long
else:
error_message = f"Qiskit Function {function_title} doesn't exist."
return Response(
{"message": error_message},
status=status.HTTP_404_NOT_FOUND,
)

file_storage = FileStorage(
username=username,
working_dir=working_dir,
function_title=function_title,
provider_name=provider_name,
)
result = file_storage.get_file(file_name=requested_file_name)
if result is None:
return Response(
{"message": "Requested file was not found."},
status=status.HTTP_404_NOT_FOUND,
)

file_wrapper, file_type, file_size = result
response = StreamingHttpResponse(file_wrapper, content_type=file_type)
response["Content-Length"] = file_size
response[
"Content-Disposition"
] = f"attachment; filename={requested_file_name}"
return response

@action(methods=["GET"], detail=False, url_path="provider/download")
def provider_download(self, request):
"""
It returns a file from provider path:
- provider_name/function_title
"""
tracer = trace.get_tracer("gateway.tracer")
ctx = TraceContextTextMapPropagator().extract(carrier=request.headers)
with tracer.start_as_current_span(
"gateway.files.provider_download", context=ctx
):
username = request.user.username
requested_file_name = sanitize_file_name(
request.query_params.get("file", None)
)
provider_name = sanitize_name(request.query_params.get("provider", None))
function_title = sanitize_name(request.query_params.get("function", None))
working_dir = WorkingDir.PROVIDER_STORAGE

if (
requested_file_name is None
or function_title is None
or provider_name is None
):
Tansito marked this conversation as resolved.
Show resolved Hide resolved
return Response(
{
"message": "File name, Qiskit Function title and Provider name are mandatory" # pylint: disable=line-too-long
},
status=status.HTTP_400_BAD_REQUEST,
)

if not FileStorage.file_extension_is_valid(requested_file_name):
extensions = ", ".join(SUPPORTED_FILE_EXTENSIONS)
return Response(
{
"message": f"File name needs to have a valid extension: {extensions}"
},
status=status.HTTP_400_BAD_REQUEST,
)

if not self.user_has_provider_access(request.user, provider_name):
return Response(
{"message": f"Provider {provider_name} doesn't exist."},
status=status.HTTP_404_NOT_FOUND,
)

function = self.get_function(
user=request.user,
function_title=function_title,
provider_name=provider_name,
)
Tansito marked this conversation as resolved.
Show resolved Hide resolved
if not function:
return Response(
{
"message": f"Qiskit Function {provider_name}/{function_title} doesn't exist." # pylint: disable=line-too-long
},
status=status.HTTP_404_NOT_FOUND,
)

file_storage = FileStorage(
username=username,
working_dir=working_dir,
function_title=function_title,
provider_name=provider_name,
)
result = file_storage.get_file(file_name=requested_file_name)
if result is None:
return Response(
{"message": "Requested file was not found."},
status=status.HTTP_404_NOT_FOUND,
)

file_wrapper, file_type, file_size = result
response = StreamingHttpResponse(file_wrapper, content_type=file_type)
response["Content-Length"] = file_size
response[
"Content-Disposition"
] = f"attachment; filename={requested_file_name}"
Tansito marked this conversation as resolved.
Show resolved Hide resolved
return response

@action(methods=["DELETE"], detail=False)
Expand Down
Loading
Loading