diff --git a/django_pgschemas/management/commands/__init__.py b/django_pgschemas/management/commands/__init__.py index a998989..08dad27 100644 --- a/django_pgschemas/management/commands/__init__.py +++ b/django_pgschemas/management/commands/__init__.py @@ -1,7 +1,6 @@ from django.conf import settings from django.core.management.base import BaseCommand, CommandError -from django.db.models import CharField, Q -from django.db.models import Value as V +from django.db.models import CharField, Q, Value as V from django.db.models.functions import Concat from django.db.utils import ProgrammingError @@ -66,6 +65,9 @@ def add_arguments(self, parser): dest="excluded_schemas", help="Schema(s) to exclude when executing the current command", ) + parser.add_argument( + "--sdb", nargs="?", dest="schema_database", default="default", help="Database to operate with the schema(s)" + ) parser.add_argument( "--parallel", dest="parallel", @@ -80,6 +82,7 @@ def add_arguments(self, parser): ) def get_schemas_from_options(self, **options): + database = options.get("database") or options.get("schema_database") skip_schema_creation = options.get("skip_schema_creation", False) try: schemas = self._get_schemas_from_options(**options) @@ -96,7 +99,7 @@ def get_schemas_from_options(self, **options): raise CommandError("This command can only run in %s" % self.specific_schemas) if not skip_schema_creation: for schema in schemas: - create_schema(schema, check_if_exists=True, sync_schema=False, verbosity=0) + create_schema(schema, database, check_if_exists=True, sync_schema=False, verbosity=0) return schemas def get_executor_from_options(self, **options): @@ -106,6 +109,7 @@ def get_scope_display(self): return "|".join(self.specific_schemas or []) or self.scope def _get_schemas_from_options(self, **options): + database = options.get("database") or options.get("schema_database") schemas = options.get("schemas") or [] excluded_schemas = options.get("excluded_schemas") or [] include_all_schemas = options.get("all_schemas") or False @@ -139,9 +143,19 @@ def _get_schemas_from_options(self, **options): raise CommandError("No schema provided") TenantModel = get_tenant_model() - static_schemas = [x for x in settings.TENANTS.keys() if x != "default"] if allow_static else [] + static_schemas = ( + [ + x + for x in settings.TENANTS.keys() + if x != "default" and database in (settings.TENANTS[x].get("DATABASES") or ["default"]) + ] + if allow_static + else [] + ) dynamic_schemas = ( - TenantModel.objects.values_list("schema_name", flat=True) if dynamic_ready and allow_dynamic else [] + [x.schema_name for x in TenantModel.objects.all() if x.get_database() == database] + if dynamic_ready and allow_dynamic + else [] ) if clone_reference and allow_static: static_schemas.append(clone_reference) @@ -172,8 +186,10 @@ def _get_schemas_from_options(self, **options): schemas_to_return.add(schema) elif schema == clone_reference: schemas_to_return.add(schema) - elif dynamic_ready and TenantModel.objects.filter(schema_name=schema).exists() and allow_dynamic: - schemas_to_return.add(schema) + elif dynamic_ready and allow_dynamic: + tenant = TenantModel.objects.filter(schema_name=schema).first() + if tenant and tenant.get_database() == database: + schemas_to_return.add(schema) schemas = list(set(schemas) - schemas_to_return) @@ -187,13 +203,18 @@ def _get_schemas_from_options(self, **options): and any([x for x in data["DOMAINS"] if x.startswith(schema)]) ] if dynamic_ready and allow_dynamic: - local += ( - TenantModel.objects.annotate( - route=Concat("domains__domain", V("/"), "domains__folder", output_field=CharField()) - ) - .filter(Q(schema_name=schema) | Q(domains__domain__istartswith=schema) | Q(route=schema)) - .distinct() - .values_list("schema_name", flat=True) + local += list( + { + x.schema_name + for x in ( + TenantModel.objects.annotate( + route=Concat("domains__domain", V("/"), "domains__folder", output_field=CharField()) + ) + .filter(Q(schema_name=schema) | Q(domains__domain__istartswith=schema) | Q(route=schema)) + .distinct() + ) + if x.get_database() == database + } ) if not local: raise CommandError("No schema found for '%s'" % schema) @@ -215,13 +236,18 @@ def _get_schemas_from_options(self, **options): if schema_name not in ["public", "default", clone_reference] and any([x for x in data["DOMAINS"] if x.startswith(schema)]) ] - local += ( - TenantModel.objects.annotate( - route=Concat("domains__domain", V("/"), "domains__folder", output_field=CharField()) - ) - .filter(Q(schema_name=schema) | Q(domains__domain__istartswith=schema) | Q(route=schema)) - .distinct() - .values_list("schema_name", flat=True) + local += list( + { + x.schema_name + for x in ( + TenantModel.objects.annotate( + route=Concat("domains__domain", V("/"), "domains__folder", output_field=CharField()) + ) + .filter(Q(schema_name=schema) | Q(domains__domain__istartswith=schema) | Q(route=schema)) + .distinct() + ) + if x.get_database() == database + } ) if not local: raise CommandError("No schema found for '%s' (excluded)" % schema) diff --git a/django_pgschemas/management/commands/cloneschema.py b/django_pgschemas/management/commands/cloneschema.py index 6b6c958..06c2614 100644 --- a/django_pgschemas/management/commands/cloneschema.py +++ b/django_pgschemas/management/commands/cloneschema.py @@ -19,6 +19,9 @@ def add_arguments(self, parser): super().add_arguments(parser) parser.add_argument("source", help="The name of the schema you want to clone") parser.add_argument("destination", help="The name of the schema you want to create as clone") + parser.add_argument( + "--sdb", nargs="?", dest="schema_database", default="default", help="Database to operate with the schema(s)" + ) parser.add_argument( "--noinput", "--no-input", @@ -104,7 +107,7 @@ def handle(self, *args, **options): if TenantModel.objects.filter(schema_name=options["source"]).exists(): tenant, domain = self.get_dynamic_tenant(**options) try: - clone_schema(options["source"], options["destination"], dry_run) + clone_schema(options["source"], options["destination"], options["schema_database"], dry_run) if tenant and domain: if options["verbosity"] >= 1: self.stdout.write("Schema cloned.") diff --git a/django_pgschemas/management/commands/createrefschema.py b/django_pgschemas/management/commands/createrefschema.py index 8b83869..221883d 100644 --- a/django_pgschemas/management/commands/createrefschema.py +++ b/django_pgschemas/management/commands/createrefschema.py @@ -1,3 +1,4 @@ +from django.conf import settings from django.core.checks import Tags, run_checks from django.core.management.base import BaseCommand, CommandError @@ -20,18 +21,22 @@ def handle(self, *args, **options): clone_reference = get_clone_reference() if not clone_reference: raise CommandError("There is no reference schema configured.") - if options.get("recreate", False): - drop_schema(clone_reference, check_if_exists=True, verbosity=options["verbosity"]) + for database in settings.DATABASES: + if options.get("recreate", False): + drop_schema(clone_reference, database=database, check_if_exists=True, verbosity=options["verbosity"]) + if options["verbosity"] >= 1: + self.stdout.write(f"[{database}] Destroyed existing reference schema.") + created = create_schema( + clone_reference, database=database, check_if_exists=True, verbosity=options["verbosity"] + ) if options["verbosity"] >= 1: - self.stdout.write("Destroyed existing reference schema.") - created = create_schema(clone_reference, check_if_exists=True, verbosity=options["verbosity"]) - if options["verbosity"] >= 1: - if created: - self.stdout.write("Reference schema successfully created!") - else: - self.stdout.write("Reference schema already exists.") - self.stdout.write( - self.style.WARNING( - "Run this command again with --recreate if you want to recreate the reference schema." + if created: + self.stdout.write(f"[{database}] Reference schema successfully created!") + else: + self.stdout.write(f"[{database}] Reference schema already exists.") + self.stdout.write( + self.style.WARNING( + f"[{database}] Run this command again with --recreate if you want to " + "recreate the reference schema." + ) ) - ) diff --git a/django_pgschemas/management/commands/runschema.py b/django_pgschemas/management/commands/runschema.py index 2f40aac..ca48e25 100644 --- a/django_pgschemas/management/commands/runschema.py +++ b/django_pgschemas/management/commands/runschema.py @@ -75,6 +75,7 @@ def handle(self, *args, **options): options.pop("static_schemas") options.pop("dynamic_schemas") options.pop("tenant_schemas") + options.pop("schema_database") options.pop("parallel") options.pop("skip_schema_creation") if self.allow_interactive: diff --git a/django_pgschemas/models.py b/django_pgschemas/models.py index 6385226..d380334 100644 --- a/django_pgschemas/models.py +++ b/django_pgschemas/models.py @@ -50,7 +50,7 @@ def save(self, verbosity=1, *args, **kwargs): elif is_new: # Although we are not using the schema functions directly, the signal might be registered by a listener schema_needs_sync.send(sender=TenantMixin, tenant=self.serializable_fields()) - elif not is_new and self.auto_create_schema and not schema_exists(self.schema_name): + elif not is_new and self.auto_create_schema and not schema_exists(self.schema_name, self.get_database()): # Create schemas for existing models, deleting only the schema on failure try: self.create_schema(verbosity=verbosity) @@ -81,13 +81,15 @@ def create_schema(self, sync_schema=True, verbosity=1): """ Creates or clones the schema ``schema_name`` for this tenant. """ - return create_or_clone_schema(self.schema_name, sync_schema, verbosity) + return create_or_clone_schema( + self.schema_name, database=self.get_database(), sync_schema=sync_schema, verbosity=verbosity + ) def drop_schema(self): """ Drops the schema. """ - return drop_schema(self.schema_name) + return drop_schema(self.schema_name, database=self.get_database()) def get_primary_domain(self): try: diff --git a/django_pgschemas/routers.py b/django_pgschemas/routers.py index d53624f..966a5f8 100644 --- a/django_pgschemas/routers.py +++ b/django_pgschemas/routers.py @@ -2,12 +2,13 @@ from django.conf import settings from .schema import schema_handler -from .utils import get_tenant_database_alias +from .utils import get_clone_reference class SyncRouter(object): """ A router to control which applications will be synced depending on the schema we're syncing. + It also controls database for read/write in a tenant sharding configuration. """ def app_in_list(self, app_label, app_list): @@ -15,16 +16,30 @@ def app_in_list(self, app_label, app_list): app_config_full_name = "{}.{}".format(app_config.__module__, app_config.__class__.__name__) return (app_config.name in app_list) or (app_config_full_name in app_list) + def db_for_read(self, model, **hints): + if not schema_handler.active or schema_handler.active.schema_name in ["public", get_clone_reference()]: + return None + return schema_handler.active.get_database() + + def db_for_write(self, model, **hints): + if not schema_handler.active or schema_handler.active.schema_name in ["public", get_clone_reference()]: + return None + return schema_handler.active.get_database() + def allow_migrate(self, db, app_label, model_name=None, **hints): - if db != get_tenant_database_alias() or not schema_handler.active: + if not schema_handler.active: return False app_list = [] + databases = [] if schema_handler.active.schema_name == "public": app_list = settings.TENANTS["public"]["APPS"] + databases = settings.TENANTS["public"].get("DATABASES") or ["default"] elif schema_handler.active.schema_name in settings.TENANTS: app_list = settings.TENANTS[schema_handler.active.schema_name]["APPS"] + databases = settings.TENANTS[schema_handler.active.schema_name].get("DATABASES") or ["default"] else: app_list = settings.TENANTS["default"]["APPS"] - if not app_list: + databases = settings.TENANTS["default"].get("DATABASES") or ["default"] + if not app_list or not databases: return None - return self.app_in_list(app_label, app_list) + return db in databases and self.app_in_list(app_label, app_list) diff --git a/django_pgschemas/schema.py b/django_pgschemas/schema.py index b525596..a755609 100644 --- a/django_pgschemas/schema.py +++ b/django_pgschemas/schema.py @@ -16,10 +16,14 @@ def set_schema(self, schema_descriptor): """ Main API method to set current schema. """ + from django.contrib.contenttypes.models import ContentType + assert isinstance( schema_descriptor, SchemaDescriptor ), "'set_schema' must be called with a SchemaDescriptor descendant" + schema_descriptor.ready = False # Defines whether search path has been set + ContentType.objects.clear_cache() # Attempting to catch change of database self.set_active_schema(schema_descriptor) def set_schema_to(self, schema_name, domain_url=None, folder=None): @@ -93,3 +97,9 @@ def get_primary_domain(self): if self.domain_url: return "/".join([self.domain_url, self.folder]) if self.folder else self.domain_url return None + + def get_database(self): + """ + Returns the database to use for this schema. + """ + return "default" diff --git a/django_pgschemas/signals.py b/django_pgschemas/signals.py index 2fc8957..7d382f2 100644 --- a/django_pgschemas/signals.py +++ b/django_pgschemas/signals.py @@ -17,6 +17,6 @@ def tenant_delete_callback(sender, instance, **kwargs): if not isinstance(instance, get_tenant_model()): return - if instance.auto_drop_schema and schema_exists(instance.schema_name): + if instance.auto_drop_schema and schema_exists(instance.schema_name, instance.get_database()): schema_pre_drop.send(sender=get_tenant_model(), tenant=instance.serializable_fields()) instance.drop_schema() diff --git a/django_pgschemas/utils.py b/django_pgschemas/utils.py index e96f199..bef6c41 100644 --- a/django_pgschemas/utils.py +++ b/django_pgschemas/utils.py @@ -17,10 +17,6 @@ def get_domain_model(require_ready=True): return apps.get_model(settings.TENANTS["default"]["DOMAIN_MODEL"], require_ready=require_ready) -def get_tenant_database_alias(): - return getattr(settings, "PGSCHEMAS_TENANT_DB_ALIAS", DEFAULT_DB_ALIAS) - - def get_limit_set_calls(): return getattr(settings, "PGSCHEMAS_LIMIT_SET_CALLS", False) @@ -83,7 +79,7 @@ def wrapper(*args, **kwargs): return wrapper -def schema_exists(schema_name): +def schema_exists(schema_name, database): "Checks if a schema exists in database." sql = """ SELECT EXISTS( @@ -92,7 +88,7 @@ def schema_exists(schema_name): WHERE LOWER(nspname) = LOWER(%s) ) """ - cursor = connection.cursor() + cursor = connections[database].cursor() cursor.execute(sql, (schema_name,)) row = cursor.fetchone() if row: @@ -120,32 +116,32 @@ def dynamic_models_exist(): @run_in_public_schema -def create_schema(schema_name, check_if_exists=False, sync_schema=True, verbosity=1): +def create_schema(schema_name, database, check_if_exists=False, sync_schema=True, verbosity=1): """ Creates the schema ``schema_name``. Optionally checks if the schema already exists before creating it. Returns ``True`` if the schema was created, ``False`` otherwise. """ check_schema_name(schema_name) - if check_if_exists and schema_exists(schema_name): + if check_if_exists and schema_exists(schema_name, database): return False - cursor = connection.cursor() + cursor = connections[database].cursor() cursor.execute("CREATE SCHEMA %s" % schema_name) cursor.close() if sync_schema: - call_command("migrateschema", schemas=[schema_name], verbosity=verbosity) + call_command("migrateschema", schemas=[schema_name], database=database, verbosity=verbosity) return True @run_in_public_schema -def drop_schema(schema_name, check_if_exists=True, verbosity=1): +def drop_schema(schema_name, database, check_if_exists=True, verbosity=1): """ Drops the schema. Optionally checks if the schema already exists before dropping it. """ - if check_if_exists and not schema_exists(schema_name): + if check_if_exists and not schema_exists(schema_name, database): return False - cursor = connection.cursor() + cursor = connections[database].cursor() cursor.execute("DROP SCHEMA %s CASCADE" % schema_name) cursor.close() return True @@ -359,31 +355,31 @@ class DryRunException(Exception): pass -def _create_clone_schema_function(): +def _create_clone_schema_function(database): """ Creates a postgres function `clone_schema` that copies a schema and its contents. Will replace any existing `clone_schema` functions owned by the `postgres` superuser. """ - cursor = connection.cursor() + cursor = connections[database].cursor() cursor.execute(CLONE_SCHEMA_FUNCTION) cursor.close() @run_in_public_schema -def clone_schema(base_schema_name, new_schema_name, dry_run=False): +def clone_schema(base_schema_name, new_schema_name, database, dry_run=False): """ Creates a new schema ``new_schema_name`` as a clone of an existing schema ``base_schema_name``. """ check_schema_name(new_schema_name) - cursor = connection.cursor() + cursor = connections[database].cursor() # check if the clone_schema function already exists in the db try: cursor.execute("SELECT 'clone_schema'::regproc") except ProgrammingError: # pragma: no cover - _create_clone_schema_function() + _create_clone_schema_function(database) transaction.commit() try: @@ -397,17 +393,19 @@ def clone_schema(base_schema_name, new_schema_name, dry_run=False): cursor.close() -def create_or_clone_schema(schema_name, sync_schema=True, verbosity=1): +def create_or_clone_schema(schema_name, database, sync_schema=True, verbosity=1): """ Creates the schema ``schema_name``. Optionally checks if the schema already exists before creating it. Returns ``True`` if the schema was created, ``False`` otherwise. """ check_schema_name(schema_name) - if schema_exists(schema_name): + if schema_exists(schema_name, database): return False clone_reference = get_clone_reference() - if clone_reference and schema_exists(clone_reference) and not django_is_in_test_mode(): # pragma: no cover - clone_schema(clone_reference, schema_name) + if ( + clone_reference and schema_exists(clone_reference, database) and not django_is_in_test_mode() + ): # pragma: no cover + clone_schema(clone_reference, schema_name, database) return True - return create_schema(schema_name, sync_schema=sync_schema, verbosity=verbosity) + return create_schema(schema_name, database, sync_schema=sync_schema, verbosity=verbosity) diff --git a/docs/settings.rst b/docs/settings.rst index d2f05fe..69e950f 100644 --- a/docs/settings.rst +++ b/docs/settings.rst @@ -93,13 +93,6 @@ control the max number of processes the parallel executor can spawn. By default, ``None`` means that the number of CPUs will be used. -``PGSCHEMAS_TENANT_DB_ALIAS`` ------------------------------ - -Default: ``"default"`` - -The database alias where the tenant configuration is going to take place. - ``PGSCHEMAS_PATHNAME_FUNCTION`` ------------------------------- diff --git a/dpgs_sandbox/tests/test_schema_creation_commands.py b/dpgs_sandbox/tests/test_schema_creation_commands.py index b9ea7b8..dade765 100644 --- a/dpgs_sandbox/tests/test_schema_creation_commands.py +++ b/dpgs_sandbox/tests/test_schema_creation_commands.py @@ -21,29 +21,29 @@ def test_cloneschema(self): @utils.run_in_public_schema def fixup(): - utils._create_clone_schema_function() + utils._create_clone_schema_function("default") fixup() - self.assertFalse(utils.schema_exists("cloned")) + self.assertFalse(utils.schema_exists("cloned", "default")) call_command("cloneschema", "sample", "cloned", verbosity=0) # All good - self.assertTrue(utils.schema_exists("cloned")) + self.assertTrue(utils.schema_exists("cloned", "default")) with self.assertRaises(CommandError): # Existing destination call_command("cloneschema", "sample", "cloned", verbosity=0) with self.assertRaises(CommandError): # Not existing source call_command("cloneschema", "nonexisting", "newschema", verbosity=0) - utils.drop_schema("cloned") + utils.drop_schema("cloned", "default") def test_createrefschema(self): "Tests 'createrefschema' command" - utils.drop_schema("cloned") + utils.drop_schema("cloned", "default") call_command("createrefschema", verbosity=0) # All good - self.assertTrue(utils.schema_exists("sample")) - utils.drop_schema("cloned") + self.assertTrue(utils.schema_exists("sample", "default")) + utils.drop_schema("cloned", "default") call_command("createrefschema", recreate=True, verbosity=0) # All good too - self.assertTrue(utils.schema_exists("sample")) - utils.drop_schema("cloned") + self.assertTrue(utils.schema_exists("sample", "default")) + utils.drop_schema("cloned", "default") call_command("createrefschema", recreate=True, verbosity=0) # All good too - self.assertTrue(utils.schema_exists("sample")) + self.assertTrue(utils.schema_exists("sample", "default")) class InteractiveCloneSchemaTestCase(TransactionTestCase): @@ -79,4 +79,4 @@ def patched_input(*args, **kwargs): with StringIO() as stdout: with StringIO() as stderr: call_command("cloneschema", "tenant1", "tenant2", verbosity=1, stdout=stdout, stderr=stderr) - self.assertTrue(utils.schema_exists("tenant2")) + self.assertTrue(utils.schema_exists("tenant2", "default")) diff --git a/dpgs_sandbox/tests/test_signals.py b/dpgs_sandbox/tests/test_signals.py index fe324e0..bb3482b 100644 --- a/dpgs_sandbox/tests/test_signals.py +++ b/dpgs_sandbox/tests/test_signals.py @@ -17,7 +17,7 @@ def test_tenant_delete_callback(self): tenant = TenantModel(schema_name="tenant1") tenant.save() tenant.create_schema(sync_schema=False) - self.assertTrue(schema_exists("tenant1")) + self.assertTrue(schema_exists("tenant1", "default")) TenantModel.objects.all().delete() - self.assertFalse(schema_exists("tenant1")) + self.assertFalse(schema_exists("tenant1", "default")) TenantModel.auto_create_schema, TenantModel.auto_drop_schema = backup_create, backup_drop diff --git a/dpgs_sandbox/tests/test_tenants.py b/dpgs_sandbox/tests/test_tenants.py index 92e1ec4..ddf111c 100644 --- a/dpgs_sandbox/tests/test_tenants.py +++ b/dpgs_sandbox/tests/test_tenants.py @@ -28,27 +28,27 @@ class TenantAutomaticTestCase(TransactionTestCase): def test_new_creation_deletion(self): "Tests automatic creation/deletion for new tenant's save/delete" - self.assertFalse(schema_exists("tenant1")) + self.assertFalse(schema_exists("tenant1", "default")) tenant = TenantModel(schema_name="tenant1") tenant.save(verbosity=0) - self.assertTrue(schema_exists("tenant1")) + self.assertTrue(schema_exists("tenant1", "default")) # Self-cleanup tenant.delete(force_drop=True) - self.assertFalse(schema_exists("tenant1")) + self.assertFalse(schema_exists("tenant1", "default")) def test_existing_creation(self): "Tests automatic creation for existing tenant's save" - self.assertFalse(schema_exists("tenant1")) + self.assertFalse(schema_exists("tenant1", "default")) tenant = TenantModel(schema_name="tenant1") tenant.auto_create_schema = False tenant.save(verbosity=0) - self.assertFalse(schema_exists("tenant1")) + self.assertFalse(schema_exists("tenant1", "default")) tenant.auto_create_schema = True tenant.save(verbosity=0) - self.assertTrue(schema_exists("tenant1")) + self.assertTrue(schema_exists("tenant1", "default")) # Self-cleanup tenant.delete(force_drop=True) - self.assertFalse(schema_exists("tenant1")) + self.assertFalse(schema_exists("tenant1", "default")) def test_new_aborted_creation(self): "Tests recovery on automatic creation for new tenant's save" @@ -56,12 +56,12 @@ def test_new_aborted_creation(self): def signal_receiver(*args, **kwargs): raise Exception - self.assertFalse(schema_exists("tenant1")) + self.assertFalse(schema_exists("tenant1", "default")) tenant = TenantModel(schema_name="tenant1") schema_post_sync.connect(signal_receiver) with self.assertRaises(Exception): tenant.save(verbosity=0) - self.assertFalse(schema_exists("tenant1")) + self.assertFalse(schema_exists("tenant1", "default")) self.assertEqual(0, TenantModel.objects.count()) schema_post_sync.disconnect(signal_receiver) @@ -71,7 +71,7 @@ def test_existing_aborted_creation(self): def signal_receiver(*args, **kwargs): raise Exception - self.assertFalse(schema_exists("tenant1")) + self.assertFalse(schema_exists("tenant1", "default")) tenant = TenantModel(schema_name="tenant1") tenant.auto_create_schema = False tenant.save(verbosity=0) @@ -79,7 +79,7 @@ def signal_receiver(*args, **kwargs): schema_post_sync.connect(signal_receiver) with self.assertRaises(Exception): tenant.save(verbosity=0) - self.assertFalse(schema_exists("tenant1")) + self.assertFalse(schema_exists("tenant1", "default")) self.assertEqual(1, TenantModel.objects.count()) schema_post_sync.disconnect(signal_receiver) # Self-cleanup @@ -122,8 +122,8 @@ def tearDownClass(cls): for key in settings.TENANTS: if key == "default": continue - drop_schema(key) - drop_schema("tenant") + drop_schema(key, "default") + drop_schema("tenant", "default") call_command("migrateschema", verbosity=0) @contextmanager diff --git a/dpgs_sandbox/tests/test_utils.py b/dpgs_sandbox/tests/test_utils.py index 9fdf928..bdcb25e 100644 --- a/dpgs_sandbox/tests/test_utils.py +++ b/dpgs_sandbox/tests/test_utils.py @@ -22,11 +22,6 @@ def test_get_tenant_model(self): def test_get_domain_model(self): self.assertEqual(utils.get_domain_model()._meta.model_name, "domain") - def test_get_tenant_database_alias(self): - self.assertEqual(utils.get_tenant_database_alias(), "default") - with override_settings(PGSCHEMAS_TENANT_DB_ALIAS="something"): - self.assertEqual(utils.get_tenant_database_alias(), "something") - def test_get_limit_set_calls(self): self.assertFalse(utils.get_limit_set_calls()) with override_settings(PGSCHEMAS_LIMIT_SET_CALLS=True): @@ -78,38 +73,38 @@ def inner(): cursor.close() def test_schema_exists(self): - self.assertTrue(utils.schema_exists("public")) - self.assertTrue(utils.schema_exists("www")) - self.assertTrue(utils.schema_exists("blog")) - self.assertTrue(utils.schema_exists("sample")) - self.assertFalse(utils.schema_exists("default")) - self.assertFalse(utils.schema_exists("tenant")) + self.assertTrue(utils.schema_exists("public", "default")) + self.assertTrue(utils.schema_exists("www", "default")) + self.assertTrue(utils.schema_exists("blog", "default")) + self.assertTrue(utils.schema_exists("sample", "default")) + self.assertFalse(utils.schema_exists("default", "default")) + self.assertFalse(utils.schema_exists("tenant", "default")) def test_dynamic_models_exist(self): self.assertTrue(utils.dynamic_models_exist()) - utils.drop_schema("public") + utils.drop_schema("public", "default") self.assertFalse(utils.dynamic_models_exist()) def test_create_drop_schema(self): - self.assertFalse(utils.create_schema("public", check_if_exists=True)) # Schema existed already - self.assertTrue(utils.schema_exists("public")) # Schema exists - self.assertTrue(utils.drop_schema("public")) # Schema was dropped - self.assertFalse(utils.drop_schema("public")) # Schema no longer exists - self.assertFalse(utils.schema_exists("public")) # Schema doesn't exist - self.assertTrue(utils.create_schema("public", sync_schema=False)) # Schema was created - self.assertTrue(utils.schema_exists("public")) # Schema exists + self.assertFalse(utils.create_schema("public", "default", check_if_exists=True)) # Schema existed already + self.assertTrue(utils.schema_exists("public", "default")) # Schema exists + self.assertTrue(utils.drop_schema("public", "default")) # Schema was dropped + self.assertFalse(utils.drop_schema("public", "default")) # Schema no longer exists + self.assertFalse(utils.schema_exists("public", "default")) # Schema doesn't exist + self.assertTrue(utils.create_schema("public", "default", sync_schema=False)) # Schema was created + self.assertTrue(utils.schema_exists("public", "default")) # Schema exists def test_clone_schema(self): with schema.SchemaDescriptor.create(schema_name="public"): - utils._create_clone_schema_function() - self.assertFalse(utils.schema_exists("sample2")) # Schema doesn't exist previously - utils.clone_schema("sample", "sample2", dry_run=True) # Dry run - self.assertFalse(utils.schema_exists("sample2")) # Schema won't exist, dry run - utils.clone_schema("sample", "sample2") # Real run, schema was cloned - self.assertTrue(utils.schema_exists("sample2")) # Schema exists + utils._create_clone_schema_function("default") + self.assertFalse(utils.schema_exists("sample2", "default")) # Schema doesn't exist previously + utils.clone_schema("sample", "sample2", "default", dry_run=True) # Dry run + self.assertFalse(utils.schema_exists("sample2", "default")) # Schema won't exist, dry run + utils.clone_schema("sample", "sample2", "default") # Real run, schema was cloned + self.assertTrue(utils.schema_exists("sample2", "default")) # Schema exists with self.assertRaises(InternalError): - utils.clone_schema("sample", "sample2") # Schema already exists, error - self.assertTrue(utils.schema_exists("sample2")) # Schema still exists + utils.clone_schema("sample", "sample2", "default") # Schema already exists, error + self.assertTrue(utils.schema_exists("sample2", "default")) # Schema still exists def test_create_or_clone_schema(self): - self.assertFalse(utils.create_or_clone_schema("sample")) # Schema existed + self.assertFalse(utils.create_or_clone_schema("sample", "default")) # Schema existed