Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/save result into result store #1570

Merged
merged 22 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions gateway/api/access_policies/jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
Access policies implementation for Job access
"""
import logging
from django.contrib.auth import get_user_model
from api.models import Job

User = get_user_model()


logger = logging.getLogger("gateway")


class JobAccessPolocies: # pylint: disable=too-few-public-methods
"""
The main objective of this class is to manage the access for the user
to the Job entities.
"""

@staticmethod
def can_access(user: User, job: Job) -> bool:
"""
Checks if the user has access to save the result of a Job:

Args:
user: Django user from the request
job: Job instance against to check the access

Returns:
bool: True or False in case the user has access
"""

is_provider_job = job.program and job.program.provider
if is_provider_job:
provider_groups = job.program.provider.admin_groups.all()
author_groups = user.groups.all()
has_access = any(group in provider_groups for group in author_groups)
else:
has_access = user.id == job.author.id

if not has_access:
logger.warning(
"User [%s] has no access to job [%s].", user.username, job.author
Tansito marked this conversation as resolved.
Show resolved Hide resolved
)
return has_access

@staticmethod
def can_save_result(user: User, job: Job) -> bool:
"""
Checks if the user has permissions to save the result of a job:

Args:
user: Django user from the request
job: Job instance against to check the permission

Returns:
bool: True or False in case the user has permissions
"""

has_access = user.id == job.author.id
if not has_access:
logger.warning(
"User [%s] has no access to save the result of the job [%s].",
user.username,
job.author,
)
return has_access
31 changes: 31 additions & 0 deletions gateway/api/repositories/jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
Repository implementation for Job model
"""
import logging
from api.models import Job

logger = logging.getLogger("gateway")


class JobsRepository: # pylint: disable=too-few-public-methods
"""
The main objective of this class is to manage the access to the Job model
"""

def get_job_by_id(self, job_id: str) -> Job:
"""
Returns the job for the given id:

Args:
id (str): id of the job

Returns:
Job | None: job with the requested id
"""

result_queryset = Job.objects.filter(id=job_id).first()

if result_queryset is None:
logger.warning("Job [%s] was not found", id)

return result_queryset
28 changes: 14 additions & 14 deletions gateway/api/services/result_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
"""
import os
import logging
import mimetypes
from typing import Optional, Tuple
from wsgiref.util import FileWrapper
from typing import Optional
from django.conf import settings

logger = logging.getLogger("gateway")
Expand All @@ -24,13 +22,13 @@ def __init__(self, username: str):
)
os.makedirs(self.user_results_directory, exist_ok=True)

def __build_result_path(self, job_id: str) -> str:
def __get_result_path(self, job_id: str) -> str:
"""Construct the full path for a result file."""
return os.path.join(
self.user_results_directory, f"{job_id}{self.RESULT_FILE_EXTENSION}"
)

def get(self, job_id: str) -> Optional[Tuple[FileWrapper, str, int]]:
def get(self, job_id: str) -> Optional[str]:
"""
Retrieve a result file for the given job ID.

Expand All @@ -40,8 +38,7 @@ def get(self, job_id: str) -> Optional[Tuple[FileWrapper, str, int]]:
- File MIME type
- File size in bytes
"""
result_path = self.__build_result_path(job_id)

result_path = self.__get_result_path(job_id)
if not os.path.exists(result_path):
logger.warning(
"Result file for job ID '%s' not found in directory '%s'.",
Expand All @@ -50,13 +47,16 @@ def get(self, job_id: str) -> Optional[Tuple[FileWrapper, str, int]]:
)
return None

with open(result_path, "rb") as result_file:
file_wrapper = FileWrapper(result_file)
file_type = (
mimetypes.guess_type(result_path)[0] or "application/octet-stream"
try:
with open(result_path, "r", encoding="utf-8") as result_file:
return result_file.read()
except (UnicodeDecodeError, IOError) as e:
logger.error(
"Failed to read result file for job ID '%s': %s",
job_id,
str(e),
)
file_size = os.path.getsize(result_path)
return file_wrapper, file_type, file_size
return None

def save(self, job_id: str, result: str) -> None:
"""
Expand All @@ -67,7 +67,7 @@ def save(self, job_id: str, result: str) -> None:
name for the result file.
result (str): The job result content to be saved in the file.
"""
result_path = self.__build_result_path(job_id)
result_path = self.__get_result_path(job_id)

with open(result_path, "w", encoding=self.ENCODING) as result_file:
result_file.write(result)
Expand Down
11 changes: 11 additions & 0 deletions gateway/api/v1/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,17 @@ class Meta(serializers.JobSerializer.Meta):
fields = ["id", "result", "status", "program", "created"]


class JobSerializerWithoutResult(serializers.JobSerializer):
"""
Job serializer first version. Include basic fields from the initial model.
"""

program = ProgramSerializer(many=False)

class Meta(serializers.JobSerializer.Meta):
fields = ["id", "status", "program", "created"]


class RuntimeJobSerializer(serializers.RuntimeJobSerializer):
"""
Runtime job serializer first version. Serializer for the runtime job model.
Expand Down
112 changes: 76 additions & 36 deletions gateway/api/views/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
import json
import logging
import os
import time

