diff --git a/django_pgschemas/management/commands/__init__.py b/django_pgschemas/management/commands/__init__.py index b20d6d1..a998989 100644 --- a/django_pgschemas/management/commands/__init__.py +++ b/django_pgschemas/management/commands/__init__.py @@ -5,7 +5,7 @@ from django.db.models.functions import Concat from django.db.utils import ProgrammingError -from ...schema import SchemaDescriptor +from ...schema import schema_handler from ...utils import create_schema, dynamic_models_exist, get_clone_reference, get_tenant_model from ._executors import parallel, sequential @@ -247,18 +247,8 @@ def handle(self, *args, **options): executor(schemas, self, "_raw_handle_tenant", args, options, pass_schema_in_kwargs=True) def _raw_handle_tenant(self, *args, **kwargs): - schema_name = kwargs.pop("schema_name") - if schema_name in settings.TENANTS: - domains = settings.TENANTS[schema_name].get("DOMAINS", []) - tenant = SchemaDescriptor.create(schema_name=schema_name, domain_url=domains[0] if domains else None) - self.handle_tenant(tenant, *args, **kwargs) - elif schema_name == get_clone_reference(): - tenant = SchemaDescriptor.create(schema_name=schema_name) - self.handle_tenant(tenant, *args, **kwargs) - else: - TenantModel = get_tenant_model() - tenant = TenantModel.objects.get(schema_name=schema_name) - self.handle_tenant(tenant, *args, **kwargs) + kwargs.pop("schema_name") + self.handle_tenant(schema_handler.active, *args, **kwargs) def handle_tenant(self, tenant, *args, **options): pass diff --git a/django_pgschemas/management/commands/_executors.py b/django_pgschemas/management/commands/_executors.py index 7d8cab6..0ce24d6 100644 --- a/django_pgschemas/management/commands/_executors.py +++ b/django_pgschemas/management/commands/_executors.py @@ -3,10 +3,11 @@ from django.conf import settings from django.core.management import call_command -from django.core.management.base import BaseCommand, CommandError, OutputWrapper +from django.core.management.base import BaseCommand, OutputWrapper from django.db import connection, connections, transaction -from ...schema import schema_handler +from ...schema import SchemaDescriptor, schema_handler +from ...utils import get_clone_reference, get_tenant_model def run_on_schema( @@ -54,7 +55,17 @@ def __call__(self, message): if fork_db: connections.close_all() - schema_handler.set_schema_to(schema_name) + + if schema_name in settings.TENANTS: + domains = settings.TENANTS[schema_name].get("DOMAINS", []) + schema = SchemaDescriptor.create(schema_name=schema_name, domain_url=domains[0] if domains else None) + elif schema_name == get_clone_reference(): + schema = SchemaDescriptor.create(schema_name=schema_name) + else: + TenantModel = get_tenant_model() + schema = TenantModel.objects.get(schema_name=schema_name) + + schema_handler.set_schema(schema) if pass_schema_in_kwargs: kwargs.update({"schema_name": schema_name})