Skip to content

Commit

Permalink
Feat/save result into result store (#1570)
Browse files Browse the repository at this point in the history
* add job access polocies

* fix lint

* apply review suggestions

* fix: create job repository

* fix job id

* fix result asignment

* fix test and lint

* feat: retrieve job results from result storage

* remove result file

* retrieve job results

* apply review suggestions

* fix typo

* add test fake file

* fix test media path for retrieve result

* remove prints

* fix lint

* remove unnecesary file

* revert file deletion
  • Loading branch information
paaragon authored Jan 29, 2025
1 parent 521dd80 commit c92124e
Show file tree
Hide file tree
Showing 7 changed files with 277 additions and 63 deletions.
67 changes: 67 additions & 0 deletions gateway/api/access_policies/jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
Access policies implementation for Job access
"""
import logging
from django.contrib.auth import get_user_model
from api.models import Job

User = get_user_model()


logger = logging.getLogger("gateway")


class JobAccessPolocies: # pylint: disable=too-few-public-methods
"""
The main objective of this class is to manage the access for the user
to the Job entities.
"""

@staticmethod
def can_access(user: User, job: Job) -> bool:
"""
Checks if the user has access to save the result of a Job:
Args:
user: Django user from the request
job: Job instance against to check the access
Returns:
bool: True or False in case the user has access
"""

is_provider_job = job.program and job.program.provider
if is_provider_job:
provider_groups = job.program.provider.admin_groups.all()
author_groups = user.groups.all()
has_access = any(group in provider_groups for group in author_groups)
else:
has_access = user.id == job.author.id

if not has_access:
logger.warning(
"User [%s] has no access to job [%s].", user.username, job.author
)
return has_access

@staticmethod
def can_save_result(user: User, job: Job) -> bool:
"""
Checks if the user has permissions to save the result of a job:
Args:
user: Django user from the request
job: Job instance against to check the permission
Returns:
bool: True or False in case the user has permissions
"""

has_access = user.id == job.author.id
if not has_access:
logger.warning(
"User [%s] has no access to save the result of the job [%s].",
user.username,
job.author,
)
return has_access
31 changes: 31 additions & 0 deletions gateway/api/repositories/jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
Repository implementation for Job model
"""
import logging
from api.models import Job

logger = logging.getLogger("gateway")


class JobsRepository: # pylint: disable=too-few-public-methods
"""
The main objective of this class is to manage the access to the Job model
"""

def get_job_by_id(self, job_id: str) -> Job:
"""
Returns the job for the given id:
Args:
id (str): id of the job
Returns:
Job | None: job with the requested id
"""

result_queryset = Job.objects.filter(id=job_id).first()

if result_queryset is None:
logger.warning("Job [%s] was not found", id)

return result_queryset
28 changes: 14 additions & 14 deletions gateway/api/services/result_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
"""
import os
import logging
import mimetypes
from typing import Optional, Tuple
from wsgiref.util import FileWrapper
from typing import Optional
from django.conf import settings

logger = logging.getLogger("gateway")
Expand All @@ -24,13 +22,13 @@ def __init__(self, username: str):
)
os.makedirs(self.user_results_directory, exist_ok=True)

def __build_result_path(self, job_id: str) -> str:
def __get_result_path(self, job_id: str) -> str:
"""Construct the full path for a result file."""
return os.path.join(
self.user_results_directory, f"{job_id}{self.RESULT_FILE_EXTENSION}"
)

def get(self, job_id: str) -> Optional[Tuple[FileWrapper, str, int]]:
def get(self, job_id: str) -> Optional[str]:
"""
Retrieve a result file for the given job ID.
Expand All @@ -40,8 +38,7 @@ def get(self, job_id: str) -> Optional[Tuple[FileWrapper, str, int]]:
- File MIME type
- File size in bytes
"""
result_path = self.__build_result_path(job_id)

result_path = self.__get_result_path(job_id)
if not os.path.exists(result_path):
logger.warning(
"Result file for job ID '%s' not found in directory '%s'.",
Expand All @@ -50,13 +47,16 @@ def get(self, job_id: str) -> Optional[Tuple[FileWrapper, str, int]]:
)
return None

with open(result_path, "rb") as result_file:
file_wrapper = FileWrapper(result_file)
file_type = (
mimetypes.guess_type(result_path)[0] or "application/octet-stream"
try:
with open(result_path, "r", encoding="utf-8") as result_file:
return result_file.read()
except (UnicodeDecodeError, IOError) as e:
logger.error(
"Failed to read result file for job ID '%s': %s",
job_id,
str(e),
)
file_size = os.path.getsize(result_path)
return file_wrapper, file_type, file_size
return None

def save(self, job_id: str, result: str) -> None:
"""
Expand All @@ -67,7 +67,7 @@ def save(self, job_id: str, result: str) -> None:
name for the result file.
result (str): The job result content to be saved in the file.
"""
result_path = self.__build_result_path(job_id)
result_path = self.__get_result_path(job_id)

with open(result_path, "w", encoding=self.ENCODING) as result_file:
result_file.write(result)
Expand Down
11 changes: 11 additions & 0 deletions gateway/api/v1/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,17 @@ class Meta(serializers.JobSerializer.Meta):
fields = ["id", "result", "status", "program", "created"]


class JobSerializerWithoutResult(serializers.JobSerializer):
"""
Job serializer first version. Include basic fields from the initial model.
"""

program = ProgramSerializer(many=False)

class Meta(serializers.JobSerializer.Meta):
fields = ["id", "status", "program", "created"]


class RuntimeJobSerializer(serializers.RuntimeJobSerializer):
"""
Runtime job serializer first version. Serializer for the runtime job model.
Expand Down
112 changes: 76 additions & 36 deletions gateway/api/views/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
import json
import logging
import os
import time

from concurrency.exceptions import RecordModifiedError
from django.db.models import Q

# pylint: disable=duplicate-code
Expand All @@ -26,6 +24,10 @@
from api.models import Job, RuntimeJob
from api.ray import get_job_handler
from api.views.enums.type_filter import TypeFilter
from api.services.result_storage import ResultStorage
from api.access_policies.jobs import JobAccessPolocies
from api.repositories.jobs import JobsRepository
from api.v1 import serializers as v1_serializers

# pylint: disable=duplicate-code
logger = logging.getLogger("gateway")
Expand All @@ -51,10 +53,39 @@ class JobViewSet(viewsets.GenericViewSet):

BASE_NAME = "jobs"

jobs_repository = JobsRepository()

def get_serializer_class(self):
"""
Returns the default serializer class for the view.
"""
return self.serializer_class

@staticmethod
def get_serializer_job(*args, **kwargs):
"""
Returns a `JobSerializer` instance
"""
return v1_serializers.JobSerializer(*args, **kwargs)

@staticmethod
def get_serializer_job_without_result(*args, **kwargs):
"""
Returns a `JobSerializerWithoutResult` instance
"""
return v1_serializers.JobSerializerWithoutResult(*args, **kwargs)

def get_queryset(self):
"""
Returns a filtered queryset of `Job` objects based on the `filter` query parameter.
- If `filter=catalog`, returns jobs authored by the user with an existing provider.
- If `filter=serverless`, returns jobs authored by the user without a provider.
- Otherwise, returns all jobs authored by the user.
Returns:
QuerySet: A filtered queryset of `Job` objects ordered by creation date (descending).
"""
type_filter = self.request.query_params.get("filter")
if type_filter:
if type_filter == TypeFilter.CATALOG:
Expand All @@ -76,24 +107,36 @@ def retrieve(self, request, pk=None): # pylint: disable=unused-argument
tracer = trace.get_tracer("gateway.tracer")
ctx = TraceContextTextMapPropagator().extract(carrier=request.headers)
with tracer.start_as_current_span("gateway.job.retrieve", context=ctx):
job = Job.objects.filter(pk=pk).first()

author = self.request.user
job = self.jobs_repository.get_job_by_id(pk)
if job is None:
logger.info("Job [%s] nor found", pk)
return Response(
{"message": f"Job [{pk}] nor found"},
status=status.HTTP_404_NOT_FOUND,
)

if not JobAccessPolocies.can_access(author, job):
logger.warning("Job [%s] not found", pk)
return Response(
{"message": f"Job [{pk}] was not found."},
status=status.HTTP_404_NOT_FOUND,
)
author = self.request.user
if job.program and job.program.provider:
provider_groups = job.program.provider.admin_groups.all()
author_groups = author.groups.all()
has_access = any(group in provider_groups for group in author_groups)
if has_access:
serializer = self.get_serializer(job)
return Response(serializer.data)
instance = self.get_object()
serializer = self.get_serializer(instance)
return Response(serializer.data)

is_provider_job = job.program and job.program.provider
if is_provider_job:
serializer = self.get_serializer_job_without_result(job)
return Response(serializer.data)

result_store = ResultStorage(author.username)
result = result_store.get(str(job.id))
if result is not None:
job.result = result

serializer = self.get_serializer_job(job)

return Response(serializer.data)

def list(self, request):
"""List jobs:"""
Expand All @@ -116,29 +159,26 @@ def result(self, request, pk=None): # pylint: disable=invalid-name,unused-argum
tracer = trace.get_tracer("gateway.tracer")
ctx = TraceContextTextMapPropagator().extract(carrier=request.headers)
with tracer.start_as_current_span("gateway.job.result", context=ctx):
saved = False
attempts_left = 10
while not saved:
if attempts_left <= 0:
return Response(
{"error": "All attempts to save results failed."}, status=500
)
author = self.request.user
job = self.jobs_repository.get_job_by_id(pk)
if job is None:
logger.info("Job [%s] nor found", pk)
return Response(
{"message": f"Job [{pk}] nor found"},
status=status.HTTP_404_NOT_FOUND,
)

attempts_left -= 1

try:
job = self.get_object()
job.result = json.dumps(request.data.get("result"))
job.save()
saved = True
except RecordModifiedError:
logger.warning(
"Job [%s] record has not been updated due to lock. Retrying. Attempts left %s", # pylint: disable=line-too-long
job.id,
attempts_left,
)
continue
time.sleep(1)
can_access = JobAccessPolocies.can_save_result(author, job)
if not can_access:
logger.info("Job [%s] nor found for author %s", pk, author.username)
return Response(
{"message": f"Job [{job.id}] nor found"},
status=status.HTTP_404_NOT_FOUND,
)

job.result = json.dumps(request.data.get("result"))
result_storage = ResultStorage(author.username)
result_storage.save(job.id, job.result)

serializer = self.get_serializer(job)
return Response(serializer.data)
Expand Down
Loading

0 comments on commit c92124e

Please sign in to comment.