diff --git a/tortoise/contrib/test/__init__.py b/tortoise/contrib/test/__init__.py index 8108bb3fb..e598e50a1 100644 --- a/tortoise/contrib/test/__init__.py +++ b/tortoise/contrib/test/__init__.py @@ -43,6 +43,7 @@ _LOOP: AbstractEventLoop = None # type: ignore _MODULES: Iterable[Union[str, ModuleType]] = [] _CONN_CONFIG: dict = {} +_APP_LABEL = None def getDBConfig(app_label: str, modules: Iterable[Union[str, ModuleType]]) -> dict: @@ -103,7 +104,9 @@ def initializer( global _TORTOISE_TEST_DB global _MODULES global _CONN_CONFIG + global _APP_LABEL _MODULES = modules + _APP_LABEL = app_label if db_url is not None: # pragma: nobranch _TORTOISE_TEST_DB = db_url _CONFIG = getDBConfig(app_label=app_label, modules=_MODULES) @@ -247,7 +250,7 @@ class IsolatedTestCase(SimpleTestCase): async def _setUpDB(self) -> None: await super()._setUpDB() - config = getDBConfig(app_label="models", modules=self.tortoise_test_modules or _MODULES) + config = getDBConfig(app_label=_APP_LABEL, modules=self.tortoise_test_modules or _MODULES) await Tortoise.init(config, _create_db=True) await Tortoise.generate_schemas(safe=False) @@ -327,7 +330,7 @@ class TestCase(TruncationTestCase): async def asyncSetUp(self) -> None: await super().asyncSetUp() - self._db = connections.get("models") + self._db = connections.get(_APP_LABEL) self._transaction = TransactionTestContext(self._db._in_transaction().connection) await self._transaction.__aenter__() # type: ignore