diff --git a/neo4j-app/neo4j_app/core/neo4j/migrations/migrate.py b/neo4j-app/neo4j_app/core/neo4j/migrations/migrate.py index 1cf87131..66efb3e0 100644 --- a/neo4j-app/neo4j_app/core/neo4j/migrations/migrate.py +++ b/neo4j-app/neo4j_app/core/neo4j/migrations/migrate.py @@ -7,7 +7,8 @@ from datetime import datetime from distutils.version import StrictVersion from enum import Enum, unique -from typing import Any, Callable, List, Optional, Sequence +from inspect import signature +from typing import Any, Callable, List, Optional, Sequence, Union import neo4j from neo4j.exceptions import ConstraintError @@ -32,7 +33,9 @@ logger = logging.getLogger(__name__) -MigrationFn = Callable[[neo4j.AsyncTransaction], Coroutine] +TransactionFn = Callable[[neo4j.AsyncTransaction], Coroutine] +ExplicitTransactionFn = Callable[[neo4j.Session], Coroutine] +MigrationFn = Union[TransactionFn, ExplicitTransactionFn] _MIGRATION_TIMEOUT_MSG = """Migration timeout expired ! Please check that a migration is indeed in progress. If the application is in a \ @@ -118,7 +121,14 @@ async def _migrate_with_lock( ) # Then run to migration logger.debug("Acquired write lock for %s !", migration.label) - await project_session.execute_write(migration.migration_fn) + sig = signature(migration.migration_fn) + first_param = list(sig.parameters)[0] + if first_param == "tx": + await project_session.execute_write(migration.migration_fn) + elif first_param == "sess": + await migration.migration_fn(project_session) + else: + raise ValueError(f"Invalid migration function: {migration.migration_fn}") # Finally free the lock await registry_session.execute_write( complete_migration_tx, diff --git a/neo4j-app/neo4j_app/tests/core/neo4j/migrations/test_migrate.py b/neo4j-app/neo4j_app/tests/core/neo4j/migrations/test_migrate.py index 83a9940f..8970661e 100644 --- a/neo4j-app/neo4j_app/tests/core/neo4j/migrations/test_migrate.py +++ b/neo4j-app/neo4j_app/tests/core/neo4j/migrations/test_migrate.py @@ -49,18 +49,28 @@ async def _create_indexes_tx(tx: neo4j.AsyncTransaction): await tx.run(index_query_1) +async def _create_indexes(sess: neo4j.AsyncSession): + index_query_0 = "CREATE INDEX index0 IF NOT EXISTS FOR (n:Node) ON (n.attribute0)" + await sess.run(index_query_0) + index_query_1 = "CREATE INDEX index1 IF NOT EXISTS FOR (n:Node) ON (n.attribute1)" + await sess.run(index_query_1) + + async def _drop_constraint_tx(tx: neo4j.AsyncTransaction): drop_index_query = "DROP INDEX index0 IF EXISTS" await tx.run(drop_index_query) -# noinspection PyTypeChecker _MIGRATION_0 = Migration( version="0.2.0", label="create index and constraint", migration_fn=_create_indexes_tx, ) -# noinspection PyTypeChecker +_MIGRATION_0_EXPLICIT = Migration( + version="0.2.0", + label="create index and constraint", + migration_fn=_create_indexes, +) _MIGRATION_1 = Migration( version="0.3.0", label="drop constraint", @@ -75,6 +85,8 @@ async def _drop_constraint_tx(tx: neo4j.AsyncTransaction): ([], set(), set()), # Single ([_MIGRATION_0], {"index0", "index1"}, set()), + # Single as explicit_transaction + ([_MIGRATION_0_EXPLICIT], {"index0", "index1"}, set()), # Multiple ordered ([_MIGRATION_0, _MIGRATION_1], {"index1"}, {"index0"}), # Multiple unordered