From 89a7128236dcc453dc0719fc9c2d157d90c1a178 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Fern=C3=A1ndez=20Poyatos?= Date: Fri, 29 Nov 2024 15:26:14 +0100 Subject: [PATCH] feat(Provider): add soft deletion for providers and related resources (#5956) --- api/src/backend/api/db_utils.py | 36 +++++++++++++-- .../backend/api/migrations/0001_initial.py | 3 +- api/src/backend/api/models.py | 45 +++++++++++++++++++ api/src/backend/api/v1/views.py | 5 ++- api/src/backend/tasks/jobs/deletion.py | 45 ++++++++++++++----- api/src/backend/tasks/tasks.py | 6 ++- api/src/backend/tasks/tests/test_deletion.py | 6 +-- 7 files changed, 124 insertions(+), 22 deletions(-) diff --git a/api/src/backend/api/db_utils.py b/api/src/backend/api/db_utils.py index d90c2c2340..d9255d193d 100644 --- a/api/src/backend/api/db_utils.py +++ b/api/src/backend/api/db_utils.py @@ -1,12 +1,13 @@ import secrets from contextlib import contextmanager -from datetime import datetime, timezone, timedelta +from datetime import datetime, timedelta, timezone from django.conf import settings from django.contrib.auth.models import BaseUserManager -from django.db import models, transaction, connection +from django.core.paginator import Paginator +from django.db import connection, models, transaction from psycopg2 import connect as psycopg2_connect -from psycopg2.extensions import new_type, register_type, register_adapter, AsIs +from psycopg2.extensions import AsIs, new_type, register_adapter, register_type DB_USER = settings.DATABASES["default"]["USER"] if not settings.TESTING else "test" DB_PASSWORD = ( @@ -88,6 +89,35 @@ def generate_random_token(length: int = 14, symbols: str | None = None) -> str: return "".join(secrets.choice(symbols or _symbols) for _ in range(length)) +def batch_delete(queryset, batch_size=5000): + """ + Deletes objects in batches and returns the total number of deletions and a summary. + + Args: + queryset (QuerySet): The queryset of objects to delete. + batch_size (int): The number of objects to delete in each batch. + + Returns: + tuple: (total_deleted, deletion_summary) + """ + total_deleted = 0 + deletion_summary = {} + + paginator = Paginator(queryset.order_by("id").only("id"), batch_size) + + for page_num in paginator.page_range: + batch_ids = [obj.id for obj in paginator.page(page_num).object_list] + + deleted_count, deleted_info = queryset.filter(id__in=batch_ids).delete() + + total_deleted += deleted_count + + for model_label, count in deleted_info.items(): + deletion_summary[model_label] = deletion_summary.get(model_label, 0) + count + + return total_deleted, deletion_summary + + # Postgres Enums diff --git a/api/src/backend/api/migrations/0001_initial.py b/api/src/backend/api/migrations/0001_initial.py index dc492f17c5..f2ec03ed5b 100644 --- a/api/src/backend/api/migrations/0001_initial.py +++ b/api/src/backend/api/migrations/0001_initial.py @@ -387,6 +387,7 @@ class Migration(migrations.Migration): ), ("inserted_at", models.DateTimeField(auto_now_add=True)), ("updated_at", models.DateTimeField(auto_now=True)), + ("is_deleted", models.BooleanField(default=False)), ( "provider", ProviderEnumField( @@ -1093,7 +1094,7 @@ class Migration(migrations.Migration): }, bases=(PostgresPartitionedModel,), managers=[ - ("objects", PostgresManager()), + ("objects", api.models.ActiveProviderPartitionedManager()), ], ), migrations.RunSQL( diff --git a/api/src/backend/api/models.py b/api/src/backend/api/models.py index 7cc7229961..4c106cd8c7 100644 --- a/api/src/backend/api/models.py +++ b/api/src/backend/api/models.py @@ -9,8 +9,10 @@ from django.contrib.postgres.search import SearchVector, SearchVectorField from django.core.validators import MinLengthValidator from django.db import models +from django.db.models import Q from django.utils.translation import gettext_lazy as _ from django_celery_results.models import TaskResult +from psqlextra.manager import PostgresManager from psqlextra.models import PostgresPartitionedModel from psqlextra.types import PostgresPartitioningMethod from uuid6 import uuid7 @@ -67,6 +69,24 @@ class StateChoices(models.TextChoices): CANCELLED = "cancelled", _("Cancelled") +class ActiveProviderManager(models.Manager): + def get_queryset(self): + return super().get_queryset().filter(self.active_provider_filter()) + + def active_provider_filter(self): + if self.model is Provider: + return Q(is_deleted=False) + elif self.model in [Finding, ComplianceOverview, ScanSummary]: + return Q(scan__provider__is_deleted=False) + else: + return Q(provider__is_deleted=False) + + +class ActiveProviderPartitionedManager(PostgresManager, ActiveProviderManager): + def get_queryset(self): + return super().get_queryset().filter(self.active_provider_filter()) + + class User(AbstractBaseUser): id = models.UUIDField(primary_key=True, default=uuid4, editable=False) name = models.CharField(max_length=150, validators=[MinLengthValidator(3)]) @@ -147,6 +167,9 @@ class JSONAPIMeta: class Provider(RowLevelSecurityProtectedModel): + objects = ActiveProviderManager() + all_objects = models.Manager() + class ProviderChoices(models.TextChoices): AWS = "aws", _("AWS") AZURE = "azure", _("Azure") @@ -202,6 +225,7 @@ def validate_kubernetes_uid(value): id = models.UUIDField(primary_key=True, default=uuid4, editable=False) inserted_at = models.DateTimeField(auto_now_add=True, editable=False) updated_at = models.DateTimeField(auto_now=True, editable=False) + is_deleted = models.BooleanField(default=False) provider = ProviderEnumField( choices=ProviderChoices.choices, default=ProviderChoices.AWS ) @@ -274,6 +298,9 @@ class JSONAPIMeta: class ProviderGroupMembership(RowLevelSecurityProtectedModel): + objects = ActiveProviderManager() + all_objects = models.Manager() + id = models.UUIDField(primary_key=True, default=uuid4, editable=False) provider = models.ForeignKey( Provider, @@ -338,6 +365,9 @@ class JSONAPIMeta: class Scan(RowLevelSecurityProtectedModel): + objects = ActiveProviderManager() + all_objects = models.Manager() + class TriggerChoices(models.TextChoices): SCHEDULED = "scheduled", _("Scheduled") MANUAL = "manual", _("Manual") @@ -435,6 +465,9 @@ class Meta(RowLevelSecurityProtectedModel.Meta): class Resource(RowLevelSecurityProtectedModel): + objects = ActiveProviderManager() + all_objects = models.Manager() + id = models.UUIDField(primary_key=True, default=uuid4, editable=False) inserted_at = models.DateTimeField(auto_now_add=True, editable=False) updated_at = models.DateTimeField(auto_now=True, editable=False) @@ -561,6 +594,9 @@ class Finding(PostgresPartitionedModel, RowLevelSecurityProtectedModel): Note when creating migrations, you must use `python manage.py pgmakemigrations` to create the migrations. """ + objects = ActiveProviderPartitionedManager() + all_objects = models.Manager() + class PartitioningMeta: method = PostgresPartitioningMethod.RANGE key = ["id"] @@ -712,6 +748,9 @@ class Meta(RowLevelSecurityProtectedModel.Meta): class ProviderSecret(RowLevelSecurityProtectedModel): + objects = ActiveProviderManager() + all_objects = models.Manager() + class TypeChoices(models.TextChoices): STATIC = "static", _("Key-value pairs") ROLE = "role", _("Role assumption") @@ -812,6 +851,9 @@ class JSONAPIMeta: class ComplianceOverview(RowLevelSecurityProtectedModel): + objects = ActiveProviderManager() + all_objects = models.Manager() + id = models.UUIDField(primary_key=True, default=uuid4, editable=False) inserted_at = models.DateTimeField(auto_now_add=True, editable=False) compliance_id = models.CharField(max_length=100, blank=False, null=False) @@ -861,6 +903,9 @@ class JSONAPIMeta: class ScanSummary(RowLevelSecurityProtectedModel): + objects = ActiveProviderManager() + all_objects = models.Manager() + id = models.UUIDField(primary_key=True, default=uuid4, editable=False) inserted_at = models.DateTimeField(auto_now_add=True, editable=False) check_id = models.CharField(max_length=100, blank=False, null=False) diff --git a/api/src/backend/api/v1/views.py b/api/src/backend/api/v1/views.py index 90db12150a..65f41cfeab 100644 --- a/api/src/backend/api/v1/views.py +++ b/api/src/backend/api/v1/views.py @@ -702,7 +702,10 @@ def connection(self, request, pk=None): ) def destroy(self, request, *args, pk=None, **kwargs): - get_object_or_404(Provider, pk=pk) + provider = get_object_or_404(Provider, pk=pk) + provider.is_deleted = True + provider.save() + with transaction.atomic(): task = delete_provider_task.delay( provider_id=pk, tenant_id=request.tenant_id diff --git a/api/src/backend/tasks/jobs/deletion.py b/api/src/backend/tasks/jobs/deletion.py index b203cf113e..1d730d22d8 100644 --- a/api/src/backend/tasks/jobs/deletion.py +++ b/api/src/backend/tasks/jobs/deletion.py @@ -1,25 +1,46 @@ from celery.utils.log import get_task_logger +from django.db import transaction + +from api.db_utils import batch_delete +from api.models import Finding, Provider, Resource, Scan logger = get_task_logger(__name__) -def delete_instance(model, pk: str): +def delete_provider(pk: str): """ - Deletes an instance of the specified model. - - This function retrieves an instance of the provided model using its primary key - and deletes it from the database. + Gracefully deletes an instance of a provider along with its related data. Args: - model (Model): The Django model class from which to delete an instance. - pk (str): The primary key of the instance to delete. + pk (str): The primary key of the Provider instance to delete. Returns: - tuple: A tuple containing the number of objects deleted and a dictionary - with the count of deleted objects per model, - including related models if applicable. + dict: A dictionary with the count of deleted objects per model, + including related models. Raises: - model.DoesNotExist: If no instance with the provided primary key exists. + Provider.DoesNotExist: If no instance with the provided primary key exists. """ - return model.objects.get(pk=pk).delete() + instance = Provider.all_objects.get(pk=pk) + deletion_summary = {} + + with transaction.atomic(): + # Delete Findings + findings_qs = Finding.all_objects.filter(scan__provider=instance) + _, findings_summary = batch_delete(findings_qs) + deletion_summary.update(findings_summary) + + # Delete Resources + resources_qs = Resource.all_objects.filter(provider=instance) + _, resources_summary = batch_delete(resources_qs) + deletion_summary.update(resources_summary) + + # Delete Scans + scans_qs = Scan.all_objects.filter(provider=instance) + _, scans_summary = batch_delete(scans_qs) + deletion_summary.update(scans_summary) + + provider_deleted_count, provider_summary = instance.delete() + deletion_summary.update(provider_summary) + + return deletion_summary diff --git a/api/src/backend/tasks/tasks.py b/api/src/backend/tasks/tasks.py index 1237b543c5..dc5f6ad158 100644 --- a/api/src/backend/tasks/tasks.py +++ b/api/src/backend/tasks/tasks.py @@ -1,7 +1,7 @@ from celery import shared_task from config.celery import RLSTask from tasks.jobs.connection import check_provider_connection -from tasks.jobs.deletion import delete_instance +from tasks.jobs.deletion import delete_provider from tasks.jobs.scan import aggregate_findings, perform_prowler_scan from api.db_utils import tenant_transaction @@ -32,6 +32,8 @@ def delete_provider_task(provider_id: str): """ Task to delete a specific Provider instance. + It will delete in batches all the related resources first. + Args: provider_id (str): The primary key of the `Provider` instance to be deleted. @@ -41,7 +43,7 @@ def delete_provider_task(provider_id: str): - A dictionary with the count of deleted instances per model, including related models if cascading deletes were triggered. """ - return delete_instance(model=Provider, pk=provider_id) + return delete_provider(pk=provider_id) @shared_task(base=RLSTask, name="scan-perform", queue="scans") diff --git a/api/src/backend/tasks/tests/test_deletion.py b/api/src/backend/tasks/tests/test_deletion.py index 630d1d1fa1..27bfd5a781 100644 --- a/api/src/backend/tasks/tests/test_deletion.py +++ b/api/src/backend/tasks/tests/test_deletion.py @@ -1,15 +1,15 @@ import pytest from django.core.exceptions import ObjectDoesNotExist +from tasks.jobs.deletion import delete_provider from api.models import Provider -from tasks.jobs.deletion import delete_instance @pytest.mark.django_db class TestDeleteInstance: def test_delete_instance_success(self, providers_fixture): instance = providers_fixture[0] - result = delete_instance(Provider, instance.id) + result = delete_provider(instance.id) assert result with pytest.raises(ObjectDoesNotExist): @@ -19,4 +19,4 @@ def test_delete_instance_does_not_exist(self): non_existent_pk = "babf6796-cfcc-4fd3-9dcf-88d012247645" with pytest.raises(ObjectDoesNotExist): - delete_instance(Provider, non_existent_pk) + delete_provider(non_existent_pk)