Skip to content

Commit

Permalink
feat(Provider): add soft deletion for providers and related resources (
Browse files Browse the repository at this point in the history
  • Loading branch information
vicferpoy authored Nov 29, 2024
1 parent c1d6021 commit 89a7128
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 22 deletions.
36 changes: 33 additions & 3 deletions api/src/backend/api/db_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import secrets
from contextlib import contextmanager
from datetime import datetime, timezone, timedelta
from datetime import datetime, timedelta, timezone

from django.conf import settings
from django.contrib.auth.models import BaseUserManager
from django.db import models, transaction, connection
from django.core.paginator import Paginator
from django.db import connection, models, transaction
from psycopg2 import connect as psycopg2_connect
from psycopg2.extensions import new_type, register_type, register_adapter, AsIs
from psycopg2.extensions import AsIs, new_type, register_adapter, register_type

DB_USER = settings.DATABASES["default"]["USER"] if not settings.TESTING else "test"
DB_PASSWORD = (
Expand Down Expand Up @@ -88,6 +89,35 @@ def generate_random_token(length: int = 14, symbols: str | None = None) -> str:
return "".join(secrets.choice(symbols or _symbols) for _ in range(length))


def batch_delete(queryset, batch_size=5000):
"""
Deletes objects in batches and returns the total number of deletions and a summary.
Args:
queryset (QuerySet): The queryset of objects to delete.
batch_size (int): The number of objects to delete in each batch.
Returns:
tuple: (total_deleted, deletion_summary)
"""
total_deleted = 0
deletion_summary = {}

paginator = Paginator(queryset.order_by("id").only("id"), batch_size)

for page_num in paginator.page_range:
batch_ids = [obj.id for obj in paginator.page(page_num).object_list]

deleted_count, deleted_info = queryset.filter(id__in=batch_ids).delete()

total_deleted += deleted_count

for model_label, count in deleted_info.items():
deletion_summary[model_label] = deletion_summary.get(model_label, 0) + count

return total_deleted, deletion_summary


# Postgres Enums


Expand Down
3 changes: 2 additions & 1 deletion api/src/backend/api/migrations/0001_initial.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ class Migration(migrations.Migration):
),
("inserted_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("is_deleted", models.BooleanField(default=False)),
(
"provider",
ProviderEnumField(
Expand Down Expand Up @@ -1093,7 +1094,7 @@ class Migration(migrations.Migration):
},
bases=(PostgresPartitionedModel,),
managers=[
("objects", PostgresManager()),
("objects", api.models.ActiveProviderPartitionedManager()),
],
),
migrations.RunSQL(
Expand Down
45 changes: 45 additions & 0 deletions api/src/backend/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from django.contrib.postgres.search import SearchVector, SearchVectorField
from django.core.validators import MinLengthValidator
from django.db import models
from django.db.models import Q
from django.utils.translation import gettext_lazy as _
from django_celery_results.models import TaskResult
from psqlextra.manager import PostgresManager
from psqlextra.models import PostgresPartitionedModel
from psqlextra.types import PostgresPartitioningMethod
from uuid6 import uuid7
Expand Down Expand Up @@ -67,6 +69,24 @@ class StateChoices(models.TextChoices):
CANCELLED = "cancelled", _("Cancelled")


class ActiveProviderManager(models.Manager):
def get_queryset(self):
return super().get_queryset().filter(self.active_provider_filter())

def active_provider_filter(self):
if self.model is Provider:
return Q(is_deleted=False)
elif self.model in [Finding, ComplianceOverview, ScanSummary]:
return Q(scan__provider__is_deleted=False)
else:
return Q(provider__is_deleted=False)


class ActiveProviderPartitionedManager(PostgresManager, ActiveProviderManager):
def get_queryset(self):
return super().get_queryset().filter(self.active_provider_filter())


class User(AbstractBaseUser):
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
name = models.CharField(max_length=150, validators=[MinLengthValidator(3)])
Expand Down Expand Up @@ -147,6 +167,9 @@ class JSONAPIMeta:


class Provider(RowLevelSecurityProtectedModel):
objects = ActiveProviderManager()
all_objects = models.Manager()

class ProviderChoices(models.TextChoices):
AWS = "aws", _("AWS")
AZURE = "azure", _("Azure")
Expand Down Expand Up @@ -202,6 +225,7 @@ def validate_kubernetes_uid(value):
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
updated_at = models.DateTimeField(auto_now=True, editable=False)
is_deleted = models.BooleanField(default=False)
provider = ProviderEnumField(
choices=ProviderChoices.choices, default=ProviderChoices.AWS
)
Expand Down Expand Up @@ -274,6 +298,9 @@ class JSONAPIMeta:


class ProviderGroupMembership(RowLevelSecurityProtectedModel):
objects = ActiveProviderManager()
all_objects = models.Manager()

id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
provider = models.ForeignKey(
Provider,
Expand Down Expand Up @@ -338,6 +365,9 @@ class JSONAPIMeta:


class Scan(RowLevelSecurityProtectedModel):
objects = ActiveProviderManager()
all_objects = models.Manager()

class TriggerChoices(models.TextChoices):
SCHEDULED = "scheduled", _("Scheduled")
MANUAL = "manual", _("Manual")
Expand Down Expand Up @@ -435,6 +465,9 @@ class Meta(RowLevelSecurityProtectedModel.Meta):


class Resource(RowLevelSecurityProtectedModel):
objects = ActiveProviderManager()
all_objects = models.Manager()

id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
updated_at = models.DateTimeField(auto_now=True, editable=False)
Expand Down Expand Up @@ -561,6 +594,9 @@ class Finding(PostgresPartitionedModel, RowLevelSecurityProtectedModel):
Note when creating migrations, you must use `python manage.py pgmakemigrations` to create the migrations.
"""

objects = ActiveProviderPartitionedManager()
all_objects = models.Manager()

class PartitioningMeta:
method = PostgresPartitioningMethod.RANGE
key = ["id"]
Expand Down Expand Up @@ -712,6 +748,9 @@ class Meta(RowLevelSecurityProtectedModel.Meta):


class ProviderSecret(RowLevelSecurityProtectedModel):
objects = ActiveProviderManager()
all_objects = models.Manager()

class TypeChoices(models.TextChoices):
STATIC = "static", _("Key-value pairs")
ROLE = "role", _("Role assumption")
Expand Down Expand Up @@ -812,6 +851,9 @@ class JSONAPIMeta:


class ComplianceOverview(RowLevelSecurityProtectedModel):
objects = ActiveProviderManager()
all_objects = models.Manager()

id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
compliance_id = models.CharField(max_length=100, blank=False, null=False)
Expand Down Expand Up @@ -861,6 +903,9 @@ class JSONAPIMeta:


class ScanSummary(RowLevelSecurityProtectedModel):
objects = ActiveProviderManager()
all_objects = models.Manager()

id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
inserted_at = models.DateTimeField(auto_now_add=True, editable=False)
check_id = models.CharField(max_length=100, blank=False, null=False)
Expand Down
5 changes: 4 additions & 1 deletion api/src/backend/api/v1/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,10 @@ def connection(self, request, pk=None):
)

def destroy(self, request, *args, pk=None, **kwargs):
get_object_or_404(Provider, pk=pk)
provider = get_object_or_404(Provider, pk=pk)
provider.is_deleted = True
provider.save()

with transaction.atomic():
task = delete_provider_task.delay(
provider_id=pk, tenant_id=request.tenant_id
Expand Down
45 changes: 33 additions & 12 deletions api/src/backend/tasks/jobs/deletion.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,46 @@
from celery.utils.log import get_task_logger
from django.db import transaction

from api.db_utils import batch_delete
from api.models import Finding, Provider, Resource, Scan

logger = get_task_logger(__name__)


def delete_instance(model, pk: str):
def delete_provider(pk: str):
"""
Deletes an instance of the specified model.
This function retrieves an instance of the provided model using its primary key
and deletes it from the database.
Gracefully deletes an instance of a provider along with its related data.
Args:
model (Model): The Django model class from which to delete an instance.
pk (str): The primary key of the instance to delete.
pk (str): The primary key of the Provider instance to delete.
Returns:
tuple: A tuple containing the number of objects deleted and a dictionary
with the count of deleted objects per model,
including related models if applicable.
dict: A dictionary with the count of deleted objects per model,
including related models.
Raises:
model.DoesNotExist: If no instance with the provided primary key exists.
Provider.DoesNotExist: If no instance with the provided primary key exists.
"""
return model.objects.get(pk=pk).delete()
instance = Provider.all_objects.get(pk=pk)
deletion_summary = {}

with transaction.atomic():
# Delete Findings
findings_qs = Finding.all_objects.filter(scan__provider=instance)
_, findings_summary = batch_delete(findings_qs)
deletion_summary.update(findings_summary)

# Delete Resources
resources_qs = Resource.all_objects.filter(provider=instance)
_, resources_summary = batch_delete(resources_qs)
deletion_summary.update(resources_summary)

# Delete Scans
scans_qs = Scan.all_objects.filter(provider=instance)
_, scans_summary = batch_delete(scans_qs)
deletion_summary.update(scans_summary)

provider_deleted_count, provider_summary = instance.delete()
deletion_summary.update(provider_summary)

return deletion_summary
6 changes: 4 additions & 2 deletions api/src/backend/tasks/tasks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from celery import shared_task
from config.celery import RLSTask
from tasks.jobs.connection import check_provider_connection
from tasks.jobs.deletion import delete_instance
from tasks.jobs.deletion import delete_provider
from tasks.jobs.scan import aggregate_findings, perform_prowler_scan

from api.db_utils import tenant_transaction
Expand Down Expand Up @@ -32,6 +32,8 @@ def delete_provider_task(provider_id: str):
"""
Task to delete a specific Provider instance.
It will delete in batches all the related resources first.
Args:
provider_id (str): The primary key of the `Provider` instance to be deleted.
Expand All @@ -41,7 +43,7 @@ def delete_provider_task(provider_id: str):
- A dictionary with the count of deleted instances per model,
including related models if cascading deletes were triggered.
"""
return delete_instance(model=Provider, pk=provider_id)
return delete_provider(pk=provider_id)


@shared_task(base=RLSTask, name="scan-perform", queue="scans")
Expand Down
6 changes: 3 additions & 3 deletions api/src/backend/tasks/tests/test_deletion.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import pytest
from django.core.exceptions import ObjectDoesNotExist
from tasks.jobs.deletion import delete_provider

from api.models import Provider
from tasks.jobs.deletion import delete_instance


@pytest.mark.django_db
class TestDeleteInstance:
def test_delete_instance_success(self, providers_fixture):
instance = providers_fixture[0]
result = delete_instance(Provider, instance.id)
result = delete_provider(instance.id)

assert result
with pytest.raises(ObjectDoesNotExist):
Expand All @@ -19,4 +19,4 @@ def test_delete_instance_does_not_exist(self):
non_existent_pk = "babf6796-cfcc-4fd3-9dcf-88d012247645"

with pytest.raises(ObjectDoesNotExist):
delete_instance(Provider, non_existent_pk)
delete_provider(non_existent_pk)

0 comments on commit 89a7128

Please sign in to comment.