From c92124e7ec63606a8bba5aa24679632848fb3b7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pablo=20Arag=C3=B3n?= Date: Wed, 29 Jan 2025 18:02:19 +0100 Subject: [PATCH] Feat/save result into result store (#1570) * add job access polocies * fix lint * apply review suggestions * fix: create job repository * fix job id * fix result asignment * fix test and lint * feat: retrieve job results from result storage * remove result file * retrieve job results * apply review suggestions * fix typo * add test fake file * fix test media path for retrieve result * remove prints * fix lint * remove unnecesary file * revert file deletion --- gateway/api/access_policies/jobs.py | 67 +++++++++++ gateway/api/repositories/jobs.py | 31 +++++ gateway/api/services/result_storage.py | 28 ++--- gateway/api/v1/serializers.py | 11 ++ gateway/api/views/jobs.py | 112 ++++++++++++------ gateway/tests/api/test_job.py | 90 ++++++++++++-- .../8317718f-5c0d-4fb6-9947-72e480b8a348.json | 1 + 7 files changed, 277 insertions(+), 63 deletions(-) create mode 100644 gateway/api/access_policies/jobs.py create mode 100644 gateway/api/repositories/jobs.py create mode 100644 gateway/tests/resources/fake_media/test_user/results/8317718f-5c0d-4fb6-9947-72e480b8a348.json diff --git a/gateway/api/access_policies/jobs.py b/gateway/api/access_policies/jobs.py new file mode 100644 index 000000000..499ff1d26 --- /dev/null +++ b/gateway/api/access_policies/jobs.py @@ -0,0 +1,67 @@ +""" +Access policies implementation for Job access +""" +import logging +from django.contrib.auth import get_user_model +from api.models import Job + +User = get_user_model() + + +logger = logging.getLogger("gateway") + + +class JobAccessPolocies: # pylint: disable=too-few-public-methods + """ + The main objective of this class is to manage the access for the user + to the Job entities. + """ + + @staticmethod + def can_access(user: User, job: Job) -> bool: + """ + Checks if the user has access to save the result of a Job: + + Args: + user: Django user from the request + job: Job instance against to check the access + + Returns: + bool: True or False in case the user has access + """ + + is_provider_job = job.program and job.program.provider + if is_provider_job: + provider_groups = job.program.provider.admin_groups.all() + author_groups = user.groups.all() + has_access = any(group in provider_groups for group in author_groups) + else: + has_access = user.id == job.author.id + + if not has_access: + logger.warning( + "User [%s] has no access to job [%s].", user.username, job.author + ) + return has_access + + @staticmethod + def can_save_result(user: User, job: Job) -> bool: + """ + Checks if the user has permissions to save the result of a job: + + Args: + user: Django user from the request + job: Job instance against to check the permission + + Returns: + bool: True or False in case the user has permissions + """ + + has_access = user.id == job.author.id + if not has_access: + logger.warning( + "User [%s] has no access to save the result of the job [%s].", + user.username, + job.author, + ) + return has_access diff --git a/gateway/api/repositories/jobs.py b/gateway/api/repositories/jobs.py new file mode 100644 index 000000000..39ba5ada9 --- /dev/null +++ b/gateway/api/repositories/jobs.py @@ -0,0 +1,31 @@ +""" +Repository implementation for Job model +""" +import logging +from api.models import Job + +logger = logging.getLogger("gateway") + + +class JobsRepository: # pylint: disable=too-few-public-methods + """ + The main objective of this class is to manage the access to the Job model + """ + + def get_job_by_id(self, job_id: str) -> Job: + """ + Returns the job for the given id: + + Args: + id (str): id of the job + + Returns: + Job | None: job with the requested id + """ + + result_queryset = Job.objects.filter(id=job_id).first() + + if result_queryset is None: + logger.warning("Job [%s] was not found", id) + + return result_queryset diff --git a/gateway/api/services/result_storage.py b/gateway/api/services/result_storage.py index 028dad187..ce740d881 100644 --- a/gateway/api/services/result_storage.py +++ b/gateway/api/services/result_storage.py @@ -3,9 +3,7 @@ """ import os import logging -import mimetypes -from typing import Optional, Tuple -from wsgiref.util import FileWrapper +from typing import Optional from django.conf import settings logger = logging.getLogger("gateway") @@ -24,13 +22,13 @@ def __init__(self, username: str): ) os.makedirs(self.user_results_directory, exist_ok=True) - def __build_result_path(self, job_id: str) -> str: + def __get_result_path(self, job_id: str) -> str: """Construct the full path for a result file.""" return os.path.join( self.user_results_directory, f"{job_id}{self.RESULT_FILE_EXTENSION}" ) - def get(self, job_id: str) -> Optional[Tuple[FileWrapper, str, int]]: + def get(self, job_id: str) -> Optional[str]: """ Retrieve a result file for the given job ID. @@ -40,8 +38,7 @@ def get(self, job_id: str) -> Optional[Tuple[FileWrapper, str, int]]: - File MIME type - File size in bytes """ - result_path = self.__build_result_path(job_id) - + result_path = self.__get_result_path(job_id) if not os.path.exists(result_path): logger.warning( "Result file for job ID '%s' not found in directory '%s'.", @@ -50,13 +47,16 @@ def get(self, job_id: str) -> Optional[Tuple[FileWrapper, str, int]]: ) return None - with open(result_path, "rb") as result_file: - file_wrapper = FileWrapper(result_file) - file_type = ( - mimetypes.guess_type(result_path)[0] or "application/octet-stream" + try: + with open(result_path, "r", encoding="utf-8") as result_file: + return result_file.read() + except (UnicodeDecodeError, IOError) as e: + logger.error( + "Failed to read result file for job ID '%s': %s", + job_id, + str(e), ) - file_size = os.path.getsize(result_path) - return file_wrapper, file_type, file_size + return None def save(self, job_id: str, result: str) -> None: """ @@ -67,7 +67,7 @@ def save(self, job_id: str, result: str) -> None: name for the result file. result (str): The job result content to be saved in the file. """ - result_path = self.__build_result_path(job_id) + result_path = self.__get_result_path(job_id) with open(result_path, "w", encoding=self.ENCODING) as result_file: result_file.write(result) diff --git a/gateway/api/v1/serializers.py b/gateway/api/v1/serializers.py index 65ae25409..d9d64fc4f 100644 --- a/gateway/api/v1/serializers.py +++ b/gateway/api/v1/serializers.py @@ -146,6 +146,17 @@ class Meta(serializers.JobSerializer.Meta): fields = ["id", "result", "status", "program", "created"] +class JobSerializerWithoutResult(serializers.JobSerializer): + """ + Job serializer first version. Include basic fields from the initial model. + """ + + program = ProgramSerializer(many=False) + + class Meta(serializers.JobSerializer.Meta): + fields = ["id", "status", "program", "created"] + + class RuntimeJobSerializer(serializers.RuntimeJobSerializer): """ Runtime job serializer first version. Serializer for the runtime job model. diff --git a/gateway/api/views/jobs.py b/gateway/api/views/jobs.py index 45376e50d..a9bf743bc 100644 --- a/gateway/api/views/jobs.py +++ b/gateway/api/views/jobs.py @@ -6,9 +6,7 @@ import json import logging import os -import time -from concurrency.exceptions import RecordModifiedError from django.db.models import Q # pylint: disable=duplicate-code @@ -26,6 +24,10 @@ from api.models import Job, RuntimeJob from api.ray import get_job_handler from api.views.enums.type_filter import TypeFilter +from api.services.result_storage import ResultStorage +from api.access_policies.jobs import JobAccessPolocies +from api.repositories.jobs import JobsRepository +from api.v1 import serializers as v1_serializers # pylint: disable=duplicate-code logger = logging.getLogger("gateway") @@ -51,10 +53,39 @@ class JobViewSet(viewsets.GenericViewSet): BASE_NAME = "jobs" + jobs_repository = JobsRepository() + def get_serializer_class(self): + """ + Returns the default serializer class for the view. + """ return self.serializer_class + @staticmethod + def get_serializer_job(*args, **kwargs): + """ + Returns a `JobSerializer` instance + """ + return v1_serializers.JobSerializer(*args, **kwargs) + + @staticmethod + def get_serializer_job_without_result(*args, **kwargs): + """ + Returns a `JobSerializerWithoutResult` instance + """ + return v1_serializers.JobSerializerWithoutResult(*args, **kwargs) + def get_queryset(self): + """ + Returns a filtered queryset of `Job` objects based on the `filter` query parameter. + + - If `filter=catalog`, returns jobs authored by the user with an existing provider. + - If `filter=serverless`, returns jobs authored by the user without a provider. + - Otherwise, returns all jobs authored by the user. + + Returns: + QuerySet: A filtered queryset of `Job` objects ordered by creation date (descending). + """ type_filter = self.request.query_params.get("filter") if type_filter: if type_filter == TypeFilter.CATALOG: @@ -76,24 +107,36 @@ def retrieve(self, request, pk=None): # pylint: disable=unused-argument tracer = trace.get_tracer("gateway.tracer") ctx = TraceContextTextMapPropagator().extract(carrier=request.headers) with tracer.start_as_current_span("gateway.job.retrieve", context=ctx): - job = Job.objects.filter(pk=pk).first() + + author = self.request.user + job = self.jobs_repository.get_job_by_id(pk) if job is None: + logger.info("Job [%s] nor found", pk) + return Response( + {"message": f"Job [{pk}] nor found"}, + status=status.HTTP_404_NOT_FOUND, + ) + + if not JobAccessPolocies.can_access(author, job): logger.warning("Job [%s] not found", pk) return Response( {"message": f"Job [{pk}] was not found."}, status=status.HTTP_404_NOT_FOUND, ) - author = self.request.user - if job.program and job.program.provider: - provider_groups = job.program.provider.admin_groups.all() - author_groups = author.groups.all() - has_access = any(group in provider_groups for group in author_groups) - if has_access: - serializer = self.get_serializer(job) - return Response(serializer.data) - instance = self.get_object() - serializer = self.get_serializer(instance) - return Response(serializer.data) + + is_provider_job = job.program and job.program.provider + if is_provider_job: + serializer = self.get_serializer_job_without_result(job) + return Response(serializer.data) + + result_store = ResultStorage(author.username) + result = result_store.get(str(job.id)) + if result is not None: + job.result = result + + serializer = self.get_serializer_job(job) + + return Response(serializer.data) def list(self, request): """List jobs:""" @@ -116,29 +159,26 @@ def result(self, request, pk=None): # pylint: disable=invalid-name,unused-argum tracer = trace.get_tracer("gateway.tracer") ctx = TraceContextTextMapPropagator().extract(carrier=request.headers) with tracer.start_as_current_span("gateway.job.result", context=ctx): - saved = False - attempts_left = 10 - while not saved: - if attempts_left <= 0: - return Response( - {"error": "All attempts to save results failed."}, status=500 - ) + author = self.request.user + job = self.jobs_repository.get_job_by_id(pk) + if job is None: + logger.info("Job [%s] nor found", pk) + return Response( + {"message": f"Job [{pk}] nor found"}, + status=status.HTTP_404_NOT_FOUND, + ) - attempts_left -= 1 - - try: - job = self.get_object() - job.result = json.dumps(request.data.get("result")) - job.save() - saved = True - except RecordModifiedError: - logger.warning( - "Job [%s] record has not been updated due to lock. Retrying. Attempts left %s", # pylint: disable=line-too-long - job.id, - attempts_left, - ) - continue - time.sleep(1) + can_access = JobAccessPolocies.can_save_result(author, job) + if not can_access: + logger.info("Job [%s] nor found for author %s", pk, author.username) + return Response( + {"message": f"Job [{job.id}] nor found"}, + status=status.HTTP_404_NOT_FOUND, + ) + + job.result = json.dumps(request.data.get("result")) + result_storage = ResultStorage(author.username) + result_storage.save(job.id, job.result) serializer = self.get_serializer(job) return Response(serializer.data) diff --git a/gateway/tests/api/test_job.py b/gateway/tests/api/test_job.py index 110160f57..8a004eeca 100644 --- a/gateway/tests/api/test_job.py +++ b/gateway/tests/api/test_job.py @@ -1,11 +1,13 @@ """Tests jobs APIs.""" +import os from django.urls import reverse from rest_framework import status from rest_framework.test import APITestCase from api.models import Job from django.contrib.auth import models +from django.conf import settings class TestJobApi(APITestCase): @@ -73,15 +75,47 @@ def test_job_serverless_list(self): def test_job_detail(self): """Tests job detail authorized.""" - self._authorize() + media_root = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "resources", + "fake_media", + ) + media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) - jobs_response = self.client.get( - reverse("v1:jobs-detail", args=["57fc2e4d-267f-40c6-91a3-38153272e764"]), - format="json", + with self.settings(MEDIA_ROOT=media_root): + self._authorize() + + jobs_response = self.client.get( + reverse( + "v1:jobs-detail", args=["8317718f-5c0d-4fb6-9947-72e480b8a348"] + ), + format="json", + ) + self.assertEqual(jobs_response.status_code, status.HTTP_200_OK) + self.assertEqual(jobs_response.data.get("result"), '{"ultimate": 42}') + + def test_job_detail_without_result_file(self): + """Tests job detail authorized.""" + media_root = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "resources", + "fake_media", ) - self.assertEqual(jobs_response.status_code, status.HTTP_200_OK) - self.assertEqual(jobs_response.data.get("status"), "SUCCEEDED") - self.assertEqual(jobs_response.data.get("result"), '{"somekey":1}') + media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) + + with self.settings(MEDIA_ROOT=media_root): + self._authorize() + + jobs_response = self.client.get( + reverse( + "v1:jobs-detail", args=["57fc2e4d-267f-40c6-91a3-38153272e764"] + ), + format="json", + ) + self.assertEqual(jobs_response.status_code, status.HTTP_200_OK) + self.assertEqual(jobs_response.data.get("result"), '{"somekey":1}') def test_job_provider_detail(self): """Tests job detail authorized.""" @@ -94,7 +128,7 @@ def test_job_provider_detail(self): ) self.assertEqual(jobs_response.status_code, status.HTTP_200_OK) self.assertEqual(jobs_response.data.get("status"), "QUEUED") - self.assertEqual(jobs_response.data.get("result"), '{"somekey":1}') + self.assertEqual(jobs_response.data.get("result"), None) def test_not_authorized_job_detail(self): """Tests job detail fails trying to access to other user job.""" @@ -108,16 +142,46 @@ def test_not_authorized_job_detail(self): def test_job_save_result(self): """Tests job results save.""" - self._authorize() + media_root = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "resources", + "fake_media", + ) + media_root = os.path.normpath(os.path.join(os.getcwd(), media_root)) + + with self.settings(MEDIA_ROOT=media_root): + self._authorize() + + job_id = "57fc2e4d-267f-40c6-91a3-38153272e764" + jobs_response = self.client.post( + reverse("v1:jobs-result", args=[job_id]), + format="json", + data={"result": {"ultimate": 42}}, + ) + self.assertEqual(jobs_response.status_code, status.HTTP_200_OK) + self.assertEqual(jobs_response.data.get("result"), '{"ultimate": 42}') + result_path = os.path.join( + settings.MEDIA_ROOT, "test_user", "results", f"{job_id}.json" + ) + self.assertTrue(os.path.exists(result_path)) + os.remove(result_path) + def test_not_authorized_job_save_result(self): + """Tests job results save.""" + self._authorize() + job_id = "1a7947f9-6ae8-4e3d-ac1e-e7d608deec84" jobs_response = self.client.post( - reverse("v1:jobs-result", args=["57fc2e4d-267f-40c6-91a3-38153272e764"]), + reverse("v1:jobs-result", args=[job_id]), format="json", data={"result": {"ultimate": 42}}, ) - self.assertEqual(jobs_response.status_code, status.HTTP_200_OK) - self.assertEqual(jobs_response.data.get("status"), "SUCCEEDED") - self.assertEqual(jobs_response.data.get("result"), '{"ultimate": 42}') + + self.assertEqual(jobs_response.status_code, status.HTTP_404_NOT_FOUND) + self.assertEqual( + jobs_response.data.get("message"), + f"Job [{job_id}] nor found", + ) def test_stop_job(self): """Tests job stop.""" diff --git a/gateway/tests/resources/fake_media/test_user/results/8317718f-5c0d-4fb6-9947-72e480b8a348.json b/gateway/tests/resources/fake_media/test_user/results/8317718f-5c0d-4fb6-9947-72e480b8a348.json new file mode 100644 index 000000000..a132681f7 --- /dev/null +++ b/gateway/tests/resources/fake_media/test_user/results/8317718f-5c0d-4fb6-9947-72e480b8a348.json @@ -0,0 +1 @@ +{"ultimate": 42} \ No newline at end of file