from concurrency.exceptions import RecordModifiedError
from django.db.models import Q

# pylint: disable=duplicate-code
Expand All @@ -26,6 +24,10 @@
from api.models import Job, RuntimeJob
from api.ray import get_job_handler
from api.views.enums.type_filter import TypeFilter
from api.services.result_storage import ResultStorage
from api.access_policies.jobs import JobAccessPolocies
from api.repositories.jobs import JobsRepository
from api.v1 import serializers as v1_serializers

# pylint: disable=duplicate-code
logger = logging.getLogger("gateway")
Expand All @@ -51,10 +53,39 @@ class JobViewSet(viewsets.GenericViewSet):

BASE_NAME = "jobs"

jobs_repository = JobsRepository()

def get_serializer_class(self):
"""
Returns the default serializer class for the view.
"""
return self.serializer_class

@staticmethod
def get_serializer_job(*args, **kwargs):
"""
Returns a `JobSerializer` instance
"""
return v1_serializers.JobSerializer(*args, **kwargs)

@staticmethod
def get_serializer_job_without_result(*args, **kwargs):
"""
Returns a `JobSerializerWithoutResult` instance
"""
return v1_serializers.JobSerializerWithoutResult(*args, **kwargs)

def get_queryset(self):
"""
Returns a filtered queryset of `Job` objects based on the `filter` query parameter.

- If `filter=catalog`, returns jobs authored by the user with an existing provider.
- If `filter=serverless`, returns jobs authored by the user without a provider.
- Otherwise, returns all jobs authored by the user.

Returns:
QuerySet: A filtered queryset of `Job` objects ordered by creation date (descending).
"""
type_filter = self.request.query_params.get("filter")
if type_filter:
if type_filter == TypeFilter.CATALOG:
Expand All @@ -76,24 +107,36 @@ def retrieve(self, request, pk=None): # pylint: disable=unused-argument
tracer = trace.get_tracer("gateway.tracer")
ctx = TraceContextTextMapPropagator().extract(carrier=request.headers)
with tracer.start_as_current_span("gateway.job.retrieve", context=ctx):
job = Job.objects.filter(pk=pk).first()

author = self.request.user
job = self.jobs_repository.get_job_by_id(pk)
if job is None:
logger.info("Job [%s] nor found", pk)
return Response(
{"message": f"Job [{pk}] nor found"},
status=status.HTTP_404_NOT_FOUND,
)

if not JobAccessPolocies.can_access(author, job):
logger.warning("Job [%s] not found", pk)
return Response(
{"message": f"Job [{pk}] was not found."},
status=status.HTTP_404_NOT_FOUND,
)
author = self.request.user
if job.program and job.program.provider:
provider_groups = job.program.provider.admin_groups.all()
author_groups = author.groups.all()
has_access = any(group in provider_groups for group in author_groups)
if has_access:
serializer = self.get_serializer(job)
return Response(serializer.data)
instance = self.get_object()
serializer = self.get_serializer(instance)
return Response(serializer.data)

is_provider_job = job.program and job.program.provider
if is_provider_job:
serializer = self.get_serializer_job_without_result(job)
return Response(serializer.data)

result_store = ResultStorage(author.username)
result = result_store.get(str(job.id))
if result is not None:
job.result = result

serializer = self.get_serializer_job(job)

return Response(serializer.data)

def list(self, request):
"""List jobs:"""
Expand All @@ -116,29 +159,26 @@ def result(self, request, pk=None): # pylint: disable=invalid-name,unused-argum
tracer = trace.get_tracer("gateway.tracer")
ctx = TraceContextTextMapPropagator().extract(carrier=request.headers)
with tracer.start_as_current_span("gateway.job.result", context=ctx):
saved = False
attempts_left = 10
while not saved:
if attempts_left <= 0:
return Response(
{"error": "All attempts to save results failed."}, status=500
)
author = self.request.user
job = self.jobs_repository.get_job_by_id(pk)
if job is None:
logger.info("Job [%s] nor found", pk)
return Response(
{"message": f"Job [{pk}] nor found"},
status=status.HTTP_404_NOT_FOUND,
)

attempts_left -= 1

try:
job = self.get_object()
job.result = json.dumps(request.data.get("result"))
job.save()
saved = True
except RecordModifiedError:
logger.warning(
"Job [%s] record has not been updated due to lock. Retrying. Attempts left %s", # pylint: disable=line-too-long
job.id,
attempts_left,
)
continue
time.sleep(1)
can_access = JobAccessPolocies.can_save_result(author, job)
if not can_access:
logger.info("Job [%s] nor found for author %s", pk, author.username)
return Response(
{"message": f"Job [{job.id}] nor found"},
status=status.HTTP_404_NOT_FOUND,
)

job.result = json.dumps(request.data.get("result"))
result_storage = ResultStorage(author.username)
result_storage.save(job.id, job.result)

serializer = self.get_serializer(job)
return Response(serializer.data)
Expand Down
Loading