diff --git a/src/aleph/api_entrypoint.py b/src/aleph/api_entrypoint.py index 106e98b83..44273babb 100644 --- a/src/aleph/api_entrypoint.py +++ b/src/aleph/api_entrypoint.py @@ -7,7 +7,7 @@ import aleph.config from aleph.chains.signature_verifier import SignatureVerifier -from aleph.db.connection import make_engine, make_session_factory +from aleph.db.connection import make_async_engine, make_async_session_factory from aleph.services.cache.node_cache import NodeCache from aleph.services.ipfs import IpfsService from aleph.services.p2p import init_p2p_client @@ -34,12 +34,12 @@ async def configure_aiohttp_app( with sentry_sdk.start_transaction(name="init-api-server"): p2p_client = await init_p2p_client(config, service_name="api-server-aiohttp") - engine = make_engine( + engine = make_async_engine( config, echo=config.logging.level.value == logging.DEBUG, application_name="aleph-api", ) - session_factory = make_session_factory(engine) + session_factory = make_async_session_factory(engine) node_cache = NodeCache( redis_host=config.redis.host.value, redis_port=config.redis.port.value diff --git a/src/aleph/chains/bsc.py b/src/aleph/chains/bsc.py index 56dff95c3..4e54ca68d 100644 --- a/src/aleph/chains/bsc.py +++ b/src/aleph/chains/bsc.py @@ -5,13 +5,13 @@ from aleph.chains.chain_data_service import PendingTxPublisher from aleph.chains.indexer_reader import AlephIndexerReader from aleph.types.chain_sync import ChainEventType -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory class BscConnector(ChainReader): def __init__( self, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, pending_tx_publisher: PendingTxPublisher, ): self.indexer_reader = AlephIndexerReader( diff --git a/src/aleph/chains/chain_data_service.py b/src/aleph/chains/chain_data_service.py index 0ce3b3e33..ccc62effc 100644 --- a/src/aleph/chains/chain_data_service.py +++ b/src/aleph/chains/chain_data_service.py @@ -31,7 +31,7 @@ from aleph.storage import StorageService from aleph.toolkit.timestamp import utc_now from aleph.types.chain_sync import ChainSyncProtocol -from aleph.types.db_session import DbSession, DbSessionFactory +from aleph.types.db_session import AsyncDbSession, AsyncDbSessionFactory from aleph.types.files import FileType from aleph.utils import get_sha256 @@ -39,14 +39,14 @@ class ChainDataService: def __init__( self, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, storage_service: StorageService, ): self.session_factory = session_factory self.storage_service = storage_service async def prepare_sync_event_payload( - self, session: DbSession, messages: List[MessageDb] + self, session: AsyncDbSession, messages: List[MessageDb] ) -> OffChainSyncEventPayload: """ Returns the payload of a sync event to be published on chain. @@ -129,22 +129,22 @@ async def _get_tx_messages_off_chain_protocol( LOGGER.info("Got bulk data with %d items" % len(messages)) if config.ipfs.enabled.value: try: - with self.session_factory() as session: + async with self.session_factory() as session: # Some chain data files are duplicated, and can be treated in parallel, # hence the upsert. - upsert_file( + await upsert_file( session=session, file_hash=sync_file_content.hash, file_type=FileType.FILE, size=len(sync_file_content.raw_value), ) - upsert_tx_file_pin( + await upsert_tx_file_pin( session=session, file_hash=file_hash, tx_hash=tx.hash, created=utc_now(), ) - session.commit() + await session.commit() # Some IPFS fetches can take a while, hence the large timeout. await asyncio.wait_for( @@ -246,9 +246,9 @@ def __init__(self, pending_tx_exchange: aio_pika.abc.AbstractExchange): self.pending_tx_exchange = pending_tx_exchange @staticmethod - def add_pending_tx(session: DbSession, tx: ChainTxDb): - upsert_chain_tx(session=session, tx=tx) - upsert_pending_tx(session=session, tx_hash=tx.hash) + async def add_pending_tx(session: AsyncDbSession, tx: ChainTxDb): + await upsert_chain_tx(session=session, tx=tx) + await upsert_pending_tx(session=session, tx_hash=tx.hash) async def publish_pending_tx(self, tx: ChainTxDb): message = aio_pika.Message(body=tx.hash.encode("utf-8")) @@ -256,7 +256,7 @@ async def publish_pending_tx(self, tx: ChainTxDb): message=message, routing_key=f"{tx.chain.value}.{tx.publisher}.{tx.hash}" ) - async def add_and_publish_pending_tx(self, session: DbSession, tx: ChainTxDb): + async def add_and_publish_pending_tx(self, session: AsyncDbSession, tx: ChainTxDb): """ Add an event published on one of the supported chains. Adds the tx to the database, creates a pending tx entry in the pending tx table @@ -265,8 +265,8 @@ async def add_and_publish_pending_tx(self, session: DbSession, tx: ChainTxDb): Note that this function commits changes to the database for consistency between the DB and the message queue. """ - self.add_pending_tx(session=session, tx=tx) - session.commit() + await self.add_pending_tx(session=session, tx=tx) + await session.commit() await self.publish_pending_tx(tx) @classmethod diff --git a/src/aleph/chains/connector.py b/src/aleph/chains/connector.py index e612f079f..84ea654a4 100644 --- a/src/aleph/chains/connector.py +++ b/src/aleph/chains/connector.py @@ -5,7 +5,7 @@ from aleph_message.models import Chain from configmanager import Config -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from .abc import ChainReader, ChainWriter from .bsc import BscConnector @@ -29,7 +29,7 @@ class ChainConnector: def __init__( self, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, pending_tx_publisher: PendingTxPublisher, chain_data_service: ChainDataService, ): diff --git a/src/aleph/chains/ethereum.py b/src/aleph/chains/ethereum.py index fe382595e..398be5cda 100644 --- a/src/aleph/chains/ethereum.py +++ b/src/aleph/chains/ethereum.py @@ -23,7 +23,7 @@ from aleph.schemas.chains.tx_context import TxContext from aleph.toolkit.timestamp import utc_now from aleph.types.chain_sync import ChainEventType -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from aleph.utils import run_in_executor from .abc import ChainWriter @@ -77,7 +77,7 @@ class EthereumVerifier(EVMVerifier): class EthereumConnector(ChainWriter): def __init__( self, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, pending_tx_publisher: PendingTxPublisher, chain_data_service: ChainDataService, ): @@ -93,8 +93,8 @@ def __init__( async def get_last_height(self, sync_type: ChainEventType) -> int: """Returns the last height for which we already have the ethereum data.""" - with self.session_factory() as session: - last_height = get_last_height( + async with self.session_factory() as session: + last_height = await get_last_height( session=session, chain=Chain.ETH, sync_type=sync_type ) @@ -209,15 +209,15 @@ async def _request_transactions( # block height to do next requests from there. last_height = event_data.blockNumber if last_height: - with self.session_factory() as session: - upsert_chain_sync_status( + async with self.session_factory() as session: + await upsert_chain_sync_status( session=session, chain=Chain.ETH, sync_type=ChainEventType.SYNC, height=last_height, update_datetime=utc_now(), ) - session.commit() + await session.commit() async def fetch_ethereum_sync_events(self, config: Config): last_stored_height = await self.get_last_height(sync_type=ChainEventType.SYNC) @@ -236,11 +236,11 @@ async def fetch_ethereum_sync_events(self, config: Config): config, web3, contract, abi, last_stored_height ): tx = ChainTxDb.from_sync_tx_context(tx_context=context, tx_data=jdata) - with self.session_factory() as session: + async with self.session_factory() as session: await self.pending_tx_publisher.add_and_publish_pending_tx( session=session, tx=tx ) - session.commit() + await session.commit() async def fetch_sync_events_task(self, config: Config): while True: @@ -295,10 +295,10 @@ async def packer(self, config: Config): i = 0 gas_price = web3.eth.generate_gas_price() while True: - with self.session_factory() as session: + async with self.session_factory() as session: # Wait for sync operations to complete - if (count_pending_txs(session=session, chain=Chain.ETH)) or ( - count_pending_messages(session=session, chain=Chain.ETH) + if (await count_pending_txs(session=session, chain=Chain.ETH)) or ( + await count_pending_messages(session=session, chain=Chain.ETH) ) > 1000: await asyncio.sleep(30) continue @@ -317,7 +317,7 @@ async def packer(self, config: Config): nonce = web3.eth.get_transaction_count(account.address) messages = list( - get_unconfirmed_messages( + await get_unconfirmed_messages( session=session, limit=10000, chain=Chain.ETH ) ) @@ -332,7 +332,7 @@ async def packer(self, config: Config): ) ) # Required to apply update to the files table in get_chaindata - session.commit() + await session.commit() response = await run_in_executor( None, self._broadcast_content, diff --git a/src/aleph/chains/indexer_reader.py b/src/aleph/chains/indexer_reader.py index 0e9f81f72..e5a5bcc91 100644 --- a/src/aleph/chains/indexer_reader.py +++ b/src/aleph/chains/indexer_reader.py @@ -38,7 +38,7 @@ from aleph.toolkit.range import MultiRange, Range from aleph.toolkit.timestamp import timestamp_to_datetime from aleph.types.chain_sync import ChainEventType, ChainSyncProtocol -from aleph.types.db_session import DbSession, DbSessionFactory +from aleph.types.db_session import AsyncDbSession, AsyncDbSessionFactory LOGGER = logging.getLogger(__name__) @@ -246,7 +246,7 @@ class AlephIndexerReader: def __init__( self, chain: Chain, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, pending_tx_publisher: PendingTxPublisher, ): self.chain = chain @@ -257,7 +257,7 @@ def __init__( async def fetch_range( self, - session: DbSession, + session: AsyncDbSession, indexer_client: AlephIndexerClient, chain: Chain, event_type: ChainEventType, @@ -295,7 +295,9 @@ async def fetch_range( LOGGER.info("%d new txs", len(txs)) # Events are listed in reverse order in the indexer response for tx in txs: - self.pending_tx_publisher.add_pending_tx(session=session, tx=tx) + await self.pending_tx_publisher.add_pending_tx( + session=session, tx=tx + ) if nb_events_fetched >= limit: last_event_datetime = txs[-1].datetime @@ -320,7 +322,7 @@ async def fetch_range( str(synced_range), ) - add_indexer_range( + await add_indexer_range( session=session, chain=chain, event_type=event_type, @@ -329,7 +331,7 @@ async def fetch_range( # Committing periodically reduces the size of DB transactions for large numbers # of events. - session.commit() + await session.commit() # Now that the txs are committed to the DB, add them to the pending tx message queue for tx in txs: @@ -347,7 +349,7 @@ async def fetch_range( async def fetch_new_events( self, - session: DbSession, + session: AsyncDbSession, indexer_url: str, smart_contract_address: str, event_type: ChainEventType, @@ -372,7 +374,7 @@ async def fetch_new_events( ] ) - multirange_to_sync = get_missing_indexer_datetime_multirange( + multirange_to_sync = await get_missing_indexer_datetime_multirange( session=session, chain=self.chain, event_type=event_type, @@ -399,14 +401,14 @@ async def fetcher( ): while True: try: - with self.session_factory() as session: + async with self.session_factory() as session: await self.fetch_new_events( session=session, indexer_url=indexer_url, smart_contract_address=smart_contract_address, event_type=event_type, ) - session.commit() + await session.commit() except Exception: LOGGER.exception( "An unexpected exception occurred, " diff --git a/src/aleph/chains/nuls2.py b/src/aleph/chains/nuls2.py index 6062b5b0a..d592105b2 100644 --- a/src/aleph/chains/nuls2.py +++ b/src/aleph/chains/nuls2.py @@ -29,7 +29,7 @@ from aleph.schemas.chains.tx_context import TxContext from aleph.schemas.pending_messages import BasePendingMessage from aleph.toolkit.timestamp import utc_now -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from aleph.utils import run_in_executor from ..db.models import ChainTxDb @@ -80,7 +80,7 @@ async def verify_signature(self, message: BasePendingMessage) -> bool: class Nuls2Connector(ChainWriter): def __init__( self, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, pending_tx_publisher: PendingTxPublisher, chain_data_service: ChainDataService, ): @@ -90,8 +90,8 @@ def __init__( async def get_last_height(self, sync_type: ChainEventType) -> int: """Returns the last height for which we already have the nuls data.""" - with self.session_factory() as session: - last_height = get_last_height( + async with self.session_factory() as session: + last_height = await get_last_height( session=session, chain=Chain.NULS2, sync_type=sync_type ) @@ -133,15 +133,15 @@ async def _request_transactions( LOGGER.info("Incoming logic data is not JSON, ignoring. %r" % ldata) if last_height: - with self.session_factory() as session: - upsert_chain_sync_status( + async with self.session_factory() as session: + await upsert_chain_sync_status( session=session, chain=Chain.NULS2, sync_type=ChainEventType.SYNC, height=last_height, update_datetime=utc_now(), ) - session.commit() + await session.commit() async def fetcher(self, config: Config): last_stored_height = await self.get_last_height(sync_type=ChainEventType.SYNC) @@ -158,11 +158,11 @@ async def fetcher(self, config: Config): tx = ChainTxDb.from_sync_tx_context( tx_context=context, tx_data=jdata ) - with self.session_factory() as db_session: + async with self.session_factory() as db_session: await self.pending_tx_publisher.add_and_publish_pending_tx( session=db_session, tx=tx ) - db_session.commit() + await db_session.commit() await asyncio.sleep(10) @@ -182,9 +182,9 @@ async def packer(self, config: Config): nonce = await get_nonce(server, address, chain_id) while True: - with self.session_factory() as session: - if (count_pending_txs(session=session, chain=Chain.NULS2)) or ( - count_pending_messages(session=session, chain=Chain.NULS2) + async with self.session_factory() as session: + if (await count_pending_txs(session=session, chain=Chain.NULS2)) or ( + await count_pending_messages(session=session, chain=Chain.NULS2) ): await asyncio.sleep(30) continue @@ -195,7 +195,7 @@ async def packer(self, config: Config): i = 0 messages = list( - get_unconfirmed_messages( + await get_unconfirmed_messages( session=session, limit=10000, chain=Chain.ETH ) ) @@ -208,7 +208,7 @@ async def packer(self, config: Config): ) ) # Required to apply update to the files table in get_chaindata - session.commit() + await session.commit() content = sync_event_payload.json() tx = await prepare_transfer_tx( diff --git a/src/aleph/chains/tezos.py b/src/aleph/chains/tezos.py index 6cb34e6aa..249d8e151 100644 --- a/src/aleph/chains/tezos.py +++ b/src/aleph/chains/tezos.py @@ -24,7 +24,7 @@ from aleph.schemas.pending_messages import BasePendingMessage from aleph.toolkit.timestamp import utc_now from aleph.types.chain_sync import ChainEventType, ChainSyncProtocol -from aleph.types.db_session import DbSession, DbSessionFactory +from aleph.types.db_session import AsyncDbSession, AsyncDbSessionFactory LOGGER = logging.getLogger(__name__) @@ -249,7 +249,7 @@ async def verify_signature(self, message: BasePendingMessage) -> bool: class TezosConnector(ChainReader): def __init__( self, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, pending_tx_publisher: PendingTxPublisher, ): self.session_factory = session_factory @@ -257,8 +257,8 @@ def __init__( async def get_last_height(self, sync_type: ChainEventType) -> int: """Returns the last height for which we already have the ethereum data.""" - with self.session_factory() as session: - last_height = get_last_height( + async with self.session_factory() as session: + last_height = await get_last_height( session=session, chain=Chain.TEZOS, sync_type=sync_type ) @@ -269,7 +269,7 @@ async def get_last_height(self, sync_type: ChainEventType) -> int: return last_height async def fetch_incoming_messages( - self, session: DbSession, indexer_url: str, sync_contract_address: str + self, session: AsyncDbSession, indexer_url: str, sync_contract_address: str ) -> None: """ Fetch the latest message events from the Aleph sync smart contract. @@ -324,7 +324,7 @@ async def fetch_incoming_messages( break finally: - upsert_chain_sync_status( + await upsert_chain_sync_status( session=session, chain=Chain.TEZOS, sync_type=ChainEventType.MESSAGE, @@ -335,13 +335,13 @@ async def fetch_incoming_messages( async def fetcher(self, config: Config): while True: try: - with self.session_factory() as session: + async with self.session_factory() as session: await self.fetch_incoming_messages( session=session, indexer_url=config.tezos.indexer_url.value, sync_contract_address=config.tezos.sync_contract.value, ) - session.commit() + await session.commit() except Exception: LOGGER.exception( "An unexpected exception occurred, " diff --git a/src/aleph/commands.py b/src/aleph/commands.py index 31c8878e5..8361a7fc8 100644 --- a/src/aleph/commands.py +++ b/src/aleph/commands.py @@ -26,7 +26,11 @@ from aleph.chains.chain_data_service import ChainDataService, PendingTxPublisher from aleph.chains.connector import ChainConnector from aleph.cli.args import parse_args -from aleph.db.connection import make_db_url, make_engine, make_session_factory +from aleph.db.connection import ( + make_async_engine, + make_async_session_factory, + make_db_url, +) from aleph.exceptions import InvalidConfigException, KeyNotFoundException from aleph.jobs import start_jobs from aleph.network import listener_tasks @@ -122,13 +126,12 @@ async def main(args: List[str]) -> None: run_db_migrations(config) LOGGER.info("Database initialized.") - engine = make_engine( + engine = make_async_engine( config, echo=args.loglevel == logging.DEBUG, application_name="aleph-conn-manager", ) - session_factory = make_session_factory(engine) - + session_factory = make_async_session_factory(engine) setup_logging(args.loglevel) mq_conn = await make_mq_conn(config) diff --git a/src/aleph/db/accessors/aggregates.py b/src/aleph/db/accessors/aggregates.py index ac35fc400..7b0ddae67 100644 --- a/src/aleph/db/accessors/aggregates.py +++ b/src/aleph/db/accessors/aggregates.py @@ -18,7 +18,7 @@ from aleph.cache import cache from aleph.db.models import AggregateDb, AggregateElementDb -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession logger = logging.getLogger(__name__) @@ -29,8 +29,8 @@ def prune_cache_for_updated_aggregates(mapper, connection, target): cache.delete_namespace(f"aggregates_by_owner:{target.owner}") -def aggregate_exists(session: DbSession, key: str, owner: str) -> bool: - return AggregateDb.exists( +async def aggregate_exists(session: AsyncDbSession, key: str, owner: str) -> bool: + return await AggregateDb.exists( session=session, where=(AggregateDb.key == key) & (AggregateDb.owner == owner), ) @@ -41,8 +41,8 @@ def aggregate_exists(session: DbSession, key: str, owner: str) -> bool: @overload -def get_aggregates_by_owner( - session: Any, +async def get_aggregates_by_owner( + session: AsyncDbSession, owner: str, with_info: Literal[False], keys: Optional[Sequence[str]] = None, @@ -50,8 +50,8 @@ def get_aggregates_by_owner( @overload -def get_aggregates_by_owner( - session: Any, +async def get_aggregates_by_owner( + session: AsyncDbSession, owner: str, with_info: Literal[True], keys: Optional[Sequence[str]] = None, @@ -59,12 +59,15 @@ def get_aggregates_by_owner( @overload -def get_aggregates_by_owner( - session, owner: str, with_info: bool, keys: Optional[Sequence[str]] = None +async def get_aggregates_by_owner( + session: AsyncDbSession, + owner: str, + with_info: bool, + keys: Optional[Sequence[str]] = None, ) -> Union[AggregateContent, AggregateContentWithInfo]: ... -def get_aggregates_by_owner(session, owner, with_info, keys=None): +async def get_aggregates_by_owner(session: AsyncDbSession, owner, with_info, keys=None): cache_key = f"{with_info} {keys}" if ( @@ -77,8 +80,8 @@ def get_aggregates_by_owner(session, owner, with_info, keys=None): if keys: where_clause = where_clause & AggregateDb.key.in_(keys) if with_info: - query = ( - session.query( + stmt = ( + select( AggregateDb.key, AggregateDb.content, AggregateDb.creation_datetime.label("created"), @@ -93,18 +96,18 @@ def get_aggregates_by_owner(session, owner, with_info, keys=None): .filter(AggregateDb.owner == owner) ) else: - query = ( - session.query(AggregateDb.key, AggregateDb.content) + stmt = ( + select(AggregateDb.key, AggregateDb.content) .filter(where_clause) .order_by(AggregateDb.key) ) - result = query.all() - cache.set(cache_key, result, namespace="aggregates_by_owner:{owner}") + result = (await session.execute(stmt)).all() + cache.set(cache_key, result, namespace=f"aggregates_by_owner:{owner}") return result -def get_aggregate_by_key( - session: DbSession, +async def get_aggregate_by_key( + session: AsyncDbSession, owner: str, key: str, with_content: bool = True, @@ -118,7 +121,7 @@ def get_aggregate_by_key( (AggregateDb.owner == owner) & (AggregateDb.key == key) ) return ( - session.execute( + await session.execute( select_stmt.options( *options, selectinload(AggregateDb.last_revision), @@ -127,29 +130,29 @@ def get_aggregate_by_key( ).scalar() -def get_aggregate_content_keys( - session: DbSession, owner: str, key: str +async def get_aggregate_content_keys( + session: AsyncDbSession, owner: str, key: str ) -> Iterable[str]: - return AggregateDb.jsonb_keys( + return await AggregateDb.jsonb_keys( session=session, column=AggregateDb.content, where=(AggregateDb.key == key) & (AggregateDb.owner == owner), ) -def get_aggregate_elements( - session: DbSession, owner: str, key: str +async def get_aggregate_elements( + session: AsyncDbSession, owner: str, key: str ) -> Iterable[AggregateElementDb]: select_stmt = ( select(AggregateElementDb) .where((AggregateElementDb.key == key) & (AggregateElementDb.owner == owner)) .order_by(AggregateElementDb.creation_datetime) ) - return (session.execute(select_stmt)).scalars() + return (await session.execute(select_stmt)).scalars() -def insert_aggregate( - session: DbSession, +async def insert_aggregate( + session: AsyncDbSession, key: str, owner: str, content: Dict[str, Any], @@ -164,11 +167,11 @@ def insert_aggregate( last_revision_hash=last_revision_hash, dirty=False, ) - session.execute(insert_stmt) + await session.execute(insert_stmt) -def update_aggregate( - session: DbSession, +async def update_aggregate( + session: AsyncDbSession, key: str, owner: str, content: Dict[str, Any], @@ -189,11 +192,11 @@ def update_aggregate( ) .where((AggregateDb.key == key) & (AggregateDb.owner == owner)) ) - session.execute(update_stmt) + await session.execute(update_stmt) -def insert_aggregate_element( - session: DbSession, +async def insert_aggregate_element( + session: AsyncDbSession, item_hash: str, key: str, owner: str, @@ -207,33 +210,48 @@ def insert_aggregate_element( content=content, creation_datetime=creation_datetime, ) - session.execute(insert_stmt) + # Use on_conflict_do_nothing to handle duplicate item_hash values + upsert_stmt = insert_stmt.on_conflict_do_nothing( + constraint="aggregate_elements_pkey" + ) + await session.execute(upsert_stmt) -def count_aggregate_elements(session: DbSession, owner: str, key: str) -> int: +async def count_aggregate_elements( + session: AsyncDbSession, owner: str, key: str +) -> int: select_stmt = select(AggregateElementDb).where( (AggregateElementDb.key == key) & (AggregateElementDb.owner == owner) ) - return session.execute(select(func.count()).select_from(select_stmt)).scalar_one() + result = await session.execute(select(func.count()).select_from(select_stmt)) + return result.scalar_one() -def merge_aggregate_elements(elements: Iterable[AggregateElementDb]) -> Dict: - content = {} - for element in elements: - content.update(element.content) - return content +async def merge_aggregate_elements(elements: Iterable[AggregateElementDb]) -> Dict: + """Asynchronously merge aggregate elements by offloading CPU-intensive operation to a thread pool.""" + from aleph.utils import run_in_executor + def _merge(elements): + content = {} + for element in elements: + content.update(element.content) + return content -def mark_aggregate_as_dirty(session: DbSession, owner: str, key: str) -> None: + return await run_in_executor(None, _merge, elements) + + +async def mark_aggregate_as_dirty( + session: AsyncDbSession, owner: str, key: str +) -> None: update_stmt = ( update(AggregateDb) .values(dirty=True) .where((AggregateDb.key == key) & (AggregateDb.owner == owner)) ) - session.execute(update_stmt) + await session.execute(update_stmt) -def refresh_aggregate(session: DbSession, owner: str, key: str) -> None: +async def refresh_aggregate(session: AsyncDbSession, owner: str, key: str) -> None: # Step 1: use a group by to retrieve the aggregate content. This uses a custom # aggregate function (see 78dd67881db4_jsonb_merge_aggregate.py). select_merged_aggregate_subquery = ( @@ -291,18 +309,18 @@ def refresh_aggregate(session: DbSession, owner: str, key: str) -> None: }, ) - session.execute(upsert_aggregate_stmt) + await session.execute(upsert_aggregate_stmt) -def delete_aggregate(session: DbSession, owner: str, key: str) -> None: +async def delete_aggregate(session: AsyncDbSession, owner: str, key: str) -> None: delete_aggregate_stmt = delete(AggregateDb).where( (AggregateDb.key == key) & (AggregateDb.owner == owner) ) - session.execute(delete_aggregate_stmt) + await session.execute(delete_aggregate_stmt) -def delete_aggregate_element(session: DbSession, item_hash: str) -> None: +async def delete_aggregate_element(session: AsyncDbSession, item_hash: str) -> None: delete_element_stmt = delete(AggregateElementDb).where( AggregateElementDb.item_hash == item_hash ) - session.execute(delete_element_stmt) + await session.execute(delete_element_stmt) diff --git a/src/aleph/db/accessors/balances.py b/src/aleph/db/accessors/balances.py index e1e9b1eed..6aee302bb 100644 --- a/src/aleph/db/accessors/balances.py +++ b/src/aleph/db/accessors/balances.py @@ -2,28 +2,31 @@ from io import StringIO from typing import Dict, Mapping, Optional, Sequence +import asyncpg from aleph_message.models import Chain from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncConnection from sqlalchemy.sql import Select from aleph.db.models import AlephBalanceDb -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession -def get_balance_by_chain( - session: DbSession, address: str, chain: Chain, dapp: Optional[str] = None +async def get_balance_by_chain( + session: AsyncDbSession, address: str, chain: Chain, dapp: Optional[str] = None ) -> Optional[Decimal]: - return session.execute( - select(AlephBalanceDb.balance).where( - (AlephBalanceDb.address == address) - & (AlephBalanceDb.chain == chain.value) - & (AlephBalanceDb.dapp == dapp) + return ( + await session.execute( + select(AlephBalanceDb.balance).where( + (AlephBalanceDb.address == address) + & (AlephBalanceDb.chain == chain.value) + & (AlephBalanceDb.dapp == dapp) + ) ) ).scalar() def make_balances_by_chain_query( - session: DbSession, chains: Optional[Sequence[Chain]] = None, page: int = 1, pagination: int = 100, @@ -46,21 +49,23 @@ def make_balances_by_chain_query( return query -def get_balances_by_chain(session: DbSession, **kwargs): - select_stmt = make_balances_by_chain_query(session=session, **kwargs) - return (session.execute(select_stmt)).all() +async def get_balances_by_chain(session: AsyncDbSession, **kwargs): + select_stmt = make_balances_by_chain_query(**kwargs) + return (await session.execute(select_stmt)).all() -def count_balances_by_chain(session: DbSession, pagination: int = 0, **kwargs): +async def count_balances_by_chain( + session: AsyncDbSession, pagination: int = 0, **kwargs +): select_stmt = make_balances_by_chain_query( - session=session, pagination=0, **kwargs + pagination=pagination, **kwargs ).subquery() select_count_stmt = select(func.count()).select_from(select_stmt) - return session.execute(select_count_stmt).scalar_one() + return (await session.execute(select_count_stmt)).scalar_one() -def get_total_balance( - session: DbSession, address: str, include_dapps: bool = False +async def get_total_balance( + session: AsyncDbSession, address: str, include_dapps: bool = False ) -> Decimal: where_clause = AlephBalanceDb.address == address if not include_dapps: @@ -73,12 +78,12 @@ def get_total_balance( .group_by(AlephBalanceDb.address) ) - result = session.execute(select_stmt).one_or_none() + result = (await session.execute(select_stmt)).one_or_none() return Decimal(0) if result is None else result.balance or Decimal(0) -def get_total_detailed_balance( - session: DbSession, +async def get_total_detailed_balance( + session: AsyncDbSession, address: str, chain: Optional[str] = None, include_dapps: bool = False, @@ -94,7 +99,7 @@ def get_total_detailed_balance( .group_by(AlephBalanceDb.address) ) - result = session.execute(query).first() + result = (await session.execute(query)).first() return result[0] if result is not None else Decimal(0), {} query = ( @@ -108,7 +113,7 @@ def get_total_detailed_balance( balances_by_chain = { row.chain: row.balance or Decimal(0) - for row in session.execute(query).fetchall() + for row in (await session.execute(query)).fetchall() } query = ( @@ -120,12 +125,12 @@ def get_total_detailed_balance( .group_by(AlephBalanceDb.address) ) - result = session.execute(query).first() + result = (await session.execute(query)).first() return result[0] if result is not None else Decimal(0), balances_by_chain -def update_balances( - session: DbSession, +async def update_balances( + session: AsyncDbSession, chain: Chain, dapp: Optional[str], eth_height: int, @@ -140,12 +145,14 @@ def update_balances( table from the temporary one. """ - session.execute( + await session.execute( "CREATE TEMPORARY TABLE temp_balances AS SELECT * FROM balances WITH NO DATA" # type: ignore[arg-type] ) - conn = session.connection().connection - cursor = conn.cursor() + # Get the raw asyncpg connection from SQLAlchemy + raw_conn: AsyncConnection = await session.connection() + # Get the underlying asyncpg connection + asyncpg_conn: asyncpg.Connection = await raw_conn.get_raw_connection() # Prepare an in-memory CSV file for use with the COPY operator csv_balances = StringIO( @@ -156,11 +163,11 @@ def update_balances( ] ) ) - cursor.copy_expert( - "COPY temp_balances(address, chain, dapp, balance, eth_height) FROM STDIN WITH CSV DELIMITER ';'", - csv_balances, + await asyncpg_conn.execute( + "COPY temp_balances(address, chain, dapp, balance, eth_height) FROM STDIN CSV DELIMITER ';'", + csv_balances.getvalue(), ) - session.execute( + await session.execute( """ INSERT INTO balances(address, chain, dapp, balance, eth_height) (SELECT address, chain, dapp, balance, eth_height FROM temp_balances) @@ -173,4 +180,4 @@ def update_balances( # Temporary tables are dropped at the same time as the connection, but SQLAlchemy # tends to reuse connections. Dropping the table here guarantees it will not be present # on the next run. - session.execute("DROP TABLE temp_balances") # type: ignore[arg-type] + await session.execute("DROP TABLE temp_balances") # type: ignore[arg-type] diff --git a/src/aleph/db/accessors/chains.py b/src/aleph/db/accessors/chains.py index 9da9bbbfa..e8dcbed65 100644 --- a/src/aleph/db/accessors/chains.py +++ b/src/aleph/db/accessors/chains.py @@ -9,16 +9,16 @@ from aleph.toolkit.range import MultiRange, Range from aleph.toolkit.timestamp import utc_now from aleph.types.chain_sync import ChainEventType -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession from ..models.chains import ChainSyncStatusDb, ChainTxDb, IndexerSyncStatusDb -def get_last_height( - session: DbSession, chain: Chain, sync_type: ChainEventType +async def get_last_height( + session: AsyncDbSession, chain: Chain, sync_type: ChainEventType ) -> Optional[int]: height = ( - session.execute( + await session.execute( select(ChainSyncStatusDb.height).where( (ChainSyncStatusDb.chain == chain) & (ChainSyncStatusDb.type == sync_type) @@ -28,7 +28,7 @@ def get_last_height( return height -def upsert_chain_tx(session: DbSession, tx: ChainTxDb) -> None: +async def upsert_chain_tx(session: AsyncDbSession, tx: ChainTxDb) -> None: insert_stmt = insert(ChainTxDb).values( hash=tx.hash, chain=tx.chain, @@ -40,11 +40,11 @@ def upsert_chain_tx(session: DbSession, tx: ChainTxDb) -> None: content=tx.content, ) upsert_stmt = insert_stmt.on_conflict_do_nothing() - session.execute(upsert_stmt) + await session.execute(upsert_stmt) -def upsert_chain_sync_status( - session: DbSession, +async def upsert_chain_sync_status( + session: AsyncDbSession, chain: Chain, sync_type: ChainEventType, height: int, @@ -58,7 +58,7 @@ def upsert_chain_sync_status( set_={"height": height, "last_update": update_datetime}, ) ) - session.execute(upsert_stmt) + await session.execute(upsert_stmt) @dataclass @@ -71,8 +71,8 @@ def iter_ranges(self) -> Iterable[Range[dt.datetime]]: return self.datetime_multirange.ranges -def get_indexer_multirange( - session: DbSession, chain: Chain, event_type: ChainEventType +async def get_indexer_multirange( + session: AsyncDbSession, chain: Chain, event_type: ChainEventType ) -> IndexerMultiRange: """ Returns the already synced indexer ranges for the specified chain and event type. @@ -92,7 +92,7 @@ def get_indexer_multirange( .order_by(IndexerSyncStatusDb.start_block_datetime) ) - rows = session.execute(select_stmt).scalars() + rows = (await session.execute(select_stmt)).scalars() datetime_multirange: MultiRange[dt.datetime] = MultiRange() @@ -106,26 +106,29 @@ def get_indexer_multirange( ) -def get_missing_indexer_datetime_multirange( - session: DbSession, chain: Chain, event_type: ChainEventType, indexer_multirange +async def get_missing_indexer_datetime_multirange( + session: AsyncDbSession, + chain: Chain, + event_type: ChainEventType, + indexer_multirange, ) -> MultiRange[dt.datetime]: # TODO: this query is inefficient (too much data retrieved, too many rows, code manipulation. # replace it with the range/multirange operations of PostgreSQL 14+ once the MongoDB # version is out the window. - db_multiranges = get_indexer_multirange( + db_multiranges = await get_indexer_multirange( session=session, chain=chain, event_type=event_type ) return indexer_multirange - db_multiranges.datetime_multirange -def update_indexer_multirange( - session: DbSession, indexer_multirange: IndexerMultiRange +async def update_indexer_multirange( + session: AsyncDbSession, indexer_multirange: IndexerMultiRange ): chain = indexer_multirange.chain event_type = indexer_multirange.event_type # For now, just delete all matching entries and rewrite them. - session.execute( + await session.execute( delete(IndexerSyncStatusDb).where( (IndexerSyncStatusDb.chain == chain) & (IndexerSyncStatusDb.event_type == event_type) @@ -133,7 +136,7 @@ def update_indexer_multirange( ) update_time = utc_now() for datetime_range in indexer_multirange.iter_ranges(): - session.execute( + await session.execute( insert(IndexerSyncStatusDb).values( chain=chain, event_type=event_type, @@ -146,15 +149,17 @@ def update_indexer_multirange( ) -def add_indexer_range( - session: DbSession, +async def add_indexer_range( + session: AsyncDbSession, chain: Chain, event_type: ChainEventType, datetime_range: Range[dt.datetime], ): - indexer_multirange = get_indexer_multirange( + indexer_multirange = await get_indexer_multirange( session=session, chain=chain, event_type=event_type ) indexer_multirange.datetime_multirange += datetime_range - update_indexer_multirange(session=session, indexer_multirange=indexer_multirange) + await update_indexer_multirange( + session=session, indexer_multirange=indexer_multirange + ) diff --git a/src/aleph/db/accessors/cost.py b/src/aleph/db/accessors/cost.py index 8f51e433d..0ffc434f1 100644 --- a/src/aleph/db/accessors/cost.py +++ b/src/aleph/db/accessors/cost.py @@ -8,11 +8,11 @@ from aleph.db.models.account_costs import AccountCostsDb from aleph.toolkit.costs import format_cost -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession -def get_total_cost_for_address( - session: DbSession, +async def get_total_cost_for_address( + session: AsyncDbSession, address: str, payment_type: Optional[PaymentType] = PaymentType.hold, ) -> Decimal: @@ -31,13 +31,15 @@ def get_total_cost_for_address( ) ) - total_cost = session.execute(select_stmt).scalar() + total_cost = (await session.execute(select_stmt)).scalar() return format_cost(Decimal(total_cost or 0)) -def get_message_costs(session: DbSession, item_hash: str) -> Iterable[AccountCostsDb]: +async def get_message_costs( + session: AsyncDbSession, item_hash: str +) -> Iterable[AccountCostsDb]: select_stmt = select(AccountCostsDb).where(AccountCostsDb.item_hash == item_hash) - return (session.execute(select_stmt)).scalars().all() + return (await session.execute(select_stmt)).scalars().all() def make_costs_upsert_query(costs: List[AccountCostsDb]) -> Insert: diff --git a/src/aleph/db/accessors/files.py b/src/aleph/db/accessors/files.py index a073098e2..ac3e7ca1e 100644 --- a/src/aleph/db/accessors/files.py +++ b/src/aleph/db/accessors/files.py @@ -5,7 +5,7 @@ from sqlalchemy.dialects.postgresql import insert from sqlalchemy.engine import Row -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession from aleph.types.files import FileTag, FileType from aleph.types.sort_order import SortOrder @@ -21,11 +21,13 @@ ) -def is_pinned_file(session: DbSession, file_hash: str) -> bool: - return FilePinDb.exists(session=session, where=FilePinDb.file_hash == file_hash) +async def is_pinned_file(session: AsyncDbSession, file_hash: str) -> bool: + return await FilePinDb.exists( + session=session, where=FilePinDb.file_hash == file_hash + ) -def get_unpinned_files(session: DbSession) -> Iterable[StoredFileDb]: +async def get_unpinned_files(session: AsyncDbSession) -> Iterable[StoredFileDb]: """ Returns the list of files that are not pinned by a message or an on-chain transaction. """ @@ -34,11 +36,11 @@ def get_unpinned_files(session: DbSession) -> Iterable[StoredFileDb]: .join(FilePinDb, StoredFileDb.hash == FilePinDb.file_hash, isouter=True) .where(FilePinDb.id.is_(None)) ) - return session.execute(select_stmt).scalars() + return (await session.execute(select_stmt)).scalars() -def upsert_tx_file_pin( - session: DbSession, file_hash: str, tx_hash: str, created: dt.datetime +async def upsert_tx_file_pin( + session: AsyncDbSession, file_hash: str, tx_hash: str, created: dt.datetime ) -> None: upsert_stmt = ( insert(TxFilePinDb) @@ -47,11 +49,11 @@ def upsert_tx_file_pin( ) .on_conflict_do_nothing() ) - session.execute(upsert_stmt) + await session.execute(upsert_stmt) -def insert_content_file_pin( - session: DbSession, +async def insert_content_file_pin( + session: AsyncDbSession, file_hash: str, owner: str, item_hash: str, @@ -64,11 +66,11 @@ def insert_content_file_pin( type=FilePinType.CONTENT, created=created, ) - session.execute(insert_stmt) + await session.execute(insert_stmt) -def insert_message_file_pin( - session: DbSession, +async def insert_message_file_pin( + session: AsyncDbSession, file_hash: str, owner: str, item_hash: str, @@ -83,32 +85,34 @@ def insert_message_file_pin( ref=ref, created=created, ) - session.execute(insert_stmt) + await session.execute(insert_stmt) -def count_file_pins(session: DbSession, file_hash: str) -> int: +async def count_file_pins(session: AsyncDbSession, file_hash: str) -> int: select_count_stmt = select(func.count()).select_from( select(FilePinDb).where(FilePinDb.file_hash == file_hash).subquery() ) - return session.execute(select_count_stmt).scalar_one() + return (await session.execute(select_count_stmt)).scalar_one() -def find_file_pins(session: DbSession, item_hashes: Collection[str]) -> Iterable[str]: +async def find_file_pins( + session: AsyncDbSession, item_hashes: Collection[str] +) -> Iterable[str]: select_stmt = select(MessageFilePinDb.item_hash).where( MessageFilePinDb.item_hash.in_(item_hashes) ) - return session.execute(select_stmt).scalars() + return (await session.execute(select_stmt)).scalars() -def delete_file_pin(session: DbSession, item_hash: str) -> None: +async def delete_file_pin(session: AsyncDbSession, item_hash: str) -> None: delete_stmt = delete(MessageFilePinDb).where( MessageFilePinDb.item_hash == item_hash ) - session.execute(delete_stmt) + await session.execute(delete_stmt) -def insert_grace_period_file_pin( - session: DbSession, +async def insert_grace_period_file_pin( + session: AsyncDbSession, file_hash: str, created: dt.datetime, delete_by: dt.datetime, @@ -119,25 +123,31 @@ def insert_grace_period_file_pin( type=FilePinType.GRACE_PERIOD, delete_by=delete_by, ) - session.execute(insert_stmt) + await session.execute(insert_stmt) -def delete_grace_period_file_pins(session: DbSession, datetime: dt.datetime) -> None: +async def delete_grace_period_file_pins( + session: AsyncDbSession, datetime: dt.datetime +) -> None: delete_stmt = delete(GracePeriodFilePinDb).where( GracePeriodFilePinDb.delete_by < datetime ) - session.execute(delete_stmt) + await session.execute(delete_stmt) -def get_message_file_pin( - session: DbSession, item_hash: str +async def get_message_file_pin( + session: AsyncDbSession, item_hash: str ) -> Optional[MessageFilePinDb]: - return session.execute( - select(MessageFilePinDb).where(MessageFilePinDb.item_hash == item_hash) + return ( + await session.execute( + select(MessageFilePinDb).where(MessageFilePinDb.item_hash == item_hash) + ) ).scalar_one_or_none() -def get_address_files_stats(session: DbSession, owner: str) -> Tuple[int, int]: +async def get_address_files_stats( + session: AsyncDbSession, owner: str +) -> Tuple[int, int]: select_stmt = ( select( func.count().label("nb_files"), @@ -147,12 +157,12 @@ def get_address_files_stats(session: DbSession, owner: str) -> Tuple[int, int]: .join(StoredFileDb, MessageFilePinDb.file_hash == StoredFileDb.hash) .where(MessageFilePinDb.owner == owner) ) - result = session.execute(select_stmt).one() + result = (await session.execute(select_stmt)).one() return result.nb_files, result.total_size -def get_address_files_for_api( - session: DbSession, +async def get_address_files_for_api( + session: AsyncDbSession, owner: str, pagination: int = 0, page: int = 1, @@ -186,44 +196,48 @@ def get_address_files_for_api( select_stmt = select_stmt.order_by(*order_by_columns) - return session.execute(select_stmt).all() + return (await session.execute(select_stmt)).all() -def upsert_file(session: DbSession, file_hash: str, size: int, file_type: FileType): +async def upsert_file( + session: AsyncDbSession, file_hash: str, size: int, file_type: FileType +): upsert_file_stmt = ( insert(StoredFileDb) .values(hash=file_hash, size=size, type=file_type) .on_conflict_do_nothing(constraint="files_pkey") ) - session.execute(upsert_file_stmt) + await session.execute(upsert_file_stmt) -def get_file(session: DbSession, file_hash: str) -> Optional[StoredFileDb]: +async def get_file(session: AsyncDbSession, file_hash: str) -> Optional[StoredFileDb]: select_stmt = select(StoredFileDb).where(StoredFileDb.hash == file_hash) - return session.execute(select_stmt).scalar_one_or_none() + return (await session.execute(select_stmt)).scalar_one_or_none() -def delete_file(session: DbSession, file_hash: str) -> None: +async def delete_file(session: AsyncDbSession, file_hash: str) -> None: delete_stmt = delete(StoredFileDb).where(StoredFileDb.hash == file_hash) - session.execute(delete_stmt) + await session.execute(delete_stmt) -def get_file_tag(session: DbSession, tag: FileTag) -> Optional[FileTagDb]: +async def get_file_tag(session: AsyncDbSession, tag: FileTag) -> Optional[FileTagDb]: select_stmt = select(FileTagDb).where(FileTagDb.tag == tag) - return session.execute(select_stmt).scalar() + return (await session.execute(select_stmt)).scalar() -def file_tag_exists(session: DbSession, tag: FileTag) -> bool: - return FileTagDb.exists(session=session, where=FileTagDb.tag == tag) +async def file_tag_exists(session: AsyncDbSession, tag: FileTag) -> bool: + return await FileTagDb.exists(session=session, where=FileTagDb.tag == tag) -def find_file_tags(session: DbSession, tags: Collection[FileTag]) -> Iterable[FileTag]: +async def find_file_tags( + session: AsyncDbSession, tags: Collection[FileTag] +) -> Iterable[FileTag]: select_stmt = select(FileTagDb.tag).where(FileTagDb.tag.in_(tags)) - return session.execute(select_stmt).scalars() + return (await session.execute(select_stmt)).scalars() -def upsert_file_tag( - session: DbSession, +async def upsert_file_tag( + session: AsyncDbSession, tag: FileTag, owner: str, file_hash: str, @@ -237,10 +251,10 @@ def upsert_file_tag( set_={"file_hash": file_hash, "last_updated": last_updated}, where=FileTagDb.last_updated < last_updated, ) - session.execute(upsert_stmt) + await session.execute(upsert_stmt) -def refresh_file_tag(session: DbSession, tag: FileTag) -> None: +async def refresh_file_tag(session: AsyncDbSession, tag: FileTag) -> None: coalesced_ref = func.coalesce(MessageFilePinDb.ref, MessageFilePinDb.item_hash) select_latest_file_pin_stmt = ( select( @@ -271,5 +285,5 @@ def refresh_file_tag(session: DbSession, tag: FileTag) -> None: "last_updated": insert_stmt.excluded.last_updated, }, ) - session.execute(delete(FileTagDb).where(FileTagDb.tag == tag)) - session.execute(upsert_stmt) + await session.execute(delete(FileTagDb).where(FileTagDb.tag == tag)) + await session.execute(upsert_stmt) diff --git a/src/aleph/db/accessors/messages.py b/src/aleph/db/accessors/messages.py index 6ab4a6175..78631b2d5 100644 --- a/src/aleph/db/accessors/messages.py +++ b/src/aleph/db/accessors/messages.py @@ -11,7 +11,7 @@ from aleph.toolkit.timestamp import coerce_to_datetime, utc_now from aleph.types.channel import Channel -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession from aleph.types.message_status import ( ErrorCode, MessageProcessingException, @@ -31,19 +31,19 @@ from .pending_messages import delete_pending_message -def get_message_by_item_hash( - session: DbSession, item_hash: ItemHash +async def get_message_by_item_hash( + session: AsyncDbSession, item_hash: ItemHash ) -> Optional[MessageDb]: select_stmt = ( select(MessageDb) .where(MessageDb.item_hash == item_hash) .options(selectinload(MessageDb.confirmations)) ) - return (session.execute(select_stmt)).scalar() + return (await session.execute(select_stmt)).scalar() -def message_exists(session: DbSession, item_hash: str) -> bool: - return MessageDb.exists( +async def message_exists(session: AsyncDbSession, item_hash: str) -> bool: + return await MessageDb.exists( session=session, where=MessageDb.item_hash == item_hash, ) @@ -187,8 +187,8 @@ def make_matching_messages_query( return select_stmt -def count_matching_messages( - session: DbSession, +async def count_matching_messages( + session: AsyncDbSession, start_date: float = 0.0, end_date: float = 0.0, sort_by: SortBy = SortBy.TIME, @@ -210,24 +210,24 @@ def count_matching_messages( pagination=0, ).subquery() select_count_stmt = select(func.count()).select_from(select_stmt) - return session.execute(select_count_stmt).scalar_one() + return (await session.execute(select_count_stmt)).scalar_one() - return MessageDb.fast_count(session=session) + return await MessageDb.fast_count(session=session) -def get_matching_messages( - session: DbSession, +async def get_matching_messages( + session: AsyncDbSession, **kwargs, # Same as make_matching_messages_query ) -> Iterable[MessageDb]: """ Applies the specified filters on the message table and returns matching entries. """ select_stmt = make_matching_messages_query(**kwargs) - return (session.execute(select_stmt)).scalars() + return (await session.execute(select_stmt)).scalars() -def get_message_stats_by_address( - session: DbSession, +async def get_message_stats_by_address( + session: AsyncDbSession, addresses: Optional[Sequence[str]] = None, ): """ @@ -250,19 +250,19 @@ def get_message_stats_by_address( select_stmt += " where address in :addresses" parameters = {"addresses": addresses_tuple} - return session.execute(text(select_stmt), parameters).all() + return (await session.execute(text(select_stmt), parameters)).all() -def refresh_address_stats_mat_view(session: DbSession) -> None: - session.execute( +async def refresh_address_stats_mat_view(session: AsyncDbSession) -> None: + await session.execute( text("refresh materialized view concurrently address_stats_mat_view") ) # TODO: declare a type that will match the result (something like UnconfirmedMessageDb) # and translate the time field to epoch. -def get_unconfirmed_messages( - session: DbSession, limit: int = 100, chain: Optional[Chain] = None +async def get_unconfirmed_messages( + session: AsyncDbSession, limit: int = 100, chain: Optional[Chain] = None ) -> Iterable[MessageDb]: if chain is None: @@ -288,7 +288,7 @@ def get_unconfirmed_messages( .order_by(MessageStatusDb.reception_time.asc()) ) - return (session.execute(select_stmt.limit(limit))).scalars() + return (await session.execute(select_stmt.limit(limit))).scalars() def make_message_upsert_query(message: MessageDb) -> Insert: @@ -310,23 +310,23 @@ def make_confirmation_upsert_query(item_hash: str, tx_hash: str) -> Insert: ) -def get_message_status( - session: DbSession, item_hash: ItemHash +async def get_message_status( + session: AsyncDbSession, item_hash: ItemHash ) -> Optional[MessageStatusDb]: return ( - session.execute( + await session.execute( select(MessageStatusDb).where(MessageStatusDb.item_hash == str(item_hash)) ) ).scalar() -def get_rejected_message( - session: DbSession, item_hash: str +async def get_rejected_message( + session: AsyncDbSession, item_hash: str ) -> Optional[RejectedMessageDb]: select_stmt = select(RejectedMessageDb).where( RejectedMessageDb.item_hash == item_hash ) - return session.execute(select_stmt).scalar() + return (await session.execute(select_stmt)).scalar() # TODO typing: Find a correct type for `where` @@ -349,21 +349,23 @@ def make_message_status_upsert_query( ) -def get_distinct_channels(session: DbSession) -> Iterable[Channel]: +async def get_distinct_channels(session: AsyncDbSession) -> Iterable[Channel]: select_stmt = select(MessageDb.channel).distinct().order_by(MessageDb.channel) - return session.execute(select_stmt).scalars() + return (await session.execute(select_stmt)).scalars() -def get_forgotten_message( - session: DbSession, item_hash: str +async def get_forgotten_message( + session: AsyncDbSession, item_hash: str ) -> Optional[ForgottenMessageDb]: - return session.execute( - select(ForgottenMessageDb).where(ForgottenMessageDb.item_hash == item_hash) + return ( + await session.execute( + select(ForgottenMessageDb).where(ForgottenMessageDb.item_hash == item_hash) + ) ).scalar() -def forget_message( - session: DbSession, item_hash: str, forget_message_hash: str +async def forget_message( + session: AsyncDbSession, item_hash: str, forget_message_hash: str ) -> None: """ Marks a processed message as forgotten. @@ -400,22 +402,22 @@ def forget_message( literal(f"{{{forget_message_hash}}}"), ).where(MessageDb.item_hash == item_hash), ) - session.execute(copy_row_stmt) - session.execute( + await session.execute(copy_row_stmt) + await session.execute( update(MessageStatusDb) .values(status=MessageStatus.FORGOTTEN) .where(MessageStatusDb.item_hash == item_hash) ) - session.execute( + await session.execute( delete(message_confirmations).where( message_confirmations.c.item_hash == item_hash ) ) - session.execute(delete(MessageDb).where(MessageDb.item_hash == item_hash)) + await session.execute(delete(MessageDb).where(MessageDb.item_hash == item_hash)) -def append_to_forgotten_by( - session: DbSession, forgotten_message_hash: str, forget_message_hash: str +async def append_to_forgotten_by( + session: AsyncDbSession, forgotten_message_hash: str, forget_message_hash: str ) -> None: update_stmt = ( update(ForgottenMessageDb) @@ -426,7 +428,7 @@ def append_to_forgotten_by( ) ) ) - session.execute(update_stmt, {"forget_hash": forget_message_hash}) + await session.execute(update_stmt, {"forget_hash": forget_message_hash}) def make_upsert_rejected_message_statement( @@ -484,8 +486,8 @@ def ensure_serializable(obj): return upsert_rejected_message_stmt -def mark_pending_message_as_rejected( - session: DbSession, +async def mark_pending_message_as_rejected( + session: AsyncDbSession, item_hash: str, pending_message_dict: Mapping[str, Any], exception: BaseException, @@ -533,8 +535,8 @@ def mark_pending_message_as_rejected( tx_hash=tx_hash, ) - session.execute(upsert_status_stmt) - session.execute(upsert_rejected_message_stmt) + await session.execute(upsert_status_stmt) + await session.execute(upsert_rejected_message_stmt) return RejectedMessageDb( item_hash=item_hash, @@ -547,8 +549,8 @@ def mark_pending_message_as_rejected( @overload -def reject_new_pending_message( - session: DbSession, +async def reject_new_pending_message( + session: AsyncDbSession, pending_message: Mapping[str, Any], exception: BaseException, tx_hash: Optional[str], @@ -556,16 +558,16 @@ def reject_new_pending_message( @overload -def reject_new_pending_message( - session: DbSession, +async def reject_new_pending_message( + session: AsyncDbSession, pending_message: PendingMessageDb, exception: BaseException, tx_hash: Optional[str], ) -> None: ... -def reject_new_pending_message( - session: DbSession, +async def reject_new_pending_message( + session: AsyncDbSession, pending_message: Union[Mapping[str, Any], PendingMessageDb], exception: BaseException, tx_hash: Optional[str], @@ -600,12 +602,12 @@ def reject_new_pending_message( # Just do nothing if that is the case. We just consider the case where a previous # message with the same item hash was already sent to replace the error message # (ex: someone is retrying a message after fixing an error). - message_status = get_message_status(session=session, item_hash=item_hash) + message_status = await get_message_status(session=session, item_hash=item_hash) if message_status: if message_status.status != MessageStatus.REJECTED: return None - return mark_pending_message_as_rejected( + return await mark_pending_message_as_rejected( session=session, item_hash=item_hash, pending_message_dict=pending_message_dict, @@ -614,8 +616,8 @@ def reject_new_pending_message( ) -def reject_existing_pending_message( - session: DbSession, +async def reject_existing_pending_message( + session: AsyncDbSession, pending_message: PendingMessageDb, exception: BaseException, ) -> Optional[RejectedMessageDb]: @@ -623,10 +625,14 @@ def reject_existing_pending_message( # The message may already be processed and someone is sending invalid copies. # Just drop the pending message. - message_status = get_message_status(session=session, item_hash=ItemHash(item_hash)) + message_status = await get_message_status( + session=session, item_hash=ItemHash(item_hash) + ) if message_status: if message_status.status not in (MessageStatus.PENDING, MessageStatus.REJECTED): - delete_pending_message(session=session, pending_message=pending_message) + await delete_pending_message( + session=session, pending_message=pending_message + ) return None # TODO: use Pydantic schema @@ -643,18 +649,20 @@ def reject_existing_pending_message( ) pending_message_dict["time"] = pending_message_dict["time"].timestamp() - rejected_message = mark_pending_message_as_rejected( + rejected_message = await mark_pending_message_as_rejected( session=session, item_hash=item_hash, pending_message_dict=pending_message_dict, exception=exception, tx_hash=pending_message.tx_hash, ) - delete_pending_message(session=session, pending_message=pending_message) + await delete_pending_message(session=session, pending_message=pending_message) return rejected_message -def get_programs_triggered_by_messages(session: DbSession, sort_order: SortOrder): +async def get_programs_triggered_by_messages( + session: AsyncDbSession, sort_order: SortOrder +): time_column = MessageDb.time order_by_column = ( time_column.desc() if sort_order == SortOrder.DESCENDING else time_column.asc() @@ -674,7 +682,7 @@ def get_programs_triggered_by_messages(session: DbSession, sort_order: SortOrder .order_by(order_by_column) ) - return session.execute(select_stmt).all() + return (await session.execute(select_stmt)).all() def make_matching_hashes_query( @@ -724,19 +732,19 @@ def make_matching_hashes_query( return select_stmt -def get_matching_hashes( - session: DbSession, +async def get_matching_hashes( + session: AsyncDbSession, **kwargs, # Same as make_matching_hashes_query ): select_stmt = make_matching_hashes_query(**kwargs) - return (session.execute(select_stmt)).scalars() + return (await session.execute(select_stmt)).scalars() -def count_matching_hashes( - session: DbSession, +async def count_matching_hashes( + session: AsyncDbSession, pagination: int = 0, **kwargs, ) -> Select: - select_stmt = make_matching_hashes_query(pagination=0, **kwargs).subquery() + select_stmt = make_matching_hashes_query(pagination=pagination, **kwargs).subquery() select_count_stmt = select(func.count()).select_from(select_stmt) - return session.execute(select_count_stmt).scalar_one() + return (await session.execute(select_count_stmt)).scalar_one() diff --git a/src/aleph/db/accessors/metrics.py b/src/aleph/db/accessors/metrics.py index e8bf05900..10402d816 100644 --- a/src/aleph/db/accessors/metrics.py +++ b/src/aleph/db/accessors/metrics.py @@ -2,9 +2,8 @@ from typing import Optional from sqlalchemy import select, text -from sqlalchemy.orm.session import Session -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession def _parse_ccn_result(result): @@ -60,8 +59,8 @@ def _build_metric_filter(select_stmt, node_id, start_date, end_date, sort_order) return select_stmt -def query_metric_ccn( - session: Session, +async def query_metric_ccn( + session: AsyncDbSession, node_id: Optional[str] = None, start_date: Optional[float] = None, end_date: Optional[float] = None, @@ -95,13 +94,13 @@ def query_metric_ccn( sort_order=sort_order, ) - result = session.execute(select_stmt).fetchall() + result = (await session.execute(select_stmt)).fetchall() return _parse_ccn_result(result=result) -def query_metric_crn( - session: DbSession, +async def query_metric_crn( + session: AsyncDbSession, node_id: str, start_date: Optional[float] = None, end_date: Optional[float] = None, @@ -132,6 +131,6 @@ def query_metric_crn( sort_order=sort_order, ) - result = session.execute(select_stmt).fetchall() + result = (await session.execute(select_stmt)).fetchall() return _parse_crn_result(result=result) diff --git a/src/aleph/db/accessors/peers.py b/src/aleph/db/accessors/peers.py index 5d30533f9..6822ae76f 100644 --- a/src/aleph/db/accessors/peers.py +++ b/src/aleph/db/accessors/peers.py @@ -5,22 +5,22 @@ from sqlalchemy.dialects.postgresql import insert from aleph.toolkit.timestamp import utc_now -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession from ..models.peers import PeerDb, PeerType -def get_all_addresses_by_peer_type( - session: DbSession, peer_type: PeerType +async def get_all_addresses_by_peer_type( + session: AsyncDbSession, peer_type: PeerType ) -> Sequence[str]: select_peers_stmt = select(PeerDb.address).where(PeerDb.peer_type == peer_type) - addresses = session.execute(select_peers_stmt) + addresses = await session.execute(select_peers_stmt) return addresses.scalars().all() -def upsert_peer( - session: DbSession, +async def upsert_peer( + session: AsyncDbSession, peer_id: str, peer_type: PeerType, address: str, @@ -43,4 +43,4 @@ def upsert_peer( set_={"address": address, "source": source, "last_seen": last_seen}, ) ) - session.execute(upsert_stmt) + await session.execute(upsert_stmt) diff --git a/src/aleph/db/accessors/pending_messages.py b/src/aleph/db/accessors/pending_messages.py index 50b1c0a48..f4a6a5f62 100644 --- a/src/aleph/db/accessors/pending_messages.py +++ b/src/aleph/db/accessors/pending_messages.py @@ -1,17 +1,17 @@ import datetime as dt -from typing import Any, Collection, Dict, Iterable, Optional, Sequence +from typing import Any, Collection, Dict, Iterable, List, Optional, Sequence, Set from aleph_message.models import Chain -from sqlalchemy import delete, func, select, update +from sqlalchemy import and_, delete, func, not_, select, text, update from sqlalchemy.orm import selectinload from sqlalchemy.sql import Update from aleph.db.models import ChainTxDb, PendingMessageDb -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession, DbSession -def get_next_pending_message( - session: DbSession, +async def get_next_pending_message( + session: AsyncDbSession, current_time: dt.datetime, offset: int = 0, fetched: Optional[bool] = None, @@ -34,11 +34,11 @@ def get_next_pending_message( ) select_stmt = select_stmt.limit(1) - return (session.execute(select_stmt)).scalar_one_or_none() + return (await session.execute(select_stmt)).scalar_one_or_none() -def get_next_pending_messages( - session: DbSession, +async def get_next_pending_messages( + session: AsyncDbSession, current_time: dt.datetime, limit: int = 10000, offset: int = 0, @@ -62,18 +62,146 @@ def get_next_pending_messages( ) select_stmt = select_stmt.limit(limit) - return (session.execute(select_stmt)).scalars() + return (await session.execute(select_stmt)).scalars() -def get_pending_messages( - session: DbSession, item_hash: str +async def get_pending_messages( + session: AsyncDbSession, item_hash: str ) -> Iterable[PendingMessageDb]: select_stmt = ( select(PendingMessageDb) .order_by(PendingMessageDb.time) .where(PendingMessageDb.item_hash == item_hash) ) - return session.execute(select_stmt).scalars() + return (await session.execute(select_stmt)).scalars() + + +async def get_next_pending_messages_by_address( + session: AsyncDbSession, + current_time: dt.datetime, + fetched: Optional[bool] = None, + exclude_item_hashes: Optional[Set[str]] = None, + exclude_addresses: Optional[Set[str]] = None, + batch_size: int = 100, +) -> List[PendingMessageDb]: + # Step 1: Get the earliest pending message + base_stmt = ( + select(PendingMessageDb) + .where(PendingMessageDb.next_attempt <= current_time) + .order_by(PendingMessageDb.next_attempt.asc()) + .options(selectinload(PendingMessageDb.tx)) + .limit(1) + ) + + if fetched is not None: + base_stmt = base_stmt.where(PendingMessageDb.fetched == fetched) + + if exclude_item_hashes: # a non-empty set() + base_stmt = base_stmt.where( + PendingMessageDb.item_hash.not_in(exclude_item_hashes) + ) + + if exclude_addresses: + base_stmt = base_stmt.where( + PendingMessageDb.content["address"].astext.not_in(list(exclude_addresses)) + ) + + first_message = (await session.execute(base_stmt)).scalar_one_or_none() + + if ( + not first_message + or not first_message.content + or "address" not in first_message.content + ): + return [] + + address = first_message.content["address"] + + # Step 2: Get a batch of messages with that same address in content + match_stmt = ( + select(PendingMessageDb) + .where( + PendingMessageDb.next_attempt <= current_time, + PendingMessageDb.content["address"].astext == address, + ) + .order_by(PendingMessageDb.next_attempt.asc()) + .limit(batch_size) + ) + + if fetched is not None: + match_stmt = match_stmt.where(PendingMessageDb.fetched == fetched) + + if exclude_item_hashes: + match_stmt = match_stmt.where( + PendingMessageDb.item_hash.not_in(exclude_item_hashes) + ) + + return (await session.execute(match_stmt)).scalars().all() + + +async def async_get_next_pending_messages_by_address( + session: AsyncDbSession, + current_time: dt.datetime, + fetched: Optional[bool] = None, + exclude_item_hashes: Optional[Set[str]] = None, + exclude_addresses: Optional[Set[str]] = None, + batch_size: int = 100, +) -> List[PendingMessageDb]: + # Step 1: Get the earliest pending message + base_stmt = ( + select(PendingMessageDb) + .where(PendingMessageDb.next_attempt <= current_time) + .order_by(PendingMessageDb.next_attempt.asc()) + .options(selectinload(PendingMessageDb.tx)) + .limit(1) + ) + + if fetched is not None: + base_stmt = base_stmt.where(PendingMessageDb.fetched == fetched) + + if exclude_item_hashes: + base_stmt = base_stmt.where( + PendingMessageDb.item_hash.not_in(exclude_item_hashes) + ) + + if exclude_addresses: + base_stmt = base_stmt.where( + PendingMessageDb.content["address"].astext.not_in(list(exclude_addresses)) + ) + + result = await session.execute(base_stmt) + first_message = result.scalar_one_or_none() + + if ( + not first_message + or not first_message.content + or "address" not in first_message.content + ): + return [] + + address = first_message.content["address"] + + # Step 2: Get a batch of messages with that same address in content + match_stmt = ( + select(PendingMessageDb) + .where( + PendingMessageDb.next_attempt <= current_time, + PendingMessageDb.content["address"].astext == address, + ) + .order_by(PendingMessageDb.next_attempt.asc()) + .limit(batch_size) + ) + + if fetched is not None: + match_stmt = match_stmt.where(PendingMessageDb.fetched == fetched) + + if exclude_item_hashes: + match_stmt = match_stmt.where( + PendingMessageDb.item_hash.not_in(exclude_item_hashes) + ) + + result = await session.execute(match_stmt) + return result.scalars().all() def get_pending_message( @@ -85,7 +213,9 @@ def get_pending_message( return session.execute(select_stmt).scalar_one_or_none() -def count_pending_messages(session: DbSession, chain: Optional[Chain] = None) -> int: +async def count_pending_messages( + session: AsyncDbSession, chain: Optional[Chain] = None +) -> int: """ Counts pending messages. @@ -99,7 +229,7 @@ def count_pending_messages(session: DbSession, chain: Optional[Chain] = None) -> ChainTxDb, PendingMessageDb.tx_hash == ChainTxDb.hash ) - return (session.execute(select_stmt)).scalar_one() + return (await session.execute(select_stmt)).scalar_one() def make_pending_message_fetched_statement( @@ -113,20 +243,123 @@ def make_pending_message_fetched_statement( return update_stmt -def set_next_retry( - session: DbSession, pending_message: PendingMessageDb, next_attempt: dt.datetime +async def set_next_retry( + session: AsyncDbSession, + pending_message: PendingMessageDb, + next_attempt: dt.datetime, ) -> None: update_stmt = ( update(PendingMessageDb) .where(PendingMessageDb.id == pending_message.id) .values(retries=PendingMessageDb.retries + 1, next_attempt=next_attempt) ) - session.execute(update_stmt) + await session.execute(update_stmt) -def delete_pending_message( - session: DbSession, pending_message: PendingMessageDb +async def delete_pending_message( + session: AsyncDbSession, pending_message: PendingMessageDb ) -> None: - session.execute( + await session.execute( delete(PendingMessageDb).where(PendingMessageDb.id == pending_message.id) ) + + +async def async_get_next_pending_messages_from_different_senders( + session: AsyncDbSession, + current_time: dt.datetime, + fetched: bool = True, + exclude_item_hashes: Optional[Set[str]] = None, + exclude_addresses: Optional[Set[str]] = None, + limit: int = 40, # Maximum number of distinct senders to process in parallel +) -> List[PendingMessageDb]: + """ + Get pending messages from different senders to process in parallel. + + This optimized function maximizes parallelism by fetching messages with distinct + sender addresses directly from the database using JSONB operators, avoiding + additional processing in Python. + + Args: + session: Database session + current_time: Current time + fetched: Whether to only return messages that have been fetched + exclude_item_hashes: Item hashes to exclude (already being processed) + exclude_addresses: Sender addresses to exclude (already being processed) + limit: Maximum number of messages to return (one per unique sender) + + Returns: + List of pending messages from different senders, one per unique sender + """ + # In PostgreSQL, DISTINCT ON requires the first ORDER BY expression to match exactly + # Let's use a text-based SQL approach to ensure identical expressions + + # Build the SQL query directly for more precise control + sql_query = """ + SELECT DISTINCT ON (jsonb_extract_path_text(content, 'address')) * + FROM pending_messages + WHERE next_attempt <= :current_time + AND fetched = :fetched + AND content IS NOT NULL + AND jsonb_extract_path_text(content, 'address') IS NOT NULL + """ + + # Add exclusions if needed + params = {"current_time": current_time, "fetched": fetched} + + if exclude_item_hashes and len(exclude_item_hashes) > 0: + placeholder_names = [f":exclude_hash_{i}" for i in range(len(exclude_item_hashes))] + sql_query += f" AND item_hash NOT IN ({', '.join(placeholder_names)})" + for i, hash_value in enumerate(exclude_item_hashes): + params[f"exclude_hash_{i}"] = hash_value + + if exclude_addresses and len(exclude_addresses) > 0: + placeholder_names = [f":exclude_addr_{i}" for i in range(len(exclude_addresses))] + sql_query += f" AND jsonb_extract_path_text(content, 'address') NOT IN ({', '.join(placeholder_names)})" + for i, addr in enumerate(exclude_addresses): + params[f"exclude_addr_{i}"] = addr + + # Add the ORDER BY clause - the first expression MUST match the DISTINCT ON expression exactly + sql_query += """ + ORDER BY jsonb_extract_path_text(content, 'address'), next_attempt + LIMIT :limit + """ + params["limit"] = limit + + # Execute the raw SQL query + result = await session.execute(select(PendingMessageDb).from_statement(text(sql_query)).params(**params)) + return result.scalars().all() + + +async def async_get_next_pending_message( + session: AsyncDbSession, + current_time: dt.datetime, + offset: int = 0, + fetched: Optional[bool] = None, + exclude_item_hashes: Optional[Sequence[str]] = None, + exclude_addresses: Optional[Sequence[str]] = None, +) -> Optional[PendingMessageDb]: + select_stmt = ( + select(PendingMessageDb) + .order_by(PendingMessageDb.next_attempt.asc()) + .offset(offset) + .options(selectinload(PendingMessageDb.tx)) + .where(PendingMessageDb.next_attempt <= current_time) + ) + + if fetched is not None: + select_stmt = select_stmt.where(PendingMessageDb.fetched == fetched) + + if exclude_item_hashes: + select_stmt = select_stmt.where( + PendingMessageDb.item_hash.not_in(exclude_item_hashes) + ) + + if exclude_addresses: + select_stmt = select_stmt.where( + PendingMessageDb.content["address"].astext.not_in(list(exclude_addresses)) + ) + + select_stmt = select_stmt.limit(1) + + result = await session.execute(select_stmt) + return result.scalar_one_or_none() diff --git a/src/aleph/db/accessors/pending_txs.py b/src/aleph/db/accessors/pending_txs.py index 5d1ce9f83..8082a5baa 100644 --- a/src/aleph/db/accessors/pending_txs.py +++ b/src/aleph/db/accessors/pending_txs.py @@ -6,19 +6,23 @@ from sqlalchemy.orm import selectinload from aleph.db.models import ChainTxDb, PendingTxDb -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession -def get_pending_tx(session: DbSession, tx_hash: str) -> Optional[PendingTxDb]: +async def get_pending_tx( + session: AsyncDbSession, tx_hash: str +) -> Optional[PendingTxDb]: select_stmt = ( select(PendingTxDb) .where(PendingTxDb.tx_hash == tx_hash) .options(selectinload(PendingTxDb.tx)) ) - return (session.execute(select_stmt)).scalar_one_or_none() + return (await session.execute(select_stmt)).scalar_one_or_none() -def get_pending_txs(session: DbSession, limit: int = 200) -> Iterable[PendingTxDb]: +async def get_pending_txs( + session: AsyncDbSession, limit: int = 200 +) -> Iterable[PendingTxDb]: select_stmt = ( select(PendingTxDb) .join(ChainTxDb, PendingTxDb.tx_hash == ChainTxDb.hash) @@ -26,24 +30,26 @@ def get_pending_txs(session: DbSession, limit: int = 200) -> Iterable[PendingTxD .limit(limit) .options(selectinload(PendingTxDb.tx)) ) - return (session.execute(select_stmt)).scalars() + return (await session.execute(select_stmt)).scalars() -def count_pending_txs(session: DbSession, chain: Optional[Chain] = None) -> int: +async def count_pending_txs( + session: AsyncDbSession, chain: Optional[Chain] = None +) -> int: select_stmt = select(func.count(PendingTxDb.tx_hash)) if chain: select_stmt = select_stmt.join( ChainTxDb, PendingTxDb.tx_hash == ChainTxDb.hash ).where(ChainTxDb.chain == chain) - return (session.execute(select_stmt)).scalar_one() + return (await session.execute(select_stmt)).scalar_one() -def upsert_pending_tx(session: DbSession, tx_hash: str) -> None: +async def upsert_pending_tx(session: AsyncDbSession, tx_hash: str) -> None: upsert_stmt = insert(PendingTxDb).values(tx_hash=tx_hash).on_conflict_do_nothing() - session.execute(upsert_stmt) + await session.execute(upsert_stmt) -def delete_pending_tx(session: DbSession, tx_hash: str) -> None: +async def delete_pending_tx(session: AsyncDbSession, tx_hash: str) -> None: delete_stmt = delete(PendingTxDb).where(PendingTxDb.tx_hash == tx_hash) - session.execute(delete_stmt) + await session.execute(delete_stmt) diff --git a/src/aleph/db/accessors/posts.py b/src/aleph/db/accessors/posts.py index cc1b8f57f..66b0e86bc 100644 --- a/src/aleph/db/accessors/posts.py +++ b/src/aleph/db/accessors/posts.py @@ -33,7 +33,7 @@ from aleph.db.models.posts import PostDb from aleph.toolkit.timestamp import coerce_to_datetime from aleph.types.channel import Channel -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession from aleph.types.sort_order import SortBy, SortOrder @@ -154,18 +154,20 @@ def make_select_merged_post_with_message_info_stmt() -> Select: return select_merged_post_stmt -def get_post(session: DbSession, item_hash: str) -> Optional[MergedPost]: +async def get_post(session: AsyncDbSession, item_hash: str) -> Optional[MergedPost]: select_stmt = make_select_merged_post_stmt() select_stmt = select_stmt.where(Original.item_hash == str(item_hash)) - return session.execute(select_stmt).one_or_none() + return (await session.execute(select_stmt)).one_or_none() -def get_original_post(session: DbSession, item_hash: str) -> Optional[PostDb]: +async def get_original_post( + session: AsyncDbSession, item_hash: str +) -> Optional[PostDb]: select_stmt = select(PostDb).where(PostDb.item_hash == item_hash) - return session.execute(select_stmt).scalar() + return (await session.execute(select_stmt)).scalar_one_or_none() -def refresh_latest_amend(session: DbSession, item_hash: str) -> None: +async def refresh_latest_amend(session: AsyncDbSession, item_hash: str) -> None: select_latest_amend = ( select( PostDb.amends, func.max(PostDb.creation_datetime).label("creation_datetime") @@ -181,7 +183,7 @@ def refresh_latest_amend(session: DbSession, item_hash: str) -> None: & (PostDb.creation_datetime == select_latest_amend.c.creation_datetime), ) - latest_amend_hash = session.execute(select_stmt).scalar() + latest_amend_hash = (await session.execute(select_stmt)).scalar() update_stmt = ( update(PostDb) @@ -189,7 +191,7 @@ def refresh_latest_amend(session: DbSession, item_hash: str) -> None: .values(latest_amend=latest_amend_hash) ) - session.execute(update_stmt) + await session.execute(update_stmt) def filter_post_select_stmt( @@ -294,8 +296,8 @@ def filter_post_select_stmt( return select_stmt -def count_matching_posts( - session: DbSession, +async def count_matching_posts( + session: AsyncDbSession, page: int = 1, pagination: int = 0, sort_by: SortBy = SortBy.TIME, @@ -322,34 +324,36 @@ def count_matching_posts( select_stmt = select(PostDb).where(PostDb.amends.is_(None)).subquery() select_count_stmt = select(func.count()).select_from(select_stmt) - return session.execute(select_count_stmt).scalar_one() + return (await session.execute(select_count_stmt)).scalar_one() -def get_matching_posts_legacy( - session: DbSession, +async def get_matching_posts_legacy( + session: AsyncDbSession, # Same as make_matching_posts_query **kwargs, ) -> List[MergedPostV0]: select_stmt = make_select_merged_post_with_message_info_stmt() filtered_select_stmt = filter_post_select_stmt(select_stmt, **kwargs) - return cast(List[MergedPostV0], session.execute(filtered_select_stmt).all()) + return cast(List[MergedPostV0], (await session.execute(filtered_select_stmt)).all()) -def get_matching_posts( - session: DbSession, +async def get_matching_posts( + session: AsyncDbSession, # Same as make_matching_posts_query **kwargs, ) -> List[MergedPost]: select_stmt = make_select_merged_post_stmt() filtered_select_stmt = filter_post_select_stmt(select_stmt, **kwargs) - return cast(List[MergedPost], session.execute(filtered_select_stmt).all()) + return cast(List[MergedPost], (await session.execute(filtered_select_stmt)).all()) -def delete_amends(session: DbSession, item_hash: str) -> Iterable[str]: - return session.execute( - delete(PostDb).where(PostDb.amends == item_hash).returning(PostDb.item_hash) +async def delete_amends(session: AsyncDbSession, item_hash: str) -> Iterable[str]: + return ( + await session.execute( + delete(PostDb).where(PostDb.amends == item_hash).returning(PostDb.item_hash) + ) ).scalars() -def delete_post(session: DbSession, item_hash: str) -> None: - session.execute(delete(PostDb).where(PostDb.item_hash == item_hash)) +async def delete_post(session: AsyncDbSession, item_hash: str) -> None: + await session.execute(delete(PostDb).where(PostDb.item_hash == item_hash)) diff --git a/src/aleph/db/accessors/vms.py b/src/aleph/db/accessors/vms.py index 480d9ce7b..ab23a3560 100644 --- a/src/aleph/db/accessors/vms.py +++ b/src/aleph/db/accessors/vms.py @@ -15,54 +15,60 @@ VmInstanceDb, VmVersionDb, ) -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession from aleph.types.vms import VmVersion -def get_instance(session: DbSession, item_hash: str) -> Optional[VmInstanceDb]: +async def get_instance( + session: AsyncDbSession, item_hash: str +) -> Optional[VmInstanceDb]: select_stmt = select(VmInstanceDb).where(VmInstanceDb.item_hash == item_hash) - return session.execute(select_stmt).scalar_one_or_none() + return (await session.execute(select_stmt)).scalar_one_or_none() -def get_program(session: DbSession, item_hash: str) -> Optional[ProgramDb]: +async def get_program(session: AsyncDbSession, item_hash: str) -> Optional[ProgramDb]: select_stmt = select(ProgramDb).where(ProgramDb.item_hash == item_hash) - return session.execute(select_stmt).scalar_one_or_none() + return (await session.execute(select_stmt)).scalar_one_or_none() -def is_vm_amend_allowed(session: DbSession, vm_hash: str) -> Optional[bool]: +async def is_vm_amend_allowed(session: AsyncDbSession, vm_hash: str) -> Optional[bool]: select_stmt = ( select(VmBaseDb.allow_amend) .select_from(VmVersionDb) .join(VmBaseDb, VmVersionDb.current_version == VmBaseDb.item_hash) .where(VmVersionDb.vm_hash == vm_hash) ) - return session.execute(select_stmt).scalar_one_or_none() + return (await session.execute(select_stmt)).scalar_one_or_none() -def _delete_vm(session: DbSession, where) -> Iterable[str]: +async def _delete_vm(session: AsyncDbSession, where) -> Iterable[str]: # Deletion of volumes is managed automatically by the DB # using an "on delete cascade" foreign key. - return session.execute( - delete(VmBaseDb).where(where).returning(VmBaseDb.item_hash) + return ( + await session.execute( + delete(VmBaseDb).where(where).returning(VmBaseDb.item_hash) + ) ).scalars() -def delete_vm(session: DbSession, vm_hash: str) -> None: - _ = _delete_vm(session=session, where=VmBaseDb.item_hash == vm_hash) +async def delete_vm(session: AsyncDbSession, vm_hash: str) -> None: + _ = await _delete_vm(session=session, where=VmBaseDb.item_hash == vm_hash) -def delete_vm_updates(session: DbSession, vm_hash: str) -> Iterable[str]: - return _delete_vm(session=session, where=VmBaseDb.replaces == vm_hash) +async def delete_vm_updates(session: AsyncDbSession, vm_hash: str) -> Iterable[str]: + return await _delete_vm(session=session, where=VmBaseDb.replaces == vm_hash) -def get_vm_version(session: DbSession, vm_hash: str) -> Optional[VmVersionDb]: - return session.execute( - select(VmVersionDb).where(VmVersionDb.vm_hash == vm_hash) +async def get_vm_version( + session: AsyncDbSession, vm_hash: str +) -> Optional[VmVersionDb]: + return ( + await session.execute(select(VmVersionDb).where(VmVersionDb.vm_hash == vm_hash)) ).scalar_one_or_none() -def get_vms_dependent_volumes( - session: DbSession, volume_hash: str +async def get_vms_dependent_volumes( + session: AsyncDbSession, volume_hash: str ) -> Optional[VmBaseDb]: statement = ( select(VmBaseDb) @@ -93,11 +99,11 @@ def get_vms_dependent_volumes( ) ) ) - return session.execute(statement).scalar_one_or_none() + return (await session.execute(statement)).scalar_one_or_none() -def upsert_vm_version( - session: DbSession, +async def upsert_vm_version( + session: AsyncDbSession, vm_hash: str, owner: str, current_version: VmVersion, @@ -114,10 +120,10 @@ def upsert_vm_version( set_={"current_version": current_version, "last_updated": last_updated}, where=VmVersionDb.last_updated < last_updated, ) - session.execute(upsert_stmt) + await session.execute(upsert_stmt) -def refresh_vm_version(session: DbSession, vm_hash: str) -> None: +async def refresh_vm_version(session: AsyncDbSession, vm_hash: str) -> None: coalesced_ref = func.coalesce(VmBaseDb.replaces, VmBaseDb.item_hash) select_latest_revision_stmt = ( select( @@ -151,5 +157,5 @@ def refresh_vm_version(session: DbSession, vm_hash: str) -> None: "last_updated": insert_stmt.excluded.last_updated, }, ) - session.execute(delete(VmVersionDb).where(VmVersionDb.vm_hash == vm_hash)) - session.execute(upsert_stmt) + await session.execute(delete(VmVersionDb).where(VmVersionDb.vm_hash == vm_hash)) + await session.execute(upsert_stmt) diff --git a/src/aleph/db/connection.py b/src/aleph/db/connection.py index 85f2cad6e..2a36ebaca 100644 --- a/src/aleph/db/connection.py +++ b/src/aleph/db/connection.py @@ -68,10 +68,16 @@ def make_async_engine( echo: bool = False, application_name: Optional[str] = None, ) -> AsyncEngine: + if config is None: + config = get_config() + return create_async_engine( - make_db_url(driver="asyncpg", config=config, application_name=application_name), + make_db_url(driver="asyncpg", config=config), future=True, echo=echo, + pool_size=40, + max_overflow=20, + pool_recycle=3600, ) diff --git a/src/aleph/db/models/base.py b/src/aleph/db/models/base.py index 7dfe16781..84f4af9f5 100644 --- a/src/aleph/db/models/base.py +++ b/src/aleph/db/models/base.py @@ -3,7 +3,7 @@ from sqlalchemy import Column, Table, exists, func, select, text from sqlalchemy.orm import declarative_base -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession class AugmentedBase: @@ -20,13 +20,13 @@ def to_dict(self, exclude: Optional[Set[str]] = None) -> Dict[str, Any]: } @classmethod - def count(cls, session: DbSession) -> int: + async def count(cls, session: AsyncDbSession) -> int: return ( - session.execute(text(f"SELECT COUNT(*) FROM {cls.__tablename__}")) + await session.execute(text(f"SELECT COUNT(*) FROM {cls.__tablename__}")) ).scalar_one() @classmethod - def estimate_count(cls, session: DbSession) -> int: + async def estimate_count(cls, session: AsyncDbSession) -> int: """ Returns an approximation of the number of rows in a table. @@ -39,37 +39,41 @@ def estimate_count(cls, session: DbSession) -> int: has never been analyzed or vacuumed. """ - return session.execute( - text( - f"SELECT reltuples::bigint FROM pg_class WHERE relname = '{cls.__tablename__}'" + return ( + await session.execute( + text( + f"SELECT reltuples::bigint FROM pg_class WHERE relname = '{cls.__tablename__}'" + ) ) ).scalar_one() @classmethod - def fast_count(cls, session: DbSession) -> int: + async def fast_count(cls, session: AsyncDbSession) -> int: """ :param session: DB session. :return: The estimate count of the table if available from pg_class, otherwise the real count of rows. """ - estimate_count = cls.estimate_count(session) + estimate_count = await cls.estimate_count(session) if estimate_count == -1: - return cls.count(session) + return await cls.count(session) return estimate_count # TODO: set type of "where" to the SQLA boolean expression class @classmethod - def exists(cls, session: DbSession, where) -> bool: + async def exists(cls, session: AsyncDbSession, where) -> bool: exists_stmt = exists(text("1")).select().where(where) - result = (session.execute(exists_stmt)).scalar() + result = (await session.execute(exists_stmt)).scalar() return result is not None @classmethod - def jsonb_keys(cls, session: DbSession, column: Column, where) -> Iterable[str]: + async def jsonb_keys( + cls, session: AsyncDbSession, column: Column, where + ) -> Iterable[str]: select_stmt = select(func.jsonb_object_keys(column)).where(where) - return session.execute(select_stmt).scalars() + return (await session.execute(select_stmt)).scalars() Base = declarative_base(cls=AugmentedBase) diff --git a/src/aleph/handlers/content/aggregate.py b/src/aleph/handlers/content/aggregate.py index 32a5859f8..1d12b6933 100644 --- a/src/aleph/handlers/content/aggregate.py +++ b/src/aleph/handlers/content/aggregate.py @@ -20,7 +20,7 @@ from aleph.db.models import AggregateDb, AggregateElementDb, MessageDb from aleph.handlers.content.content_handler import ContentHandler from aleph.toolkit.timestamp import timestamp_to_datetime -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession from aleph.types.message_status import InvalidMessageFormat LOGGER = logging.getLogger(__name__) @@ -37,13 +37,13 @@ def _get_aggregate_content(message: MessageDb) -> AggregateContent: class AggregateMessageHandler(ContentHandler): async def fetch_related_content( - self, session: DbSession, message: MessageDb + self, session: AsyncDbSession, message: MessageDb ) -> None: # Nothing to do, aggregates are independent of one another return @staticmethod - async def _insert_aggregate_element(session: DbSession, message: MessageDb): + async def _insert_aggregate_element(session: AsyncDbSession, message: MessageDb): content = cast(AggregateContent, message.parsed_content) aggregate_element = AggregateElementDb( item_hash=message.item_hash, @@ -53,7 +53,7 @@ async def _insert_aggregate_element(session: DbSession, message: MessageDb): creation_datetime=timestamp_to_datetime(message.parsed_content.time), ) - insert_aggregate_element( + await insert_aggregate_element( session=session, item_hash=aggregate_element.item_hash, key=aggregate_element.key, @@ -66,13 +66,13 @@ async def _insert_aggregate_element(session: DbSession, message: MessageDb): @staticmethod async def _append_to_aggregate( - session: DbSession, + session: AsyncDbSession, aggregate: AggregateDb, elements: Sequence[AggregateElementDb], ): - new_content = merge_aggregate_elements(elements) + new_content = await merge_aggregate_elements(elements) - update_aggregate( + await update_aggregate( session=session, key=aggregate.key, owner=aggregate.owner, @@ -83,13 +83,13 @@ async def _append_to_aggregate( @staticmethod async def _prepend_to_aggregate( - session: DbSession, + session: AsyncDbSession, aggregate: AggregateDb, elements: Sequence[AggregateElementDb], ): - new_content = merge_aggregate_elements(elements) + new_content = await merge_aggregate_elements(elements) - update_aggregate( + await update_aggregate( session=session, key=aggregate.key, owner=aggregate.owner, @@ -101,7 +101,7 @@ async def _prepend_to_aggregate( async def _update_aggregate( self, - session: DbSession, + session: AsyncDbSession, key: str, owner: str, elements: Sequence[AggregateElementDb], @@ -122,15 +122,15 @@ async def _update_aggregate( dirty_threshold = 1000 - aggregate_metadata = get_aggregate_by_key( + aggregate_metadata = await get_aggregate_by_key( session=session, owner=owner, key=key, with_content=False ) if not aggregate_metadata: LOGGER.info("%s/%s does not exist, creating it", key, owner) - content = merge_aggregate_elements(elements) - insert_aggregate( + content = await merge_aggregate_elements(elements) + await insert_aggregate( session=session, key=key, owner=owner, @@ -168,7 +168,9 @@ async def _update_aggregate( # Last chance before a full refresh, check the keys of the aggregate # and determine if there's a conflict. - keys = set(get_aggregate_content_keys(session=session, key=key, owner=owner)) + keys = set( + await get_aggregate_content_keys(session=session, key=key, owner=owner) + ) new_keys = set(itertools.chain(element.content.keys for element in elements)) conflicting_keys = keys & new_keys @@ -189,11 +191,11 @@ async def _update_aggregate( return if ( - count_aggregate_elements(session=session, owner=owner, key=key) + await count_aggregate_elements(session=session, owner=owner, key=key) > dirty_threshold ): LOGGER.info("%s/%s: too many elements, marking as dirty") - mark_aggregate_as_dirty(session=session, owner=owner, key=key) + await mark_aggregate_as_dirty(session=session, owner=owner, key=key) return # Out of order insertions. Here, we need to get all the elements in the database @@ -201,10 +203,10 @@ async def _update_aggregate( # large aggregates, so we do it as a last resort. # Expect the new elements to already be added to the current session. # We flush it to make them accessible from the current transaction. - session.flush() - refresh_aggregate(session=session, owner=owner, key=key) + await session.flush() + await refresh_aggregate(session=session, owner=owner, key=key) - async def process(self, session: DbSession, messages: List[MessageDb]) -> None: + async def process(self, session: AsyncDbSession, messages: List[MessageDb]) -> None: sorted_messages = sorted( messages, key=lambda m: (m.parsed_content.key, m.parsed_content.address, m.time), @@ -223,16 +225,18 @@ async def process(self, session: DbSession, messages: List[MessageDb]) -> None: session=session, key=key, owner=owner, elements=aggregate_elements ) - async def forget_message(self, session: DbSession, message: MessageDb) -> Set[str]: + async def forget_message( + self, session: AsyncDbSession, message: MessageDb + ) -> Set[str]: content = _get_aggregate_content(message) owner = content.address key = content.key LOGGER.debug("Deleting aggregate element %s...", message.item_hash) - delete_aggregate(session=session, owner=owner, key=str(key)) - delete_aggregate_element(session=session, item_hash=message.item_hash) + await delete_aggregate(session=session, owner=owner, key=str(key)) + await delete_aggregate_element(session=session, item_hash=message.item_hash) LOGGER.debug("Refreshing aggregate %s/%s...", owner, key) - refresh_aggregate(session=session, owner=owner, key=str(key)) + await refresh_aggregate(session=session, owner=owner, key=str(key)) return set() diff --git a/src/aleph/handlers/content/content_handler.py b/src/aleph/handlers/content/content_handler.py index c055752c6..9edcad8aa 100644 --- a/src/aleph/handlers/content/content_handler.py +++ b/src/aleph/handlers/content/content_handler.py @@ -4,12 +4,12 @@ from aleph.db.models import MessageDb from aleph.db.models.account_costs import AccountCostsDb from aleph.permissions import check_sender_authorization -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession class ContentHandler(abc.ABC): async def fetch_related_content( - self, session: DbSession, message: MessageDb + self, session: AsyncDbSession, message: MessageDb ) -> None: """ Fetch additional content from the network based on the content of a message. @@ -24,7 +24,7 @@ async def fetch_related_content( pass async def is_related_content_fetched( - self, session: DbSession, message: MessageDb + self, session: AsyncDbSession, message: MessageDb ) -> bool: """ Check whether the additional network content mentioned in the message @@ -38,7 +38,7 @@ async def is_related_content_fetched( return True @abc.abstractmethod - async def process(self, session: DbSession, messages: List[MessageDb]) -> None: + async def process(self, session: AsyncDbSession, messages: List[MessageDb]) -> None: """ Process several messages of the same type and applies the resulting changes. @@ -49,7 +49,7 @@ async def process(self, session: DbSession, messages: List[MessageDb]) -> None: pass async def check_balance( - self, session: DbSession, message: MessageDb + self, session: AsyncDbSession, message: MessageDb ) -> List[AccountCostsDb] | None: """ Checks whether the user has enough Aleph tokens to process the message. @@ -62,7 +62,9 @@ async def check_balance( """ pass - async def check_dependencies(self, session: DbSession, message: MessageDb) -> None: + async def check_dependencies( + self, session: AsyncDbSession, message: MessageDb + ) -> None: """ Check dependencies of a message. @@ -76,7 +78,9 @@ async def check_dependencies(self, session: DbSession, message: MessageDb) -> No """ pass - async def check_permissions(self, session: DbSession, message: MessageDb) -> None: + async def check_permissions( + self, session: AsyncDbSession, message: MessageDb + ) -> None: """ Check user permissions. @@ -91,7 +95,9 @@ async def check_permissions(self, session: DbSession, message: MessageDb) -> Non await check_sender_authorization(session=session, message=message) @abc.abstractmethod - async def forget_message(self, session: DbSession, message: MessageDb) -> Set[str]: + async def forget_message( + self, session: AsyncDbSession, message: MessageDb + ) -> Set[str]: """ Clean up message-type specific objects when forgetting a message. diff --git a/src/aleph/handlers/content/forget.py b/src/aleph/handlers/content/forget.py index 99dfefa83..bc003983d 100644 --- a/src/aleph/handlers/content/forget.py +++ b/src/aleph/handlers/content/forget.py @@ -17,7 +17,7 @@ from aleph.db.accessors.vms import get_vms_dependent_volumes from aleph.db.models import AggregateElementDb, MessageDb from aleph.handlers.content.content_handler import ContentHandler -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession from aleph.types.message_status import ( CannotForgetForgetMessage, ForgetNotAllowed, @@ -40,7 +40,9 @@ def __init__( self.content_handlers = content_handlers self.content_handlers[MessageType.forget] = self - async def check_dependencies(self, session: DbSession, message: MessageDb) -> None: + async def check_dependencies( + self, session: AsyncDbSession, message: MessageDb + ) -> None: """ We only consider FORGETs as fetched if the messages / aggregates they target already exist. Otherwise, we retry them later. @@ -57,12 +59,12 @@ async def check_dependencies(self, session: DbSession, message: MessageDb) -> No raise NoForgetTarget() for item_hash in content.hashes: - if not message_exists(session=session, item_hash=item_hash): + if not await message_exists(session=session, item_hash=item_hash): raise ForgetTargetNotFound(item_hash) # Check file references, on VM volumes, as data volume and as code volume # to block the deletion if we found ones - dependent_volumes = get_vms_dependent_volumes( + dependent_volumes = await get_vms_dependent_volumes( session=session, volume_hash=item_hash ) print(dependent_volumes, item_hash) @@ -72,33 +74,33 @@ async def check_dependencies(self, session: DbSession, message: MessageDb) -> No ) for aggregate_key in content.aggregates: - if not aggregate_exists( + if not await aggregate_exists( session=session, key=aggregate_key, owner=content.address ): raise ForgetTargetNotFound(aggregate_key=aggregate_key) @staticmethod async def _list_target_messages( - session: DbSession, forget_message: MessageDb + session: AsyncDbSession, forget_message: MessageDb ) -> Sequence[ItemHash]: content = cast(ForgetContent, forget_message.parsed_content) aggregate_messages_to_forget: List[ItemHash] = [] for aggregate in content.aggregates: # TODO: write accessor + result = await session.execute( + select(AggregateElementDb.item_hash).where( + (AggregateElementDb.key == aggregate) + & (AggregateElementDb.owner == content.address) + ) + ) aggregate_messages_to_forget.extend( - ItemHash(value) - for value in session.execute( - select(AggregateElementDb.item_hash).where( - (AggregateElementDb.key == aggregate) - & (AggregateElementDb.owner == content.address) - ) - ).scalars() + ItemHash(value) for value in result.scalars() ) return content.hashes + aggregate_messages_to_forget - async def check_permissions(self, session: DbSession, message: MessageDb): + async def check_permissions(self, session: AsyncDbSession, message: MessageDb): await super().check_permissions(session=session, message=message) # Check that the sender owns the objects it is attempting to forget @@ -106,7 +108,9 @@ async def check_permissions(self, session: DbSession, message: MessageDb): session=session, forget_message=message ) for target_hash in target_hashes: - target_status = get_message_status(session=session, item_hash=target_hash) + target_status = await get_message_status( + session=session, item_hash=target_hash + ) if not target_status: raise ForgetTargetNotFound(target_hash=target_hash) @@ -119,7 +123,7 @@ async def check_permissions(self, session: DbSession, message: MessageDb): if target_status.status != MessageStatus.PROCESSED: raise ForgetTargetNotFound(target_hash=target_hash) - target_message = get_message_by_item_hash( + target_message = await get_message_by_item_hash( session=session, item_hash=target_hash ) if not target_message: @@ -139,7 +143,7 @@ async def check_permissions(self, session: DbSession, message: MessageDb): ) async def _forget_by_message_type( - self, session: DbSession, message: MessageDb + self, session: AsyncDbSession, message: MessageDb ) -> Set[str]: """ When processing a FORGET message, performs additional cleanup depending @@ -149,10 +153,10 @@ async def _forget_by_message_type( return await content_handler.forget_message(session=session, message=message) async def _forget_message( - self, session: DbSession, message: MessageDb, forgotten_by: MessageDb + self, session: AsyncDbSession, message: MessageDb, forgotten_by: MessageDb ): # Mark the message as forgotten - forget_message( + await forget_message( session=session, item_hash=message.item_hash, forget_message_hash=forgotten_by.item_hash, @@ -162,16 +166,16 @@ async def _forget_message( session=session, message=message ) for item_hash in additional_messages_to_forget: - forget_message( + await forget_message( session=session, item_hash=item_hash, forget_message_hash=forgotten_by.item_hash, ) async def _forget_item_hash( - self, session: DbSession, item_hash: str, forgotten_by: MessageDb + self, session: AsyncDbSession, item_hash: str, forgotten_by: MessageDb ): - message_status = get_message_status( + message_status = await get_message_status( session=session, item_hash=ItemHash(item_hash) ) if not message_status: @@ -181,7 +185,7 @@ async def _forget_item_hash( logger.info("Message %s was rejected, nothing to do.", item_hash) if message_status.status == MessageStatus.FORGOTTEN: logger.info("Message %s is already forgotten, nothing to do.", item_hash) - append_to_forgotten_by( + await append_to_forgotten_by( session=session, forgotten_message_hash=item_hash, forget_message_hash=forgotten_by.item_hash, @@ -196,7 +200,7 @@ async def _forget_item_hash( ) raise ForgetTargetNotFound(item_hash) - message = get_message_by_item_hash( + message = await get_message_by_item_hash( session=session, item_hash=ItemHash(item_hash) ) if not message: @@ -214,7 +218,9 @@ async def _forget_item_hash( forgotten_by=forgotten_by, ) - async def _process_forget_message(self, session: DbSession, message: MessageDb): + async def _process_forget_message( + self, session: AsyncDbSession, message: MessageDb + ): hashes_to_forget = await self._list_target_messages( session=session, forget_message=message @@ -225,7 +231,7 @@ async def _process_forget_message(self, session: DbSession, message: MessageDb): session=session, item_hash=item_hash, forgotten_by=message ) - async def process(self, session: DbSession, messages: List[MessageDb]) -> None: + async def process(self, session: AsyncDbSession, messages: List[MessageDb]) -> None: # FORGET: # 0. Check permissions: separate step now @@ -237,5 +243,7 @@ async def process(self, session: DbSession, messages: List[MessageDb]) -> None: for message in messages: await self._process_forget_message(session=session, message=message) - async def forget_message(self, session: DbSession, message: MessageDb) -> Set[str]: + async def forget_message( + self, session: AsyncDbSession, message: MessageDb + ) -> Set[str]: raise CannotForgetForgetMessage(target_hash=message.item_hash) diff --git a/src/aleph/handlers/content/post.py b/src/aleph/handlers/content/post.py index 3d306e1f0..aa33d03a8 100644 --- a/src/aleph/handlers/content/post.py +++ b/src/aleph/handlers/content/post.py @@ -15,7 +15,7 @@ from aleph.db.models.messages import MessageDb from aleph.db.models.posts import PostDb from aleph.toolkit.timestamp import timestamp_to_datetime -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession from aleph.types.message_status import ( AmendTargetNotFound, CannotAmendAmend, @@ -38,7 +38,7 @@ def get_post_content(message: MessageDb) -> PostContent: return content -def update_balances(session: DbSession, content: Mapping[str, Any]) -> None: +async def update_balances(session: AsyncDbSession, content: Mapping[str, Any]) -> None: try: chain = Chain(content["chain"]) height = content["main_height"] @@ -51,7 +51,7 @@ def update_balances(session: DbSession, content: Mapping[str, Any]) -> None: LOGGER.info("Updating balances for %s (dapp: %s)", chain, dapp) balances: Dict[str, float] = content["balances"] - update_balances_db( + await update_balances_db( session=session, chain=chain, dapp=dapp, @@ -85,7 +85,7 @@ def __init__(self, balances_addresses: List[str], balances_post_type: str): self.balances_addresses = balances_addresses self.balances_post_type = balances_post_type - async def check_dependencies(self, session: DbSession, message: MessageDb): + async def check_dependencies(self, session: AsyncDbSession, message: MessageDb): content = get_post_content(message) # For amends, ensure that the original message exists @@ -95,14 +95,14 @@ async def check_dependencies(self, session: DbSession, message: MessageDb): if ref is None: raise NoAmendTarget() - original_post = get_original_post(session=session, item_hash=ref) + original_post = await get_original_post(session=session, item_hash=ref) if not original_post: raise AmendTargetNotFound() if original_post.type == "amend": raise CannotAmendAmend() - async def process_post(self, session: DbSession, message: MessageDb): + async def process_post(self, session: AsyncDbSession, message: MessageDb): content = get_post_content(message) creation_datetime = timestamp_to_datetime(content.time) @@ -121,10 +121,10 @@ async def process_post(self, session: DbSession, message: MessageDb): session.add(post) if content.type == "amend": - [amended_post] = get_matching_posts(session=session, hashes=[ref]) + [amended_post] = await get_matching_posts(session=session, hashes=[ref]) if amended_post.last_updated < creation_datetime: - session.execute( + await session.execute( update(PostDb) .where(PostDb.item_hash == ref) .values(latest_amend=message.item_hash) @@ -136,29 +136,31 @@ async def process_post(self, session: DbSession, message: MessageDb): and content.content ): LOGGER.info("Updating balances...") - update_balances(session=session, content=content.content) + await update_balances(session=session, content=content.content) LOGGER.info("Done updating balances") - async def process(self, session: DbSession, messages: List[MessageDb]) -> None: + async def process(self, session: AsyncDbSession, messages: List[MessageDb]) -> None: for message in messages: await self.process_post(session=session, message=message) - async def forget_message(self, session: DbSession, message: MessageDb) -> Set[str]: + async def forget_message( + self, session: AsyncDbSession, message: MessageDb + ) -> Set[str]: content = get_post_content(message) LOGGER.debug("Deleting post %s...", message.item_hash) - amend_hashes = delete_amends(session=session, item_hash=message.item_hash) - delete_post(session=session, item_hash=message.item_hash) + amend_hashes = await delete_amends(session=session, item_hash=message.item_hash) + await delete_post(session=session, item_hash=message.item_hash) if content.type == "amend": - original_post = get_original_post(session, str(content.ref)) + original_post = await get_original_post(session, str(content.ref)) if original_post is None: raise InternalError( f"Could not find original post ({content.ref} for amend ({message.item_hash})." ) if original_post.latest_amend == message.item_hash: - refresh_latest_amend(session, original_post.item_hash) + await refresh_latest_amend(session, original_post.item_hash) return set(amend_hashes) diff --git a/src/aleph/handlers/content/store.py b/src/aleph/handlers/content/store.py index 55ab678f1..0a58846cb 100644 --- a/src/aleph/handlers/content/store.py +++ b/src/aleph/handlers/content/store.py @@ -36,7 +36,7 @@ from aleph.toolkit.constants import MAX_UNAUTHENTICATED_UPLOAD_FILE_SIZE, MiB from aleph.toolkit.costs import are_store_and_program_free from aleph.toolkit.timestamp import timestamp_to_datetime, utc_now -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession from aleph.types.files import FileTag, FileType from aleph.types.message_status import ( FileUnavailable, @@ -96,7 +96,7 @@ def __init__(self, storage_service: StorageService, grace_period: int): self.grace_period = grace_period async def is_related_content_fetched( - self, session: DbSession, message: MessageDb + self, session: AsyncDbSession, message: MessageDb ) -> bool: content = message.parsed_content assert isinstance(content, StoreContent) @@ -105,7 +105,7 @@ async def is_related_content_fetched( return await self.storage_service.storage_engine.exists(file_hash) async def fetch_related_content( - self, session: DbSession, message: MessageDb + self, session: AsyncDbSession, message: MessageDb ) -> None: # TODO: simplify this function, it's overly complicated for no good reason. @@ -199,7 +199,7 @@ async def fetch_related_content( else: size = -1 - upsert_file( + await upsert_file( session=session, file_hash=item_hash, file_type=FileType.DIRECTORY if is_folder else FileType.FILE, @@ -207,26 +207,28 @@ async def fetch_related_content( ) async def check_balance( - self, session: DbSession, message: MessageDb + self, session: AsyncDbSession, message: MessageDb ) -> List[AccountCostsDb]: content = _get_store_content(message) - message_cost, costs = get_total_and_detailed_costs( + message_cost, costs = await get_total_and_detailed_costs( session, content, message.item_hash ) if are_store_and_program_free(message): return costs - storage_size_mib = calculate_storage_size(session, content) + storage_size_mib = await calculate_storage_size(session, content) if storage_size_mib and storage_size_mib <= ( MAX_UNAUTHENTICATED_UPLOAD_FILE_SIZE / MiB ): return costs - current_balance = get_total_balance(address=content.address, session=session) - current_cost = get_total_cost_for_address( + current_balance = await get_total_balance( + address=content.address, session=session + ) + current_cost = await get_total_cost_for_address( session=session, address=content.address ) @@ -240,7 +242,9 @@ async def check_balance( return costs - async def check_dependencies(self, session: DbSession, message: MessageDb) -> None: + async def check_dependencies( + self, session: AsyncDbSession, message: MessageDb + ) -> None: content = _get_store_content(message) if content.ref is None: return @@ -261,7 +265,9 @@ async def check_dependencies(self, session: DbSession, message: MessageDb) -> No if not ref_is_hash: return - ref_file_pin_db = get_message_file_pin(session=session, item_hash=content.ref) + ref_file_pin_db = await get_message_file_pin( + session=session, item_hash=content.ref + ) if ref_file_pin_db is None: raise StoreRefNotFound(content.ref) @@ -269,7 +275,7 @@ async def check_dependencies(self, session: DbSession, message: MessageDb) -> No if ref_file_pin_db.ref is not None: raise StoreCannotUpdateStoreWithRef() - async def check_permissions(self, session: DbSession, message: MessageDb): + async def check_permissions(self, session: AsyncDbSession, message: MessageDb): await super().check_permissions(session=session, message=message) content = _get_store_content(message) if content.ref is None: @@ -279,7 +285,7 @@ async def check_permissions(self, session: DbSession, message: MessageDb): file_tag = make_file_tag( owner=owner, ref=content.ref, item_hash=message.item_hash ) - file_tag_db = get_file_tag(session=session, tag=file_tag) + file_tag_db = await get_file_tag(session=session, tag=file_tag) if not file_tag_db: return @@ -289,13 +295,13 @@ async def check_permissions(self, session: DbSession, message: MessageDb): f"{message.item_hash} attempts to update a file tag belonging to another user" ) - async def _pin_and_tag_file(self, session: DbSession, message: MessageDb): + async def _pin_and_tag_file(self, session: AsyncDbSession, message: MessageDb): content = _get_store_content(message) file_hash = content.item_hash owner = content.address - insert_message_file_pin( + await insert_message_file_pin( session=session, file_hash=file_hash, owner=owner, @@ -307,7 +313,7 @@ async def _pin_and_tag_file(self, session: DbSession, message: MessageDb): file_tag = make_file_tag( owner=content.address, ref=content.ref, item_hash=message.item_hash ) - upsert_file_tag( + await upsert_file_tag( session=session, tag=file_tag, owner=owner, @@ -315,12 +321,12 @@ async def _pin_and_tag_file(self, session: DbSession, message: MessageDb): last_updated=timestamp_to_datetime(content.time), ) - async def process(self, session: DbSession, messages: List[MessageDb]) -> None: + async def process(self, session: AsyncDbSession, messages: List[MessageDb]) -> None: for message in messages: await self._pin_and_tag_file(session=session, message=message) async def _check_remaining_pins( - self, session: DbSession, storage_hash: str, storage_type: ItemType + self, session: AsyncDbSession, storage_hash: str, storage_type: ItemType ): """ If a file is not pinned anymore, mark it as pickable by the garbage collector. @@ -333,7 +339,7 @@ async def _check_remaining_pins( """ LOGGER.debug(f"Garbage collecting {storage_hash}") - if is_pinned_file(session=session, file_hash=storage_hash): + if await is_pinned_file(session=session, file_hash=storage_hash): LOGGER.debug(f"File {storage_hash} has at least one reference left") return @@ -350,18 +356,21 @@ async def _check_remaining_pins( current_datetime = utc_now() delete_by = current_datetime + dt.timedelta(hours=self.grace_period) - insert_grace_period_file_pin( + await insert_grace_period_file_pin( session=session, file_hash=storage_hash, created=utc_now(), delete_by=delete_by, ) - async def forget_message(self, session: DbSession, message: MessageDb) -> Set[str]: + async def forget_message( + self, session: AsyncDbSession, message: MessageDb + ) -> Set[str]: content = _get_store_content(message) - delete_file_pin(session=session, item_hash=message.item_hash) - refresh_file_tag( + await delete_file_pin(session=session, item_hash=message.item_hash) + + await refresh_file_tag( session=session, tag=make_file_tag( owner=content.address, diff --git a/src/aleph/handlers/content/vm.py b/src/aleph/handlers/content/vm.py index f7b00da79..56141f119 100644 --- a/src/aleph/handlers/content/vm.py +++ b/src/aleph/handlers/content/vm.py @@ -54,7 +54,7 @@ from aleph.services.cost import get_total_and_detailed_costs from aleph.toolkit.costs import are_store_and_program_free from aleph.toolkit.timestamp import timestamp_to_datetime -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession from aleph.types.files import FileTag from aleph.types.message_status import ( InsufficientBalanceException, @@ -241,8 +241,8 @@ def vm_message_to_db(message: MessageDb) -> VmBaseDb: return vm -def find_missing_volumes( - session: DbSession, content: ExecutableContent +async def find_missing_volumes( + session: AsyncDbSession, content: ExecutableContent ) -> Set[FileTag]: tags_to_check = set() pins_to_check = set() @@ -273,18 +273,18 @@ def add_ref_to_check(_volume): # For each volume, if use_latest is set check the tags and otherwise check # the file pins. - file_tags_db = set(find_file_tags(session=session, tags=tags_to_check)) - file_pins_db = set(find_file_pins(session=session, item_hashes=pins_to_check)) + file_tags_db = set(await find_file_tags(session=session, tags=tags_to_check)) + file_pins_db = set(await find_file_pins(session=session, item_hashes=pins_to_check)) return (pins_to_check - file_pins_db) | (tags_to_check - file_tags_db) -def check_parent_volumes_size_requirements( - session: DbSession, content: ExecutableContent +async def check_parent_volumes_size_requirements( + session: AsyncDbSession, content: ExecutableContent ) -> None: - def _get_parent_volume_file(_parent: ParentVolume) -> StoredFileDb: + async def _get_parent_volume_file(_parent: ParentVolume) -> StoredFileDb: if _parent.use_latest: - file_tag = get_file_tag(session=session, tag=FileTag(_parent.ref)) + file_tag = await get_file_tag(session=session, tag=FileTag(_parent.ref)) if file_tag is None: raise InternalError( f"Could not find latest version of parent volume {_parent.ref}" @@ -292,7 +292,7 @@ def _get_parent_volume_file(_parent: ParentVolume) -> StoredFileDb: return file_tag.file - file_pin = get_message_file_pin(session=session, item_hash=_parent.ref) + file_pin = await get_message_file_pin(session=session, item_hash=_parent.ref) if file_pin is None: raise InternalError( f"Could not find original version of parent volume {_parent.ref}" @@ -315,7 +315,7 @@ class HasParent(Protocol): for volume in volumes_with_parent: if volume.parent: - volume_metadata = _get_parent_volume_file(volume.parent) + volume_metadata = await _get_parent_volume_file(volume.parent) volume_size = volume.size_mib * 1024 * 1024 if volume_size < volume_metadata.size: raise VmVolumeTooSmall( @@ -337,11 +337,11 @@ class VmMessageHandler(ContentHandler): """ async def check_balance( - self, session: DbSession, message: MessageDb + self, session: AsyncDbSession, message: MessageDb ) -> List[AccountCostsDb]: content = _get_vm_content(message) - message_cost, costs = get_total_and_detailed_costs( + message_cost, costs = await get_total_and_detailed_costs( session, content, message.item_hash ) @@ -357,8 +357,10 @@ async def check_balance( return costs # NOTE: Instances and persistent Programs being paid by HOLD are the only ones being checked for now - current_balance = get_total_balance(address=content.address, session=session) - current_cost = get_total_cost_for_address( + current_balance = await get_total_balance( + address=content.address, session=session + ) + current_cost = await get_total_cost_for_address( session=session, address=content.address ) @@ -372,25 +374,27 @@ async def check_balance( return costs - async def check_dependencies(self, session: DbSession, message: MessageDb) -> None: + async def check_dependencies( + self, session: AsyncDbSession, message: MessageDb + ) -> None: content = _get_vm_content(message) - missing_volumes = find_missing_volumes(session=session, content=content) + missing_volumes = await find_missing_volumes(session=session, content=content) if missing_volumes: raise VmVolumeNotFound([volume for volume in missing_volumes]) - check_parent_volumes_size_requirements(session=session, content=content) + await check_parent_volumes_size_requirements(session=session, content=content) # Check dependencies if the message updates an existing instance/program if (ref := content.replaces) is not None: - original_program = get_program(session=session, item_hash=ref) + original_program = await get_program(session=session, item_hash=ref) if original_program is None: raise VmRefNotFound(ref) if original_program.replaces is not None: raise VmCannotUpdateUpdate() - is_amend_allowed = is_vm_amend_allowed(session=session, vm_hash=ref) + is_amend_allowed = await is_vm_amend_allowed(session=session, vm_hash=ref) if is_amend_allowed is None: raise InternalError(f"Could not find current version of program {ref}") @@ -398,12 +402,12 @@ async def check_dependencies(self, session: DbSession, message: MessageDb) -> No raise VmUpdateNotAllowed() @staticmethod - async def process_vm_message(session: DbSession, message: MessageDb): + async def process_vm_message(session: AsyncDbSession, message: MessageDb): vm = vm_message_to_db(message) session.add(vm) program_ref = vm.replaces or vm.item_hash - upsert_vm_version( + await upsert_vm_version( session=session, vm_hash=vm.item_hash, owner=vm.owner, @@ -411,24 +415,26 @@ async def process_vm_message(session: DbSession, message: MessageDb): last_updated=vm.created, ) - async def process(self, session: DbSession, messages: List[MessageDb]) -> None: + async def process(self, session: AsyncDbSession, messages: List[MessageDb]) -> None: for message in messages: await self.process_vm_message(session=session, message=message) - async def forget_message(self, session: DbSession, message: MessageDb) -> Set[str]: + async def forget_message( + self, session: AsyncDbSession, message: MessageDb + ) -> Set[str]: content = _get_vm_content(message) LOGGER.debug("Deleting program %s...", message.item_hash) - delete_vm(session=session, vm_hash=message.item_hash) + await delete_vm(session=session, vm_hash=message.item_hash) if content.replaces: update_hashes = set() else: update_hashes = set( - delete_vm_updates(session=session, vm_hash=message.item_hash) + await delete_vm_updates(session=session, vm_hash=message.item_hash) ) - refresh_vm_version(session=session, vm_hash=message.item_hash) + await refresh_vm_version(session=session, vm_hash=message.item_hash) return update_hashes diff --git a/src/aleph/handlers/message_handler.py b/src/aleph/handlers/message_handler.py index f8e6f8b15..a088edb50 100644 --- a/src/aleph/handlers/message_handler.py +++ b/src/aleph/handlers/message_handler.py @@ -8,6 +8,7 @@ from aleph_message.models import ItemHash, ItemType, MessageType from configmanager import Config from pydantic import ValidationError +from sqlalchemy import select from sqlalchemy.dialects.postgresql import insert from aleph.chains.signature_verifier import SignatureVerifier @@ -39,7 +40,7 @@ from aleph.schemas.pending_messages import parse_message from aleph.storage import StorageService from aleph.toolkit.timestamp import timestamp_to_datetime -from aleph.types.db_session import DbSession, DbSessionFactory +from aleph.types.db_session import AsyncDbSession, AsyncDbSessionFactory from aleph.types.files import FileType from aleph.types.message_processing_result import ProcessedMessage, RejectedMessage from aleph.types.message_status import ( @@ -122,7 +123,7 @@ async def fetch_pending_message( return validated_message - async def fetch_related_content(self, session: DbSession, message: MessageDb): + async def fetch_related_content(self, session: AsyncDbSession, message: MessageDb): content_handler = self.get_content_handler(message.type) try: @@ -135,7 +136,7 @@ async def fetch_related_content(self, session: DbSession, message: MessageDb): ) from e async def load_fetched_content( - self, session: DbSession, pending_message: PendingMessageDb + self, session: AsyncDbSession, pending_message: PendingMessageDb ) -> PendingMessageDb: if pending_message.item_type != ItemType.inline: pending_message.fetched = False @@ -161,7 +162,7 @@ class MessagePublisher(BaseMessageHandler): def __init__( self, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, storage_service: StorageService, config: Config, pending_message_exchange: aio_pika.abc.AbstractExchange, @@ -191,19 +192,19 @@ async def add_pending_message( origin: Optional[MessageOrigin] = MessageOrigin.P2P, ) -> Optional[PendingMessageDb]: # TODO: this implementation is just messy, improve it. - with self.session_factory() as session: + async with self.session_factory() as session: try: # we don't check signatures yet. message = parse_message(message_dict) except InvalidMessageException as e: LOGGER.warning(e) - reject_new_pending_message( + await reject_new_pending_message( session=session, pending_message=message_dict, exception=e, tx_hash=tx_hash, ) - session.commit() + await session.commit() return None pending_message = PendingMessageDb.from_obj( @@ -220,25 +221,24 @@ async def add_pending_message( ) except InvalidMessageException as e: LOGGER.warning("Invalid message: %s - %s", message.item_hash, str(e)) - reject_new_pending_message( + await reject_new_pending_message( session=session, pending_message=message_dict, exception=e, tx_hash=tx_hash, ) - session.commit() + await session.commit() return None # Check if there are an already existing record - existing_message = ( - session.query(PendingMessageDb) - .filter_by( - sender=pending_message.sender, - item_hash=pending_message.item_hash, - signature=pending_message.signature, - ) - .one_or_none() + stmt = select(PendingMessageDb).where( + (PendingMessageDb.sender == pending_message.sender) + & (PendingMessageDb.item_hash == pending_message.item_hash) + & (PendingMessageDb.signature == pending_message.signature) ) + result = await session.execute(stmt) + existing_message = result.scalar_one_or_none() + if existing_message: return existing_message @@ -255,9 +255,9 @@ async def add_pending_message( ) try: - session.execute(upsert_message_status_stmt) - session.execute(insert_pending_message_stmt) - session.commit() + await session.execute(upsert_message_status_stmt) + await session.execute(insert_pending_message_stmt) + await session.commit() except sqlalchemy.exc.IntegrityError: # Handle the unique constraint violation. LOGGER.warning("Duplicate pending message detected trying to save it.") @@ -269,14 +269,14 @@ async def add_pending_message( pending_message.item_hash, str(e), ) - session.rollback() - reject_new_pending_message( + await session.rollback() + await reject_new_pending_message( session=session, pending_message=message_dict, exception=e, tx_hash=tx_hash, ) - session.commit() + await session.commit() return None await self._publish_pending_message(pending_message) @@ -307,16 +307,16 @@ async def verify_signature(self, pending_message: PendingMessageDb): @staticmethod async def confirm_existing_message( - session: DbSession, + session: AsyncDbSession, existing_message: MessageDb, pending_message: PendingMessageDb, ): if pending_message.signature != existing_message.signature: raise InvalidSignature(f"Invalid signature for {pending_message.item_hash}") - delete_pending_message(session=session, pending_message=pending_message) + await delete_pending_message(session=session, pending_message=pending_message) if tx_hash := pending_message.tx_hash: - session.execute( + await session.execute( make_confirmation_upsert_query( item_hash=pending_message.item_hash, tx_hash=tx_hash ) @@ -324,27 +324,30 @@ async def confirm_existing_message( @staticmethod async def confirm_existing_forgotten_message( - session: DbSession, + session: AsyncDbSession, forgotten_message: ForgottenMessageDb, pending_message: PendingMessageDb, ): if pending_message.signature != forgotten_message.signature: raise InvalidSignature(f"Invalid signature for {pending_message.item_hash}") - delete_pending_message(session=session, pending_message=pending_message) + await delete_pending_message(session=session, pending_message=pending_message) async def insert_message( - self, session: DbSession, pending_message: PendingMessageDb, message: MessageDb + self, + session: AsyncDbSession, + pending_message: PendingMessageDb, + message: MessageDb, ): - session.execute(make_message_upsert_query(message)) + await session.execute(make_message_upsert_query(message)) if message.item_type != ItemType.inline: - upsert_file( + await upsert_file( session=session, file_hash=message.item_hash, size=message.size, file_type=FileType.FILE, ) - insert_content_file_pin( + await insert_content_file_pin( session=session, file_hash=message.item_hash, owner=message.sender, @@ -352,8 +355,8 @@ async def insert_message( created=timestamp_to_datetime(message.content["time"]), ) - delete_pending_message(session=session, pending_message=pending_message) - session.execute( + await delete_pending_message(session=session, pending_message=pending_message) + await session.execute( make_message_status_upsert_query( item_hash=message.item_hash, new_status=MessageStatus.PROCESSED, @@ -363,21 +366,21 @@ async def insert_message( ) if tx_hash := pending_message.tx_hash: - session.execute( + await session.execute( make_confirmation_upsert_query( item_hash=message.item_hash, tx_hash=tx_hash ) ) async def insert_costs( - self, session: DbSession, costs: List[AccountCostsDb], message: MessageDb + self, session: AsyncDbSession, costs: List[AccountCostsDb], message: MessageDb ): if len(costs) > 0: insert_stmt = make_costs_upsert_query(costs) - session.execute(insert_stmt) + await session.execute(insert_stmt) async def verify_and_fetch( - self, session: DbSession, pending_message: PendingMessageDb + self, session: AsyncDbSession, pending_message: PendingMessageDb ) -> MessageDb: await self.verify_signature(pending_message=pending_message) validated_message = await self.fetch_pending_message( @@ -387,7 +390,7 @@ async def verify_and_fetch( return validated_message async def process( - self, session: DbSession, pending_message: PendingMessageDb + self, session: AsyncDbSession, pending_message: PendingMessageDb ) -> ProcessedMessage | RejectedMessage: """ Process a pending message. @@ -403,7 +406,7 @@ async def process( """ # Note: Check if message already exists (and confirm it) - existing_message = get_message_by_item_hash( + existing_message = await get_message_by_item_hash( session=session, item_hash=ItemHash(pending_message.item_hash) ) if existing_message: @@ -416,7 +419,7 @@ async def process( # Note: Check if message is already forgotten (and confirm it) # this is to avoid race conditions when a confirmation arrives after the FORGET message has been preocessed - forgotten_message = get_forgotten_message( + forgotten_message = await get_forgotten_message( session=session, item_hash=ItemHash(pending_message.item_hash) ) if forgotten_message: @@ -433,7 +436,6 @@ async def process( message = await self.verify_and_fetch( session=session, pending_message=pending_message ) - content_handler = self.get_content_handler(message.type) await content_handler.check_dependencies(session=session, message=message) await content_handler.check_permissions(session=session, message=message) diff --git a/src/aleph/jobs/__init__.py b/src/aleph/jobs/__init__.py index 70a436adb..00fc89c94 100644 --- a/src/aleph/jobs/__init__.py +++ b/src/aleph/jobs/__init__.py @@ -10,14 +10,14 @@ from aleph.jobs.process_pending_txs import handle_txs_task, pending_txs_subprocess from aleph.jobs.reconnect_ipfs import reconnect_ipfs_job from aleph.services.ipfs import IpfsService -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory LOGGER = logging.getLogger("jobs") def start_jobs( config, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, ipfs_service: IpfsService, use_processes=True, ) -> List[Coroutine]: diff --git a/src/aleph/jobs/fetch_pending_messages.py b/src/aleph/jobs/fetch_pending_messages.py index f20616281..0a4744c06 100644 --- a/src/aleph/jobs/fetch_pending_messages.py +++ b/src/aleph/jobs/fetch_pending_messages.py @@ -15,7 +15,7 @@ get_next_pending_messages, make_pending_message_fetched_statement, ) -from aleph.db.connection import make_engine, make_session_factory +from aleph.db.connection import make_async_engine, make_async_session_factory from aleph.db.models import MessageDb, PendingMessageDb from aleph.handlers.message_handler import MessageHandler from aleph.services.cache.node_cache import NodeCache @@ -25,7 +25,7 @@ from aleph.toolkit.logging import setup_logging from aleph.toolkit.monitoring import setup_sentry from aleph.toolkit.timestamp import utc_now -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from ..toolkit.rabbitmq import make_mq_conn from .job_utils import MessageJob, make_pending_message_queue, prepare_loop @@ -39,7 +39,7 @@ class PendingMessageFetcher(MessageJob): def __init__( self, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_handler: MessageHandler, max_retries: int, pending_message_queue: aio_pika.abc.AbstractQueue, @@ -53,27 +53,27 @@ def __init__( self.pending_message_queue = pending_message_queue async def fetch_pending_message(self, pending_message: PendingMessageDb): - with self.session_factory() as session: + async with self.session_factory() as session: try: message = await self.message_handler.verify_and_fetch( session=session, pending_message=pending_message ) - session.execute( + await session.execute( make_pending_message_fetched_statement( pending_message, message.content ) ) - session.commit() + await session.commit() return message except Exception as e: - session.rollback() + await session.rollback() _ = await self.handle_processing_error( session=session, pending_message=pending_message, exception=e, ) - session.commit() + await session.commit() return None async def fetch_pending_messages( @@ -91,7 +91,7 @@ async def fetch_pending_messages( fetched_messages: List[MessageDb] = [] while True: - with self.session_factory() as session: + async with self.session_factory() as session: if fetch_tasks: finished_tasks, fetch_tasks = await asyncio.wait( fetch_tasks, return_when=asyncio.FIRST_COMPLETED @@ -102,7 +102,7 @@ async def fetch_pending_messages( await node_cache.decr(retry_messages_cache_key) if len(fetch_tasks) < max_concurrent_tasks: - pending_messages = get_next_pending_messages( + pending_messages = await get_next_pending_messages( session=session, current_time=utc_now(), limit=max_concurrent_tasks - len(fetch_tasks), @@ -133,7 +133,8 @@ async def fetch_pending_messages( yield fetched_messages fetched_messages = [] - if not PendingMessageDb.count(session): + pending_count = await PendingMessageDb.count(session) + if not pending_count: # If not in loop mode, stop if there are no more pending messages if not loop: break @@ -158,8 +159,8 @@ def make_pipeline( async def fetch_messages_task(config: Config): - engine = make_engine(config=config, application_name="aleph-fetch") - session_factory = make_session_factory(engine) + engine = make_async_engine(config=config) + async_session_factory = make_async_session_factory(engine) mq_conn = await make_mq_conn(config=config) mq_channel = await mq_conn.channel() @@ -186,7 +187,7 @@ async def fetch_messages_task(config: Config): config=config, ) fetcher = PendingMessageFetcher( - session_factory=session_factory, + session_factory=async_session_factory, message_handler=message_handler, max_retries=config.aleph.jobs.pending_messages.max_retries.value, pending_message_queue=pending_message_queue, diff --git a/src/aleph/jobs/job_utils.py b/src/aleph/jobs/job_utils.py index 92453cbe7..a37691200 100644 --- a/src/aleph/jobs/job_utils.py +++ b/src/aleph/jobs/job_utils.py @@ -13,7 +13,7 @@ from aleph.db.models import PendingMessageDb from aleph.handlers.message_handler import MessageHandler from aleph.toolkit.timestamp import utc_now -from aleph.types.db_session import DbSession, DbSessionFactory +from aleph.types.db_session import AsyncDbSession, AsyncDbSessionFactory from aleph.types.message_processing_result import RejectedMessage, WillRetryMessage from aleph.types.message_status import ( ErrorCode, @@ -95,8 +95,8 @@ def compute_next_retry_interval(attempts: int) -> dt.timedelta: return dt.timedelta(seconds=min(seconds, MAX_RETRY_INTERVAL)) -def schedule_next_attempt( - session: DbSession, pending_message: PendingMessageDb +async def schedule_next_attempt( + session: AsyncDbSession, pending_message: PendingMessageDb ) -> None: """ Schedules the next attempt time for a failed pending message. @@ -115,7 +115,7 @@ def schedule_next_attempt( # are processed in the right order while leaving enough time for the issue that # caused the original message to be rescheduled to get resolved. next_attempt = utc_now() + compute_next_retry_interval(pending_message.retries) - set_next_retry( + await set_next_retry( session=session, pending_message=pending_message, next_attempt=next_attempt ) pending_message.next_attempt = next_attempt @@ -181,7 +181,7 @@ async def ready(self): class MessageJob(MqWatcher): def __init__( self, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_handler: MessageHandler, max_retries: int, pending_message_queue: aio_pika.abc.AbstractQueue, @@ -193,12 +193,12 @@ def __init__( self.max_retries = max_retries @staticmethod - def _handle_rejection( - session: DbSession, + async def _handle_rejection( + session: AsyncDbSession, pending_message: PendingMessageDb, exception: BaseException, ) -> RejectedMessage: - rejected_message_db = reject_existing_pending_message( + rejected_message_db = await reject_existing_pending_message( session=session, pending_message=pending_message, exception=exception, @@ -216,7 +216,7 @@ def _handle_rejection( async def _handle_retry( self, - session: DbSession, + session: AsyncDbSession, pending_message: PendingMessageDb, exception: BaseException, ) -> Union[RejectedMessage, WillRetryMessage]: @@ -227,7 +227,7 @@ async def _handle_retry( str(exception), ) error_code = exception.error_code - session.execute( + await session.execute( update(PendingMessageDb) .where(PendingMessageDb.id == pending_message.id) .values(fetched=False) @@ -240,7 +240,9 @@ async def _handle_retry( pending_message.item_hash, ) error_code = exception.error_code - schedule_next_attempt(session=session, pending_message=pending_message) + await schedule_next_attempt( + session=session, pending_message=pending_message + ) else: LOGGER.exception( "Unexpected error while fetching message", exc_info=exception @@ -251,20 +253,22 @@ async def _handle_retry( "Rejecting pending message: %s - too many retries", pending_message.item_hash, ) - return self._handle_rejection( + return await self._handle_rejection( session=session, pending_message=pending_message, exception=exception, ) else: - schedule_next_attempt(session=session, pending_message=pending_message) + await schedule_next_attempt( + session=session, pending_message=pending_message + ) return WillRetryMessage( pending_message=pending_message, error_code=error_code ) async def handle_processing_error( self, - session: DbSession, + session: AsyncDbSession, pending_message: PendingMessageDb, exception: BaseException, ) -> Union[RejectedMessage, WillRetryMessage]: @@ -274,7 +278,7 @@ async def handle_processing_error( pending_message.item_hash, str(exception), ) - return self._handle_rejection( + return await self._handle_rejection( session=session, pending_message=pending_message, exception=exception ) else: diff --git a/src/aleph/jobs/process_pending_messages.py b/src/aleph/jobs/process_pending_messages.py index 08bb99397..e4839760c 100644 --- a/src/aleph/jobs/process_pending_messages.py +++ b/src/aleph/jobs/process_pending_messages.py @@ -4,7 +4,7 @@ import asyncio from logging import getLogger -from typing import AsyncIterator, Dict, Sequence +from typing import AsyncIterator, Dict, List, Sequence, Set import aio_pika.abc from configmanager import Config @@ -12,8 +12,10 @@ import aleph.toolkit.json as aleph_json from aleph.chains.signature_verifier import SignatureVerifier -from aleph.db.accessors.pending_messages import get_next_pending_message -from aleph.db.connection import make_engine, make_session_factory +from aleph.db.accessors.pending_messages import ( + async_get_next_pending_messages_from_different_senders, +) +from aleph.db.connection import make_async_engine, make_async_session_factory from aleph.handlers.message_handler import MessageHandler from aleph.services.cache.node_cache import NodeCache from aleph.services.ipfs import IpfsService @@ -22,9 +24,10 @@ from aleph.toolkit.logging import setup_logging from aleph.toolkit.monitoring import setup_sentry from aleph.toolkit.timestamp import utc_now -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from aleph.types.message_processing_result import MessageProcessingResult +from ..db.models import PendingMessageDb from ..types.message_status import MessageOrigin from .job_utils import MessageJob, prepare_loop @@ -34,7 +37,7 @@ class PendingMessageProcessor(MessageJob): def __init__( self, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_handler: MessageHandler, max_retries: int, mq_conn: aio_pika.abc.AbstractConnection, @@ -51,10 +54,16 @@ def __init__( self.mq_conn = mq_conn self.mq_message_exchange = mq_message_exchange + # Reduced from 100 to 30 to prevent overwhelming the event loop + self.max_parallel = 25 + self.processed_hashes: Set[str] = set() + self.current_address: Set[str] = set() + self._task: Dict[str, asyncio.Task] = {} + @classmethod async def new( cls, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_handler: MessageHandler, max_retries: int, mq_host: str, @@ -97,35 +106,158 @@ async def new( async def close(self): await self.mq_conn.close() + async def process_message( + self, pending_message: PendingMessageDb + ) -> MessageProcessingResult: + async with self.session_factory() as session: + try: + LOGGER.info(f"Processing {pending_message.item_hash}: {pending_message.content['address']}") + result: MessageProcessingResult = await self.message_handler.process( + session=session, pending_message=pending_message + ) + await session.commit() + except Exception as e: + await session.rollback() + result = await self.handle_processing_error( + session=session, + pending_message=pending_message, + exception=e, + ) + await session.commit() + + # Clean up tracking after processing is complete + address = pending_message.content.get("address") + if address in self.current_address: + self.current_address.remove(address) + + if pending_message.item_hash in self.processed_hashes: + self.processed_hashes.remove(pending_message.item_hash) + + + return result + async def process_messages( self, ) -> AsyncIterator[Sequence[MessageProcessingResult]]: + """ + Process pending messages in parallel, up to max_parallel tasks at once. + + This method has been improved to: + 1. Better handle task completion and cleanup + 2. Process completed results more efficiently + 3. Ensure the event loop doesn't get blocked + 4. Add more robust error handling + + Returns: + An async iterator yielding sequences of message processing results + """ + # Keep track of active tasks and results + completed_results: List[MessageProcessingResult] = [] + + # Track the last time we yielded results + last_yield_time = asyncio.get_event_loop().time() + + # Define a callback function when a task completes + def task_done_callback(task: asyncio.Task): + try: + # The callback just ensures exceptions are logged + # Actual result processing is done in the main loop + task.result() + except asyncio.CancelledError: + pass + except Exception as e: + LOGGER.exception(f"Task failed with exception: {e}") + while True: - with self.session_factory() as session: - pending_message = get_next_pending_message( - current_time=utc_now(), session=session, fetched=True - ) - if not pending_message: - break + # Yield control back to the event loop to prevent blocking + await asyncio.sleep(0.1) + + current_time = asyncio.get_event_loop().time() + # First clean up completed tasks to free up slots + completed_count = 0 + for item_hash, task in list(self._task.items()): + if task.done(): + try: + result = task.result() + if result is not None: + LOGGER.info(f"Message {item_hash} processed with status: {result.status}") + completed_results.append(result) + completed_count += 1 + except asyncio.CancelledError: + LOGGER.debug(f"Task for {item_hash} was cancelled") + except Exception as e: + LOGGER.exception(f"Error getting task result for {item_hash}: {e}") + + # Remove from task dictionary and tracking sets + self._task.pop(item_hash) + + # Remove from tracking sets (clean up happens in process_message, + # but we also do it here as a safeguard) + if item_hash in self.processed_hashes: + self.processed_hashes.remove(item_hash) + + # Only fetch new messages if we're below max_parallel + if len(self._task) < self.max_parallel: + available_slots = self.max_parallel - len(self._task) + try: - result: MessageProcessingResult = ( - await self.message_handler.process( - session=session, pending_message=pending_message + async with self.session_factory() as session: + messages: List[PendingMessageDb] = ( + await async_get_next_pending_messages_from_different_senders( + session=session, + current_time=utc_now(), + fetched=True, + exclude_item_hashes=self.processed_hashes, + exclude_addresses=self.current_address, + limit=available_slots, + ) ) - ) - session.commit() + + if messages: + LOGGER.info(f"Fetched: {len(messages)} messages") + + # Create tasks for new messages + for message in messages: + if ( + not message.content + or not isinstance(message.content, dict) + or "address" not in message.content + ): + continue + # Track processed hashes and addresses + item_hash = message.item_hash + address = message.content.get("address") + + # Add to tracking sets + self.processed_hashes.add(item_hash) + self.current_address.add(address) + + LOGGER.info(f"Processing {item_hash}, {address}") + + # Create task and add callback + task = asyncio.create_task(self.process_message(message)) + task.add_done_callback(task_done_callback) + + # Store in active tasks dictionary + self._task[item_hash] = task except Exception as e: - session.rollback() - result = await self.handle_processing_error( - session=session, - pending_message=pending_message, - exception=e, - ) - session.commit() + LOGGER.exception(f"Error fetching pending messages: {e}") + # Sleep a bit longer if we hit an error + await asyncio.sleep(1.0) - yield [result] + # Yield completed results if we have any or if enough time has passed + if completed_results and (completed_count > 0 or current_time - last_yield_time > 5.0): + LOGGER.info(f"Yielding {len(completed_results)} completed results") + yield completed_results + completed_results = [] + last_yield_time = current_time + + # If we have no active tasks and no results, sleep a bit longer + # to avoid spinning the CPU when there's nothing to do + if not self._task and not completed_results: + await asyncio.sleep(0.5) async def publish_to_mq( self, message_iterator: AsyncIterator[Sequence[MessageProcessingResult]] @@ -149,8 +281,8 @@ def make_pipeline(self) -> AsyncIterator[Sequence[MessageProcessingResult]]: async def fetch_and_process_messages_task(config: Config): - engine = make_engine(config=config, application_name="aleph-process") - session_factory = make_session_factory(engine) + engine = make_async_engine(config=config, application_name="aleph-process") + session_factory = make_async_session_factory(engine) async with ( NodeCache( @@ -183,7 +315,7 @@ async def fetch_and_process_messages_task(config: Config): async with pending_message_processor: while True: - with session_factory() as session: + async with session_factory() as session: try: message_processing_pipeline = ( pending_message_processor.make_pipeline() @@ -196,7 +328,7 @@ async def fetch_and_process_messages_task(config: Config): except Exception: LOGGER.exception("Error in pending messages job") - session.rollback() + await session.rollback() LOGGER.info("Waiting for new pending messages...") # We still loop periodically for retried messages as we do not bother sending a message diff --git a/src/aleph/jobs/process_pending_txs.py b/src/aleph/jobs/process_pending_txs.py index 3c4d3e774..c36463115 100644 --- a/src/aleph/jobs/process_pending_txs.py +++ b/src/aleph/jobs/process_pending_txs.py @@ -12,7 +12,7 @@ from aleph.chains.chain_data_service import ChainDataService from aleph.db.accessors.pending_txs import delete_pending_tx, get_pending_txs -from aleph.db.connection import make_engine, make_session_factory +from aleph.db.connection import make_async_engine, make_async_session_factory from aleph.db.models import PendingTxDb from aleph.handlers.message_handler import MessagePublisher from aleph.services.cache.node_cache import NodeCache @@ -24,7 +24,7 @@ from aleph.toolkit.rabbitmq import make_mq_conn from aleph.toolkit.timestamp import utc_now from aleph.types.chain_sync import ChainSyncProtocol -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from ..types.message_status import MessageOrigin from .job_utils import MqWatcher, make_pending_tx_queue, prepare_loop @@ -35,7 +35,7 @@ class PendingTxProcessor(MqWatcher): def __init__( self, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_publisher: MessagePublisher, chain_data_service: ChainDataService, pending_tx_queue: aio_pika.abc.AbstractQueue, @@ -73,9 +73,9 @@ async def handle_pending_tx( ) # bogus or handled, we remove it. - with self.session_factory() as session: - delete_pending_tx(session=session, tx_hash=tx.hash) - session.commit() + async with self.session_factory() as session: + await delete_pending_tx(session=session, tx_hash=tx.hash) + await session.commit() else: LOGGER.debug("TX contains no message") @@ -90,8 +90,8 @@ async def process_pending_txs(self, max_concurrent_tasks: int): seen_offchain_hashes = set() seen_ids: Set[str] = set() LOGGER.info("handling TXs") - with self.session_factory() as session: - for pending_tx in get_pending_txs(session): + async with self.session_factory() as session: + for pending_tx in await get_pending_txs(session): # TODO: remove this feature? It doesn't seem necessary. if pending_tx.tx.protocol == ChainSyncProtocol.OFF_CHAIN_SYNC: if pending_tx.tx.content in seen_offchain_hashes: @@ -118,8 +118,8 @@ async def process_pending_txs(self, max_concurrent_tasks: int): async def handle_txs_task(config: Config): max_concurrent_tasks = config.aleph.jobs.pending_txs.max_concurrency.value - engine = make_engine(config=config, application_name="aleph-txs") - session_factory = make_session_factory(engine) + async_engine = make_async_engine(config=config, application_name="aleph-txs") + async_session_factory = make_async_session_factory(engine=async_engine) mq_conn = await make_mq_conn(config=config) mq_channel = await mq_conn.channel() @@ -142,17 +142,18 @@ async def handle_txs_task(config: Config): ipfs_service=ipfs_service, node_cache=node_cache, ) + message_publisher = MessagePublisher( - session_factory=session_factory, + session_factory=async_session_factory, storage_service=storage_service, config=config, pending_message_exchange=pending_message_exchange, ) chain_data_service = ChainDataService( - session_factory=session_factory, storage_service=storage_service + session_factory=async_session_factory, storage_service=storage_service ) pending_tx_processor = PendingTxProcessor( - session_factory=session_factory, + session_factory=async_session_factory, message_publisher=message_publisher, chain_data_service=chain_data_service, pending_tx_queue=pending_tx_queue, diff --git a/src/aleph/jobs/reconnect_ipfs.py b/src/aleph/jobs/reconnect_ipfs.py index ae8cdb5ae..6962e6833 100644 --- a/src/aleph/jobs/reconnect_ipfs.py +++ b/src/aleph/jobs/reconnect_ipfs.py @@ -11,13 +11,13 @@ from aleph.db.accessors.peers import get_all_addresses_by_peer_type from aleph.db.models import PeerType from aleph.services.ipfs import IpfsService -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory LOGGER = logging.getLogger("jobs.reconnect_ipfs") async def reconnect_ipfs_job( - config: Config, session_factory: DbSessionFactory, ipfs_service: IpfsService + config: Config, session_factory: AsyncDbSessionFactory, ipfs_service: IpfsService ): from aleph.services.utils import get_IP @@ -34,8 +34,8 @@ async def reconnect_ipfs_job( except aioipfs.APIError: LOGGER.warning("Can't reconnect to %s" % peer) - with session_factory() as session: - peers = get_all_addresses_by_peer_type( + async with session_factory() as session: + peers = await get_all_addresses_by_peer_type( session=session, peer_type=PeerType.IPFS ) diff --git a/src/aleph/network.py b/src/aleph/network.py index 2470bb0ff..606a92569 100644 --- a/src/aleph/network.py +++ b/src/aleph/network.py @@ -13,7 +13,7 @@ from aleph.services.ipfs.pubsub import incoming_channel as incoming_ipfs_channel from aleph.services.storage.fileystem_engine import FileSystemStorageEngine from aleph.storage import StorageService -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from aleph.types.message_status import InvalidMessageFormat LOGGER = logging.getLogger(__name__) @@ -38,7 +38,7 @@ async def decode_pubsub_message(message_data: bytes) -> Dict[str, Any]: async def listener_tasks( config, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, node_cache: NodeCache, p2p_client: AlephP2PServiceClient, mq_channel: aio_pika.abc.AbstractChannel, diff --git a/src/aleph/permissions.py b/src/aleph/permissions.py index 1b50070b1..b1481656f 100644 --- a/src/aleph/permissions.py +++ b/src/aleph/permissions.py @@ -2,10 +2,12 @@ from aleph.db.accessors.aggregates import get_aggregate_by_key from aleph.db.models import MessageDb -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession -async def check_sender_authorization(session: DbSession, message: MessageDb) -> bool: +async def check_sender_authorization( + session: AsyncDbSession, message: MessageDb +) -> bool: """Checks a content against a message to verify if sender is authorized. TODO: implement "security" aggregate key check. @@ -20,7 +22,7 @@ async def check_sender_authorization(session: DbSession, message: MessageDb) -> if sender == address: return True - aggregate = get_aggregate_by_key( + aggregate = await get_aggregate_by_key( session=session, key="security", owner=address ) # do we need anything else here? diff --git a/src/aleph/services/cache/materialized_views.py b/src/aleph/services/cache/materialized_views.py index 796daa7f7..b9697bf55 100644 --- a/src/aleph/services/cache/materialized_views.py +++ b/src/aleph/services/cache/materialized_views.py @@ -2,12 +2,14 @@ import logging from aleph.db.accessors.messages import refresh_address_stats_mat_view -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory LOGGER = logging.getLogger(__name__) -async def refresh_cache_materialized_views(session_factory: DbSessionFactory) -> None: +async def refresh_cache_materialized_views( + session_factory: AsyncDbSessionFactory, +) -> None: """ Refresh DB materialized views used as caches, periodically. @@ -19,9 +21,9 @@ async def refresh_cache_materialized_views(session_factory: DbSessionFactory) -> while True: try: - with session_factory() as session: - refresh_address_stats_mat_view(session) - session.commit() + async with session_factory() as session: + await refresh_address_stats_mat_view(session) + await session.commit() LOGGER.info("Refreshed address stats materialized view") except Exception: diff --git a/src/aleph/services/cost.py b/src/aleph/services/cost.py index 7be692dc1..668825da8 100644 --- a/src/aleph/services/cost.py +++ b/src/aleph/services/cost.py @@ -48,7 +48,7 @@ RefVolume, SizedVolume, ) -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession from aleph.types.files import FileTag from aleph.types.settings import Settings @@ -67,8 +67,8 @@ # TODO: Cache aggregate for 5 min -def _get_settings_aggregate(session: DbSession) -> Union[AggregateDb, dict]: - aggregate = get_aggregate_by_key( +async def _get_settings_aggregate(session: AsyncDbSession) -> Union[AggregateDb, dict]: + aggregate = await get_aggregate_by_key( session=session, owner=SETTINGS_AGGREGATE_OWNER, key=SETTINGS_AGGREGATE_KEY ) @@ -78,8 +78,8 @@ def _get_settings_aggregate(session: DbSession) -> Union[AggregateDb, dict]: return aggregate -def _get_settings(session: DbSession) -> Settings: - aggregate = _get_settings_aggregate(session) +async def _get_settings(session: AsyncDbSession) -> Settings: + aggregate = await _get_settings_aggregate(session) return Settings.from_aggregate(aggregate) @@ -167,8 +167,8 @@ def _get_product_price_type( # TODO: Cache aggregate for 5 min -def _get_price_aggregate(session: DbSession) -> Union[AggregateDb, dict]: - aggregate = get_aggregate_by_key( +async def _get_price_aggregate(session: AsyncDbSession) -> Union[AggregateDb, dict]: + aggregate = await get_aggregate_by_key( session=session, owner=PRICE_AGGREGATE_OWNER, key=PRICE_AGGREGATE_KEY ) @@ -178,24 +178,24 @@ def _get_price_aggregate(session: DbSession) -> Union[AggregateDb, dict]: return aggregate -def _get_product_price( - session: DbSession, content: CostComputableContent, settings: Settings +async def _get_product_price( + session: AsyncDbSession, content: CostComputableContent, settings: Settings ) -> ProductPricing: - price_aggregate = _get_price_aggregate(session) + price_aggregate = await _get_price_aggregate(session) price_type = _get_product_price_type(content, settings, price_aggregate) return ProductPricing.from_aggregate(price_type, price_aggregate) -def _get_file_from_ref( - session: DbSession, ref: str, use_latest: bool +async def _get_file_from_ref( + session: AsyncDbSession, ref: str, use_latest: bool ) -> Optional[StoredFileDb]: tag_or_pin: Optional[Union[MessageFilePinDb, FileTagDb]] if use_latest: - tag_or_pin = get_file_tag(session=session, tag=FileTag(ref)) + tag_or_pin = await get_file_tag(session=session, tag=FileTag(ref)) else: - tag_or_pin = get_message_file_pin(session=session, item_hash=ref) + tag_or_pin = await get_message_file_pin(session=session, item_hash=ref) if tag_or_pin: return tag_or_pin.file @@ -233,8 +233,8 @@ def _get_compute_unit_multiplier(content: CostComputableContent) -> int: return compute_unit_multiplier -def _get_volumes_costs( - session: DbSession, +async def _get_volumes_costs( + session: AsyncDbSession, volumes: List[RefVolume | SizedVolume], payment_type: PaymentType, price_per_mib: Decimal, @@ -248,7 +248,7 @@ def _get_volumes_costs( if isinstance(volume, SizedVolume): storage_mib = Decimal(volume.size_mib) elif isinstance(volume, RefVolume): - file = _get_file_from_ref( + file = await _get_file_from_ref( session=session, ref=volume.ref, use_latest=volume.use_latest ) @@ -280,8 +280,8 @@ def _get_volumes_costs( return costs -def _get_execution_volumes_costs( - session: DbSession, +async def _get_execution_volumes_costs( + session: AsyncDbSession, content: CostComputableExecutableContent, pricing: ProductPricing, payment_type: PaymentType, @@ -408,7 +408,7 @@ def _get_execution_volumes_costs( price_per_mib = pricing.price.storage.holding price_per_mib_second = pricing.price.storage.payg / HOUR - return _get_volumes_costs( + return await _get_volumes_costs( session, volumes, payment_type, @@ -419,15 +419,15 @@ def _get_execution_volumes_costs( ) -def _get_additional_storage_price( - session: DbSession, +async def _get_additional_storage_price( + session: AsyncDbSession, content: CostComputableExecutableContent, pricing: ProductPricing, payment_type: PaymentType, item_hash: str, ) -> List[AccountCostsDb]: # EXECUTION VOLUMES COSTS - costs = _get_execution_volumes_costs( + costs = await _get_execution_volumes_costs( session, content, pricing, payment_type, item_hash ) @@ -470,8 +470,8 @@ def _get_additional_storage_price( return costs -def _calculate_executable_costs( - session: DbSession, +async def _calculate_executable_costs( + session: AsyncDbSession, content: CostComputableExecutableContent, pricing: ProductPricing, item_hash: str, @@ -520,22 +520,22 @@ def _calculate_executable_costs( ) costs: List[AccountCostsDb] = [execution_cost] - costs += _get_additional_storage_price( + costs += await _get_additional_storage_price( session, content, pricing, payment_type, item_hash ) return costs -def _calculate_storage_costs( - session: DbSession, +async def _calculate_storage_costs( + session: AsyncDbSession, content: CostEstimationStoreContent | StoreContent, pricing: ProductPricing, item_hash: str, ) -> List[AccountCostsDb]: payment_type = get_payment_type(content) - storage_mib = calculate_storage_size(session, content) + storage_mib = await calculate_storage_size(session, content) if not storage_mib: return [] @@ -545,7 +545,7 @@ def _calculate_storage_costs( price_per_mib = pricing.price.storage.holding price_per_mib_second = pricing.price.storage.payg / HOUR - return _get_volumes_costs( + return await _get_volumes_costs( session, [volume], payment_type, @@ -556,15 +556,15 @@ def _calculate_storage_costs( ) -def calculate_storage_size( - session: DbSession, +async def calculate_storage_size( + session: AsyncDbSession, content: CostEstimationStoreContent | StoreContent, ) -> Optional[Decimal]: if isinstance(content, CostEstimationStoreContent) and content.estimated_size_mib: storage_mib = Decimal(content.estimated_size_mib) else: - file = get_file(session, content.item_hash) + file = await get_file(session, content.item_hash) if not file: return None storage_mib = Decimal(file.size / MiB) @@ -572,30 +572,30 @@ def calculate_storage_size( return storage_mib -def get_detailed_costs( - session: DbSession, +async def get_detailed_costs( + session: AsyncDbSession, content: CostComputableContent, item_hash: str, pricing: Optional[ProductPricing] = None, settings: Optional[Settings] = None, ) -> List[AccountCostsDb]: - settings = settings or _get_settings(session) - pricing = pricing or _get_product_price(session, content, settings) + settings = settings or await _get_settings(session) + pricing = pricing or await _get_product_price(session, content, settings) if isinstance(content, StoreContent): - return _calculate_storage_costs(session, content, pricing, item_hash) + return await _calculate_storage_costs(session, content, pricing, item_hash) else: - return _calculate_executable_costs(session, content, pricing, item_hash) + return await _calculate_executable_costs(session, content, pricing, item_hash) -def get_total_and_detailed_costs( - session: DbSession, +async def get_total_and_detailed_costs( + session: AsyncDbSession, content: CostComputableContent, item_hash: str, ) -> Tuple[Decimal, List[AccountCostsDb]]: payment_type = get_payment_type(content) - costs = get_detailed_costs(session, content, item_hash) + costs = await get_detailed_costs(session, content, item_hash) cost = format_cost( reduce(lambda x, y: x + y.cost_stream, costs, Decimal(0)) if payment_type == PaymentType.superfluid @@ -605,14 +605,14 @@ def get_total_and_detailed_costs( return Decimal(cost), list(costs) -def get_total_and_detailed_costs_from_db( - session: DbSession, +async def get_total_and_detailed_costs_from_db( + session: AsyncDbSession, content: CostComputableContent, item_hash: str, ) -> Tuple[Decimal, List[AccountCostsDb]]: payment_type = get_payment_type(content) - costs = get_message_costs(session, item_hash) + costs = await get_message_costs(session, item_hash) cost = format_cost( reduce(lambda x, y: x + y.cost_stream, costs, Decimal(0)) if payment_type == PaymentType.superfluid diff --git a/src/aleph/services/p2p/__init__.py b/src/aleph/services/p2p/__init__.py index 32464e83d..3b752ae37 100644 --- a/src/aleph/services/p2p/__init__.py +++ b/src/aleph/services/p2p/__init__.py @@ -4,7 +4,7 @@ from configmanager import Config from aleph.services.ipfs import IpfsService -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from ..cache.node_cache import NodeCache from .manager import initialize_host @@ -28,7 +28,7 @@ async def init_p2p_client(config: Config, service_name: str) -> AlephP2PServiceC async def init_p2p( config: Config, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, service_name: str, ipfs_service: IpfsService, node_cache: NodeCache, diff --git a/src/aleph/services/p2p/jobs.py b/src/aleph/services/p2p/jobs.py index 2203345dd..afd5cbad1 100644 --- a/src/aleph/services/p2p/jobs.py +++ b/src/aleph/services/p2p/jobs.py @@ -8,7 +8,7 @@ from aleph.db.accessors.peers import get_all_addresses_by_peer_type from aleph.db.models import PeerType -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from ..cache.node_cache import NodeCache from .http import api_get_request @@ -26,7 +26,9 @@ class PeerStatus: async def reconnect_p2p_job( - config: Config, session_factory: DbSessionFactory, p2p_client: AlephP2PServiceClient + config: Config, + session_factory: AsyncDbSessionFactory, + p2p_client: AlephP2PServiceClient, ) -> None: await asyncio.sleep(2) @@ -34,9 +36,9 @@ async def reconnect_p2p_job( try: peers = set(config.p2p.peers.value) - with session_factory() as session: + async with session_factory() as session: peers |= set( - get_all_addresses_by_peer_type( + await get_all_addresses_by_peer_type( session=session, peer_type=PeerType.P2P ) ) @@ -66,7 +68,7 @@ async def check_peer(peer_uri: str, timeout: int = 1) -> PeerStatus: async def tidy_http_peers_job( - config: Config, session_factory: DbSessionFactory, node_cache: NodeCache + config: Config, session_factory: AsyncDbSessionFactory, node_cache: NodeCache ) -> None: """Check that HTTP peers are reachable, else remove them from the list""" from aleph.services.utils import get_IP @@ -78,8 +80,8 @@ async def tidy_http_peers_job( jobs = [] try: - with session_factory() as session: - peers = get_all_addresses_by_peer_type( + async with session_factory() as session: + peers = await get_all_addresses_by_peer_type( session=session, peer_type=PeerType.HTTP ) diff --git a/src/aleph/services/p2p/manager.py b/src/aleph/services/p2p/manager.py index fdf8970db..ed268ccc3 100644 --- a/src/aleph/services/p2p/manager.py +++ b/src/aleph/services/p2p/manager.py @@ -10,14 +10,14 @@ from aleph.services.peers.monitor import monitor_hosts_ipfs, monitor_hosts_p2p from aleph.services.peers.publish import publish_host from aleph.services.utils import get_IP -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory LOGGER = logging.getLogger(__name__) async def initialize_host( config: Config, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, node_cache: NodeCache, p2p_client: AlephP2PServiceClient, ipfs_service: IpfsService, diff --git a/src/aleph/services/peers/monitor.py b/src/aleph/services/peers/monitor.py index 0a6593288..c721087ff 100644 --- a/src/aleph/services/peers/monitor.py +++ b/src/aleph/services/peers/monitor.py @@ -9,13 +9,13 @@ from aleph.db.models import PeerType from aleph.services.ipfs import IpfsService from aleph.toolkit.timestamp import utc_now -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory LOGGER = logging.getLogger(__name__) async def handle_incoming_host( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, data: bytes, sender: str, source: PeerType, @@ -34,8 +34,8 @@ async def handle_incoming_host( # TODO: handle interests and save it - with session_factory() as session: - upsert_peer( + async with session_factory() as session: + await upsert_peer( session=session, peer_id=sender, peer_type=peer_type, @@ -43,7 +43,7 @@ async def handle_incoming_host( source=source, last_seen=utc_now(), ) - session.commit() + await session.commit() except Exception as e: if isinstance(e, ValueError): @@ -54,7 +54,7 @@ async def handle_incoming_host( async def monitor_hosts_p2p( p2p_client: AlephP2PServiceClient, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, alive_topic: str, ) -> None: while True: @@ -76,7 +76,7 @@ async def monitor_hosts_p2p( async def monitor_hosts_ipfs( - ipfs_service: IpfsService, session_factory: DbSessionFactory, alive_topic: str + ipfs_service: IpfsService, session_factory: AsyncDbSessionFactory, alive_topic: str ): while True: try: diff --git a/src/aleph/services/storage/fileystem_engine.py b/src/aleph/services/storage/fileystem_engine.py index 4a209f4c1..48dd389ac 100644 --- a/src/aleph/services/storage/fileystem_engine.py +++ b/src/aleph/services/storage/fileystem_engine.py @@ -1,11 +1,15 @@ from pathlib import Path from typing import Optional, Union +import aiofiles +import aiofiles.ospath + from .engine import StorageEngine class FileSystemStorageEngine(StorageEngine): def __init__(self, folder: Union[Path, str]): + self.folder = folder if isinstance(folder, Path) else Path(folder) if self.folder.exists() and not self.folder.is_dir(): @@ -16,19 +20,27 @@ def __init__(self, folder: Union[Path, str]): async def read(self, filename: str) -> Optional[bytes]: file_path = self.folder / filename - if not file_path.is_file(): + if not await aiofiles.ospath.isfile(file_path): return None - - return file_path.read_bytes() + async with aiofiles.open(file_path, "rb") as f: + return await f.read() async def write(self, filename: str, content: bytes): file_path = self.folder / filename - file_path.write_bytes(content) + async with aiofiles.open(file_path, "wb") as f: + + await f.write(content) async def delete(self, filename: str): file_path = self.folder / filename - file_path.unlink(missing_ok=True) + async_unlink = aiofiles.ospath.wrap( + Path.unlink + ) # We manually warp unlink (not handle by aiofiles) + + await async_unlink(async_unlink(file_path, missing_ok=True)) async def exists(self, filename: str) -> bool: file_path = self.folder / filename - return file_path.exists() + return await aiofiles.ospath.exists( + file_path + ) # This func warp .exist func into async diff --git a/src/aleph/services/storage/garbage_collector.py b/src/aleph/services/storage/garbage_collector.py index 49cb2df81..d97887c75 100644 --- a/src/aleph/services/storage/garbage_collector.py +++ b/src/aleph/services/storage/garbage_collector.py @@ -10,14 +10,14 @@ from aleph.db.accessors.files import delete_grace_period_file_pins, get_unpinned_files from aleph.storage import StorageService from aleph.toolkit.timestamp import utc_now -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory LOGGER = logging.getLogger(__name__) class GarbageCollector: def __init__( - self, session_factory: DbSessionFactory, storage_service: StorageService + self, session_factory: AsyncDbSessionFactory, storage_service: StorageService ): self.session_factory = session_factory self.storage_service = storage_service @@ -43,23 +43,23 @@ async def _delete_from_local_storage(self, file_hash: ItemHash): LOGGER.debug(f"Removed from local storage: {file_hash}") async def collect(self, datetime: dt.datetime): - with self.session_factory() as session: + async with self.session_factory() as session: # Delete outdated grace period file pins - delete_grace_period_file_pins(session=session, datetime=datetime) - session.commit() + await delete_grace_period_file_pins(session=session, datetime=datetime) + await session.commit() # Delete files without pins - files_to_delete = list(get_unpinned_files(session)) + files_to_delete = list(await get_unpinned_files(session)) LOGGER.info("Found %d files to delete", len(files_to_delete)) for file_to_delete in files_to_delete: - with self.session_factory() as session: + async with self.session_factory() as session: try: file_hash = ItemHash(file_to_delete.hash) LOGGER.info("Deleting %s...", file_hash) - delete_file_db(session=session, file_hash=file_hash) - session.commit() + await delete_file_db(session=session, file_hash=file_hash) + await session.commit() if file_hash.item_type == ItemType.ipfs: await self._delete_from_ipfs(file_hash) @@ -69,7 +69,7 @@ async def collect(self, datetime: dt.datetime): LOGGER.info("Deleted %s", file_hash) except Exception as err: LOGGER.error("Failed to delete file %s: %s", file_hash, str(err)) - session.rollback() + await session.rollback() async def garbage_collector_task( diff --git a/src/aleph/storage.py b/src/aleph/storage.py index d9a4d19f2..6906383e3 100644 --- a/src/aleph/storage.py +++ b/src/aleph/storage.py @@ -22,7 +22,7 @@ from aleph.services.ipfs.common import get_cid_version from aleph.services.p2p.http import request_hash as p2p_http_request_hash from aleph.services.storage.engine import StorageEngine -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession from aleph.types.files import FileType from aleph.utils import get_sha256 @@ -251,7 +251,7 @@ async def pin_hash(self, chash: str, timeout: int = 30, tries: int = 1): await self.ipfs_service.pin_add(cid=chash, timeout=timeout, tries=tries) async def add_json( - self, session: DbSession, value: Any, engine: ItemType = ItemType.ipfs + self, session: AsyncDbSession, value: Any, engine: ItemType = ItemType.ipfs ) -> str: content = aleph_json.dumps(value) @@ -263,7 +263,7 @@ async def add_json( raise NotImplementedError("storage engine %s not supported" % engine) await self.storage_engine.write(filename=chash, content=content) - upsert_file( + await upsert_file( session=session, file_hash=chash, size=len(content), @@ -273,10 +273,10 @@ async def add_json( return chash async def add_file_content_to_local_storage( - self, session: DbSession, file_content: bytes, file_hash: str + self, session: AsyncDbSession, file_content: bytes, file_hash: str ) -> None: await self.storage_engine.write(filename=file_hash, content=file_content) - upsert_file( + await upsert_file( session=session, file_hash=file_hash, size=len(file_content), @@ -284,7 +284,10 @@ async def add_file_content_to_local_storage( ) async def add_file( - self, session: DbSession, file_content: bytes, engine: ItemType = ItemType.ipfs + self, + session: AsyncDbSession, + file_content: bytes, + engine: ItemType = ItemType.ipfs, ) -> str: if engine == ItemType.ipfs: output = await self.ipfs_service.add_file(file_content) diff --git a/src/aleph/web/controllers/accounts.py b/src/aleph/web/controllers/accounts.py index a58e2ae84..7653085da 100644 --- a/src/aleph/web/controllers/accounts.py +++ b/src/aleph/web/controllers/accounts.py @@ -24,7 +24,7 @@ GetAccountQueryParams, GetBalancesChainsQueryParams, ) -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from aleph.web.controllers.app_state_getters import get_session_factory_from_request @@ -49,10 +49,10 @@ async def addresses_stats_view(request: web.Request): """Returns the stats of some addresses.""" addresses: List[str] = request.query.getall("addresses[]", []) - session_factory: DbSessionFactory = request.app["session_factory"] + session_factory: AsyncDbSessionFactory = request.app["session_factory"] - with session_factory() as session: - stats = get_message_stats_by_address(session=session, addresses=addresses) + async with session_factory() as session: + stats = await get_message_stats_by_address(session=session, addresses=addresses) stats_dict = make_stats_dict(stats) @@ -82,12 +82,12 @@ async def get_account_balance(request: web.Request): except ValidationError as e: raise web.HTTPUnprocessableEntity(text=e.json()) - session_factory: DbSessionFactory = get_session_factory_from_request(request) - with session_factory() as session: - balance, details = get_total_detailed_balance( + session_factory: AsyncDbSessionFactory = get_session_factory_from_request(request) + async with session_factory() as session: + balance, details = await get_total_detailed_balance( session=session, address=address, chain=query_params.chain ) - total_cost = get_total_cost_for_address(session=session, address=address) + total_cost = await get_total_cost_for_address(session=session, address=address) return web.json_response( text=GetAccountBalanceResponse( address=address, balance=balance, locked_amount=total_cost, details=details @@ -103,9 +103,9 @@ async def get_chain_balances(request: web.Request) -> web.Response: find_filters = query_params.model_dump(exclude_none=True) - session_factory: DbSessionFactory = get_session_factory_from_request(request) - with session_factory() as session: - balances = get_balances_by_chain(session, **find_filters) + session_factory: AsyncDbSessionFactory = get_session_factory_from_request(request) + async with session_factory() as session: + balances = await get_balances_by_chain(session, **find_filters) formatted_balances = [ AddressBalanceResponse.model_validate(b) for b in balances @@ -134,11 +134,11 @@ async def get_account_files(request: web.Request) -> web.Response: except ValidationError as e: raise web.HTTPUnprocessableEntity(text=e.json()) - session_factory: DbSessionFactory = get_session_factory_from_request(request) + session_factory: AsyncDbSessionFactory = get_session_factory_from_request(request) - with session_factory() as session: + async with session_factory() as session: file_pins = list( - get_address_files_for_api( + await get_address_files_for_api( session=session, owner=address, pagination=query_params.pagination, @@ -146,7 +146,9 @@ async def get_account_files(request: web.Request) -> web.Response: sort_order=query_params.sort_order, ) ) - nb_files, total_size = get_address_files_stats(session=session, owner=address) + nb_files, total_size = await get_address_files_stats( + session=session, owner=address + ) if not file_pins: raise web.HTTPNotFound() diff --git a/src/aleph/web/controllers/aggregates.py b/src/aleph/web/controllers/aggregates.py index c9f61c3be..42b88ed5e 100644 --- a/src/aleph/web/controllers/aggregates.py +++ b/src/aleph/web/controllers/aggregates.py @@ -43,21 +43,23 @@ async def address_aggregate(request: web.Request) -> web.Response: text=e.json(), content_type="application/json" ) session_factory = request.app["session_factory"] - with session_factory() as session: - dirty_aggregates = session.execute( - select(AggregateDb.key).where( - (AggregateDb.owner == address) - & (AggregateDb.owner == address) - & AggregateDb.dirty + async with session_factory() as session: + dirty_aggregates = ( + await session.execute( + select(AggregateDb.key).where( + (AggregateDb.owner == address) + & (AggregateDb.owner == address) + & AggregateDb.dirty + ) ) ).scalars() for key in dirty_aggregates: LOGGER.info("Refreshing dirty aggregate %s/%s", address, key) - refresh_aggregate(session=session, owner=address, key=key) - session.commit() + await refresh_aggregate(session=session, owner=address, key=key) + await session.commit() aggregates = list( - get_aggregates_by_owner( + await get_aggregates_by_owner( session=session, owner=address, with_info=query_params.with_info, diff --git a/src/aleph/web/controllers/app_state_getters.py b/src/aleph/web/controllers/app_state_getters.py index 1277388e9..876e6820c 100644 --- a/src/aleph/web/controllers/app_state_getters.py +++ b/src/aleph/web/controllers/app_state_getters.py @@ -16,7 +16,7 @@ from aleph.services.cache.node_cache import NodeCache from aleph.services.ipfs import IpfsService from aleph.storage import StorageService -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory APP_STATE_CONFIG = "config" APP_STATE_MQ_CONN = "mq_conn" @@ -99,8 +99,8 @@ def get_p2p_client_from_request(request: web.Request) -> AlephP2PServiceClient: return cast(AlephP2PServiceClient, request.app[APP_STATE_P2P_CLIENT]) -def get_session_factory_from_request(request: web.Request) -> DbSessionFactory: - return cast(DbSessionFactory, request.app[APP_STATE_SESSION_FACTORY]) +def get_session_factory_from_request(request: web.Request) -> AsyncDbSessionFactory: + return cast(AsyncDbSessionFactory, request.app[APP_STATE_SESSION_FACTORY]) def get_storage_service_from_request(request: web.Request) -> StorageService: diff --git a/src/aleph/web/controllers/channels.py b/src/aleph/web/controllers/channels.py index 47e65947d..9c222fa54 100644 --- a/src/aleph/web/controllers/channels.py +++ b/src/aleph/web/controllers/channels.py @@ -5,13 +5,13 @@ from aleph.db.accessors.messages import get_distinct_channels from aleph.types.channel import Channel -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession from aleph.web.controllers.app_state_getters import get_session_factory_from_request @cached(ttl=60 * 120, cache=SimpleMemoryCache, timeout=120) -async def get_channels(session: DbSession) -> List[Channel]: - channels = get_distinct_channels(session) +async def get_channels(session: AsyncDbSession) -> List[Channel]: + channels = await get_distinct_channels(session) return list(channels) @@ -23,7 +23,7 @@ async def used_channels(request: web.Request) -> web.Response: session_factory = get_session_factory_from_request(request) - with session_factory() as session: + async with session_factory() as session: channels = await get_channels(session) response = web.json_response({"channels": channels}) diff --git a/src/aleph/web/controllers/ipfs.py b/src/aleph/web/controllers/ipfs.py index 3e38f9150..86fb41f2f 100644 --- a/src/aleph/web/controllers/ipfs.py +++ b/src/aleph/web/controllers/ipfs.py @@ -58,15 +58,17 @@ async def ipfs_add_file(request: web.Request): ) size = stats["Size"] - with session_factory() as session: - upsert_file( + async with session_factory() as session: + await upsert_file( session=session, file_hash=cid, size=size, file_type=FileType.FILE, ) - add_grace_period_for_file(session=session, file_hash=cid, hours=grace_period) - session.commit() + await add_grace_period_for_file( + session=session, file_hash=cid, hours=grace_period + ) + await session.commit() output = { "status": "success", diff --git a/src/aleph/web/controllers/main.py b/src/aleph/web/controllers/main.py index 5b5c5b79e..3ebac7fe7 100644 --- a/src/aleph/web/controllers/main.py +++ b/src/aleph/web/controllers/main.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from aleph.db.accessors.metrics import query_metric_ccn, query_metric_crn -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from aleph.web.controllers.app_state_getters import ( get_node_cache_from_request, get_session_factory_from_request, @@ -22,9 +22,9 @@ async def index(request: web.Request) -> Dict: """Index of aleph.""" - session_factory: DbSessionFactory = get_session_factory_from_request(request) + session_factory: AsyncDbSessionFactory = get_session_factory_from_request(request) node_cache = get_node_cache_from_request(request) - with session_factory() as session: + async with session_factory() as session: return asdict(await get_metrics(session=session, node_cache=node_cache)) @@ -32,12 +32,12 @@ async def status_ws(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) - session_factory: DbSessionFactory = get_session_factory_from_request(request) + session_factory: AsyncDbSessionFactory = get_session_factory_from_request(request) node_cache = get_node_cache_from_request(request) previous_status = None while True: - with session_factory() as session: + async with session_factory() as session: status = await get_metrics(session=session, node_cache=node_cache) if status != previous_status: @@ -58,10 +58,10 @@ async def metrics(request: web.Request) -> web.Response: Naming convention: https://prometheus.io/docs/practices/naming/ """ - session_factory = get_session_factory_from_request(request) + session_factory: AsyncDbSessionFactory = get_session_factory_from_request(request) node_cache = get_node_cache_from_request(request) - with session_factory() as session: + async with session_factory() as session: return web.Response( text=format_dataclass_for_prometheus( await get_metrics(session=session, node_cache=node_cache) @@ -71,10 +71,10 @@ async def metrics(request: web.Request) -> web.Response: async def metrics_json(request: web.Request) -> web.Response: """JSON version of the Prometheus metrics.""" - session_factory: DbSessionFactory = get_session_factory_from_request(request) + session_factory: AsyncDbSessionFactory = get_session_factory_from_request(request) node_cache = get_node_cache_from_request(request) - with session_factory() as session: + async with session_factory() as session: return web.Response( text=(await get_metrics(session=session, node_cache=node_cache)).to_json(), content_type="application/json", @@ -97,13 +97,13 @@ def _get_node_id_from_request(request: web.Request) -> str: async def ccn_metric(request: web.Request) -> web.Response: """Fetch metrics for CCN node id""" - session_factory: DbSessionFactory = get_session_factory_from_request(request) + session_factory: AsyncDbSessionFactory = get_session_factory_from_request(request) query_params = Metrics.model_validate(request.query) node_id = _get_node_id_from_request(request) - with session_factory() as session: - ccn = query_metric_ccn( + async with session_factory() as session: + ccn = await query_metric_ccn( session, node_id=node_id, start_date=query_params.start_date, @@ -123,13 +123,13 @@ async def ccn_metric(request: web.Request) -> web.Response: async def crn_metric(request: web.Request) -> web.Response: """Fetch Metric for crn.""" - session_factory: DbSessionFactory = get_session_factory_from_request(request) + session_factory: AsyncDbSessionFactory = get_session_factory_from_request(request) query_params = Metrics.model_validate(request.query) node_id = _get_node_id_from_request(request) - with session_factory() as session: - crn = query_metric_crn( + async with session_factory() as session: + crn = await query_metric_crn( session, node_id=node_id, start_date=query_params.start_date, diff --git a/src/aleph/web/controllers/messages.py b/src/aleph/web/controllers/messages.py index ca4535728..cb84c962d 100644 --- a/src/aleph/web/controllers/messages.py +++ b/src/aleph/web/controllers/messages.py @@ -44,7 +44,7 @@ format_message_dict, ) from aleph.toolkit.shield import shielded -from aleph.types.db_session import DbSession, DbSessionFactory +from aleph.types.db_session import AsyncDbSession, AsyncDbSessionFactory from aleph.types.message_status import MessageStatus from aleph.types.sort_order import SortBy, SortOrder from aleph.web.controllers.app_state_getters import ( @@ -313,11 +313,11 @@ async def view_messages_list(request: web.Request) -> web.Response: pagination_per_page = query_params.pagination session_factory = get_session_factory_from_request(request) - with session_factory() as session: - messages = get_matching_messages( + async with session_factory() as session: + messages = await get_matching_messages( session, include_confirmations=True, **find_filters ) - total_msgs = count_matching_messages(session, **find_filters) + total_msgs = await count_matching_messages(session, **find_filters) return format_response( messages, @@ -329,12 +329,12 @@ async def view_messages_list(request: web.Request) -> web.Response: async def _send_history_to_ws( ws: aiohttp.web_ws.WebSocketResponse, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, history: int, query_params: WsMessageQueryParams, ) -> None: - with session_factory() as session: - messages = get_matching_messages( + async with session_factory() as session: + messages = await get_matching_messages( session=session, pagination=history, include_confirmations=True, @@ -505,15 +505,17 @@ async def messages_ws(request: web.Request) -> web.WebSocketResponse: return ws -def _get_message_with_status( - session: DbSession, status_db: MessageStatusDb +async def _get_message_with_status( + session: AsyncDbSession, status_db: MessageStatusDb ) -> MessageWithStatus: status = status_db.status item_hash = status_db.item_hash reception_time = status_db.reception_time if status == MessageStatus.PENDING: # There may be several instances of the same pending message, return the first. - pending_messages_db = get_pending_messages(session=session, item_hash=item_hash) + pending_messages_db = await get_pending_messages( + session=session, item_hash=item_hash + ) pending_messages = [ PendingMessage.model_validate(m) for m in pending_messages_db ] @@ -525,7 +527,7 @@ def _get_message_with_status( ) if status == MessageStatus.PROCESSED: - message_db = get_message_by_item_hash( + message_db = await get_message_by_item_hash( session=session, item_hash=ItemHash(item_hash) ) if not message_db: @@ -539,7 +541,7 @@ def _get_message_with_status( ) if status == MessageStatus.FORGOTTEN: - forgotten_message_db = get_forgotten_message( + forgotten_message_db = await get_forgotten_message( session=session, item_hash=item_hash ) if not forgotten_message_db: @@ -553,7 +555,9 @@ def _get_message_with_status( ) if status == MessageStatus.REJECTED: - rejected_message_db = get_rejected_message(session=session, item_hash=item_hash) + rejected_message_db = await get_rejected_message( + session=session, item_hash=item_hash + ) if not rejected_message_db: raise web.HTTPNotFound() @@ -578,12 +582,14 @@ async def view_message(request: web.Request): except ValueError: raise web.HTTPBadRequest(body=f"Invalid message hash: {item_hash_str}") - session_factory: DbSessionFactory = request.app["session_factory"] - with session_factory() as session: - message_status_db = get_message_status(session=session, item_hash=item_hash) + session_factory: AsyncDbSessionFactory = request.app["session_factory"] + async with session_factory() as session: + message_status_db = await get_message_status( + session=session, item_hash=item_hash + ) if message_status_db is None: raise web.HTTPNotFound() - message_with_status = _get_message_with_status( + message_with_status = await _get_message_with_status( session=session, status_db=message_status_db ) @@ -600,12 +606,14 @@ async def view_message_content(request: web.Request): except ValueError: raise web.HTTPBadRequest(body=f"Invalid message hash: {item_hash_str}") - session_factory: DbSessionFactory = request.app["session_factory"] - with session_factory() as session: - message_status_db = get_message_status(session=session, item_hash=item_hash) + session_factory: AsyncDbSessionFactory = request.app["session_factory"] + async with session_factory() as session: + message_status_db = await get_message_status( + session=session, item_hash=item_hash + ) if message_status_db is None: raise web.HTTPNotFound() - message_with_status = _get_message_with_status( + message_with_status = await _get_message_with_status( session=session, status_db=message_status_db ) @@ -639,9 +647,9 @@ async def view_message_status(request: web.Request): except ValueError: raise web.HTTPBadRequest(body=f"Invalid message hash: {item_hash_str}") - session_factory: DbSessionFactory = request.app["session_factory"] - with session_factory() as session: - message_status = get_message_status(session=session, item_hash=item_hash) + session_factory: AsyncDbSessionFactory = request.app["session_factory"] + async with session_factory() as session: + message_status = await get_message_status(session=session, item_hash=item_hash) if message_status is None: raise web.HTTPNotFound() @@ -661,15 +669,15 @@ async def view_message_hashes(request: web.Request): pagination_per_page = query_params.pagination session_factory = get_session_factory_from_request(request) - with session_factory() as session: - hashes = get_matching_hashes(session, **find_filters) + async with session_factory() as session: + hashes = await get_matching_hashes(session, **find_filters) if find_filters["hash_only"]: formatted_hashes = [h for h in hashes] else: formatted_hashes = [MessageHashes.model_validate(h) for h in hashes] - total_hashes = count_matching_hashes(session, **find_filters) + total_hashes = await count_matching_hashes(session, **find_filters) response = { "hashes": formatted_hashes, "pagination_per_page": pagination_per_page, diff --git a/src/aleph/web/controllers/metrics.py b/src/aleph/web/controllers/metrics.py index c9ed4ec19..70daa67cd 100644 --- a/src/aleph/web/controllers/metrics.py +++ b/src/aleph/web/controllers/metrics.py @@ -19,7 +19,7 @@ from aleph.db.models import FilePinDb, MessageDb, PeerDb, PendingMessageDb, PendingTxDb from aleph.services.cache.node_cache import NodeCache from aleph.types.chain_sync import ChainEventType -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession LOGGER = getLogger("WEB.metrics") @@ -147,16 +147,16 @@ async def fetch_eth_height() -> Optional[int]: # Cache metrics for 10 seconds by default @cached(ttl=config.aleph.cache.ttl.metrics.value) -async def get_metrics(session: DbSession, node_cache: NodeCache) -> Metrics: +async def get_metrics(session: AsyncDbSession, node_cache: NodeCache) -> Metrics: sync_messages_reference_total = await fetch_reference_total_messages() eth_reference_height = await fetch_eth_height() # select count(*) can be slow but estimates are not refreshed often enough to give # a meaningful sense of progress for the node main page on /. - sync_messages_total: int = MessageDb.count(session=session) - peers_count = PeerDb.count(session=session) + sync_messages_total: int = await MessageDb.count(session=session) + peers_count = await PeerDb.count(session=session) - eth_last_committed_height = get_last_height( + eth_last_committed_height = await get_last_height( session=session, chain=Chain.ETH, sync_type=ChainEventType.SYNC ) @@ -182,13 +182,15 @@ async def get_metrics(session: DbSession, node_cache: NodeCache) -> Metrics: pyaleph_status_peers_total=peers_count, pyaleph_processing_pending_messages_tasks_total=nb_message_jobs, pyaleph_status_sync_messages_total=sync_messages_total, - pyaleph_status_sync_permanent_files_total=FilePinDb.count(session=session), + pyaleph_status_sync_permanent_files_total=await FilePinDb.count( + session=session + ), pyaleph_status_sync_messages_reference_total=sync_messages_reference_total, pyaleph_status_sync_messages_remaining_total=sync_messages_remaining_total, - pyaleph_status_sync_pending_messages_total=PendingMessageDb.count( + pyaleph_status_sync_pending_messages_total=await PendingMessageDb.count( session=session ), - pyaleph_status_sync_pending_txs_total=PendingTxDb.count(session=session), + pyaleph_status_sync_pending_txs_total=await PendingTxDb.count(session=session), pyaleph_status_chain_eth_last_committed_height=eth_last_committed_height, pyaleph_status_chain_eth_height_reference_total=eth_reference_height, pyaleph_status_chain_eth_height_remaining_total=eth_remaining_height, diff --git a/src/aleph/web/controllers/posts.py b/src/aleph/web/controllers/posts.py index b2e4ab7c8..cb64a9b35 100644 --- a/src/aleph/web/controllers/posts.py +++ b/src/aleph/web/controllers/posts.py @@ -20,7 +20,7 @@ get_matching_posts_legacy, ) from aleph.db.models import ChainTxDb, message_confirmations -from aleph.types.db_session import DbSession, DbSessionFactory +from aleph.types.db_session import AsyncDbSession, AsyncDbSessionFactory from aleph.types.sort_order import SortBy, SortOrder from aleph.web.controllers.utils import ( DEFAULT_MESSAGES_PER_PAGE, @@ -122,8 +122,8 @@ def merged_post_to_dict(merged_post: MergedPost) -> Dict[str, Any]: } -def get_post_confirmations( - session: DbSession, post: MergedPostV0 +async def get_post_confirmations( + session: AsyncDbSession, post: MergedPostV0 ) -> List[Dict[str, Any]]: select_stmt = ( select( @@ -136,17 +136,17 @@ def get_post_confirmations( .where(message_confirmations.c.item_hash == post.item_hash) ) - results = session.execute(select_stmt).all() + results = (await session.execute(select_stmt)).all() return [ {"chain": row.chain, "hash": row.hash, "height": row.height} for row in results ] -def merged_post_v0_to_dict( - session: DbSession, merged_post: MergedPostV0 +async def merged_post_v0_to_dict( + session: AsyncDbSession, merged_post: MergedPostV0 ) -> Dict[str, Any]: - confirmations = get_post_confirmations(session, merged_post) + confirmations = await get_post_confirmations(session, merged_post) return { "chain": merged_post.chain, @@ -193,12 +193,12 @@ async def view_posts_list_v0(request: web.Request) -> web.Response: pagination_page = query_params.page pagination_per_page = query_params.pagination - session_factory: DbSessionFactory = request.app["session_factory"] + session_factory: AsyncDbSessionFactory = request.app["session_factory"] - with session_factory() as session: - total_posts = count_matching_posts(session=session, **find_filters) - results = get_matching_posts_legacy(session=session, **find_filters) - posts = [merged_post_v0_to_dict(session, post) for post in results] + async with session_factory() as session: + total_posts = await count_matching_posts(session=session, **find_filters) + results = await get_matching_posts_legacy(session=session, **find_filters) + posts = [await merged_post_v0_to_dict(session, post) for post in results] context: Dict[str, Any] = {"posts": posts} @@ -244,10 +244,10 @@ async def view_posts_list_v1(request) -> web.Response: pagination_page = query_params.page pagination_per_page = query_params.pagination - session_factory: DbSessionFactory = request.app["session_factory"] - with session_factory() as session: - total_posts = count_matching_posts(session=session, **find_filters) - results = get_matching_posts(session=session, **find_filters) + session_factory: AsyncDbSessionFactory = request.app["session_factory"] + async with session_factory() as session: + total_posts = await count_matching_posts(session=session, **find_filters) + results = await get_matching_posts(session=session, **find_filters) posts = [merged_post_to_dict(post) for post in results] context: Dict[str, Any] = {"posts": posts} diff --git a/src/aleph/web/controllers/prices.py b/src/aleph/web/controllers/prices.py index 5d22ed3ce..591e2b66d 100644 --- a/src/aleph/web/controllers/prices.py +++ b/src/aleph/web/controllers/prices.py @@ -23,7 +23,7 @@ get_total_and_detailed_costs_from_db, ) from aleph.toolkit.costs import format_cost_str -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession from aleph.types.message_status import MessageStatus from aleph.web.controllers.app_state_getters import ( get_session_factory_from_request, @@ -56,7 +56,9 @@ class MessagePrice(DataClassJsonMixin): required_tokens: Optional[Decimal] = None -async def get_executable_message(session: DbSession, item_hash_str: str) -> MessageDb: +async def get_executable_message( + session: AsyncDbSession, item_hash_str: str +) -> MessageDb: """Attempt to get an executable message from the database. Raises an HTTP exception if the message is not found, not processed or is not an executable message. """ @@ -68,7 +70,7 @@ async def get_executable_message(session: DbSession, item_hash_str: str) -> Mess raise web.HTTPBadRequest(body=f"Invalid message hash: {item_hash_str}") # Get the message status from the database - message_status_db = get_message_status(session=session, item_hash=item_hash) + message_status_db = await get_message_status(session=session, item_hash=item_hash) if not message_status_db: raise web.HTTPNotFound(body=f"Message not found with hash: {item_hash}") # Loop through the status_exceptions to find a match and raise the corresponding exception @@ -78,7 +80,7 @@ async def get_executable_message(session: DbSession, item_hash_str: str) -> Mess assert message_status_db.status == MessageStatus.PROCESSED # Get the message from the database - message: Optional[MessageDb] = get_message_by_item_hash(session, item_hash) + message: Optional[MessageDb] = await get_message_by_item_hash(session, item_hash) if not message: raise web.HTTPNotFound(body="Message not found, despite appearing as processed") if message.type not in ( @@ -97,14 +99,14 @@ async def message_price(request: web.Request): """Returns the price of an executable message.""" session_factory = get_session_factory_from_request(request) - with session_factory() as session: + async with session_factory() as session: item_hash = request.match_info["item_hash"] message = await get_executable_message(session, item_hash) content: ExecutableContent = message.parsed_content try: payment_type = get_payment_type(content) - required_tokens, costs = get_total_and_detailed_costs_from_db( + required_tokens, costs = await get_total_and_detailed_costs_from_db( session, content, item_hash ) @@ -133,7 +135,7 @@ async def message_price_estimate(request: web.Request): session_factory = get_session_factory_from_request(request) storage_service = get_storage_service_from_request(request) - with session_factory() as session: + async with session_factory() as session: parsed_body = PubMessageRequest.model_validate(await request.json()) message = validate_cost_estimation_message_dict(parsed_body.message_dict) content = await validate_cost_estimation_message_content( @@ -143,7 +145,7 @@ async def message_price_estimate(request: web.Request): try: payment_type = get_payment_type(content) - required_tokens, costs = get_total_and_detailed_costs( + required_tokens, costs = await get_total_and_detailed_costs( session, content, item_hash ) diff --git a/src/aleph/web/controllers/programs.py b/src/aleph/web/controllers/programs.py index 408178fc6..6047bcb35 100644 --- a/src/aleph/web/controllers/programs.py +++ b/src/aleph/web/controllers/programs.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, ConfigDict, ValidationError from aleph.db.accessors.messages import get_programs_triggered_by_messages -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from aleph.types.sort_order import SortOrder @@ -20,9 +20,9 @@ async def get_programs_on_message(request: web.Request) -> web.Response: data=error.json(), status=web.HTTPBadRequest.status_code ) - session_factory: DbSessionFactory = request.app["session_factory"] + session_factory: AsyncDbSessionFactory = request.app["session_factory"] - with session_factory() as session: + async with session_factory() as session: messages = [ { "item_hash": result.item_hash, @@ -30,7 +30,7 @@ async def get_programs_on_message(request: web.Request) -> web.Response: "on": {"message": result.message_subscriptions}, }, } - for result in get_programs_triggered_by_messages( + for result in await get_programs_triggered_by_messages( session=session, sort_order=query.sort_order ) ] diff --git a/src/aleph/web/controllers/storage.py b/src/aleph/web/controllers/storage.py index 2a8bf104b..f188205aa 100644 --- a/src/aleph/web/controllers/storage.py +++ b/src/aleph/web/controllers/storage.py @@ -31,7 +31,7 @@ MAX_UNAUTHENTICATED_UPLOAD_FILE_SIZE, MiB, ) -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession from aleph.types.message_status import InvalidSignature from aleph.utils import item_type_from_hash, run_in_executor from aleph.web.controllers.app_state_getters import ( @@ -59,17 +59,17 @@ async def add_ipfs_json_controller(request: web.Request): grace_period = config.storage.grace_period.value data = await request.json() - with session_factory() as session: + async with session_factory() as session: output = { "status": "success", "hash": await storage_service.add_json( session=session, value=data, engine=ItemType.ipfs ), } - add_grace_period_for_file( + await add_grace_period_for_file( session=session, file_hash=output["hash"], hours=grace_period ) - session.commit() + await session.commit() return web.json_response(output) @@ -82,17 +82,17 @@ async def add_storage_json_controller(request: web.Request): grace_period = config.storage.grace_period.value data = await request.json() - with session_factory() as session: + async with session_factory() as session: output = { "status": "success", "hash": await storage_service.add_json( session=session, value=data, engine=ItemType.storage ), } - add_grace_period_for_file( + await add_grace_period_for_file( session=session, file_hash=output["hash"], hours=grace_period ) - session.commit() + await session.commit() return web.json_response(output) @@ -107,16 +107,18 @@ async def _verify_message_signature( async def _verify_user_balance( - session: DbSession, content: CostEstimationStoreContent + session: AsyncDbSession, content: CostEstimationStoreContent ) -> None: if content.estimated_size_mib and content.estimated_size_mib > ( MAX_UNAUTHENTICATED_UPLOAD_FILE_SIZE / MiB ): - current_balance = get_total_balance(session=session, address=content.address) - current_cost = get_total_cost_for_address( + current_balance = await get_total_balance( session=session, address=content.address ) - message_cost, _ = get_total_and_detailed_costs(session, content, "") + current_cost = await get_total_cost_for_address( + session=session, address=content.address + ) + message_cost, _ = await get_total_and_detailed_costs(session, content, "") required_balance = current_cost + message_cost @@ -217,7 +219,7 @@ async def _read_chunks(self, chunk_size): async def _check_and_add_file( - session: DbSession, + session: AsyncDbSession, signature_verifier: SignatureVerifier, storage_service: StorageService, message: Optional[PendingStoreMessage], @@ -271,7 +273,7 @@ async def _check_and_add_file( # For files uploaded without authenticated upload, add a grace period of 1 day. if message_content is None: - add_grace_period_for_file( + await add_grace_period_for_file( session=session, file_hash=file_hash, hours=grace_period ) return file_hash @@ -350,7 +352,7 @@ async def storage_add_file(request: web.Request): message = None sync = False - with session_factory() as session: + async with session_factory() as session: file_hash = await _check_and_add_file( session=session, signature_verifier=signature_verifier, @@ -359,7 +361,7 @@ async def storage_add_file(request: web.Request): uploaded_file=uploaded_file, grace_period=grace_period, ) - session.commit() + await session.commit() if message: broadcast_status = await broadcast_and_process_message( pending_message=message, sync=sync, request=request, logger=logger @@ -374,12 +376,12 @@ async def storage_add_file(request: web.Request): await uploaded_file.cleanup() -def assert_file_is_downloadable(session: DbSession, file_hash: str) -> None: +async def assert_file_is_downloadable(session: AsyncDbSession, file_hash: str) -> None: """ Check if the file is on the aleph.im network and can be downloaded from the API. This filters out requests for files outside the network / nonexistent files. """ - file_metadata = get_file(session=session, file_hash=file_hash) + file_metadata = await get_file(session=session, file_hash=file_hash) if not file_metadata: raise web.HTTPNotFound(text="Not found") @@ -404,8 +406,8 @@ async def get_hash(request): return web.HTTPBadRequest(text="Invalid hash provided") session_factory = get_session_factory_from_request(request) - with session_factory() as session: - assert_file_is_downloadable(session=session, file_hash=item_hash) + async with session_factory() as session: + await assert_file_is_downloadable(session=session, file_hash=item_hash) storage_service = get_storage_service_from_request(request) @@ -446,8 +448,8 @@ async def get_raw_hash(request): raise web.HTTPBadRequest(text="Invalid hash") session_factory = get_session_factory_from_request(request) - with session_factory() as session: - assert_file_is_downloadable(session=session, file_hash=item_hash) + async with session_factory() as session: + await assert_file_is_downloadable(session=session, file_hash=item_hash) storage_service = get_storage_service_from_request(request) @@ -475,6 +477,6 @@ async def get_file_pins_count(request: web.Request) -> web.Response: raise web.HTTPBadRequest(text="No hash provided") session_factory = get_session_factory_from_request(request) - with session_factory() as session: - count = count_file_pins(session=session, file_hash=item_hash) + async with session_factory() as session: + count = await count_file_pins(session=session, file_hash=item_hash) return web.json_response(data=count) diff --git a/src/aleph/web/controllers/utils.py b/src/aleph/web/controllers/utils.py index 7e0a4ac7b..1eef8a8f5 100644 --- a/src/aleph/web/controllers/utils.py +++ b/src/aleph/web/controllers/utils.py @@ -22,7 +22,7 @@ from aleph.services.p2p.pubsub import publish as pub_p2p from aleph.toolkit.shield import shielded from aleph.toolkit.timestamp import utc_now -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession from aleph.types.message_status import ( InvalidMessageException, MessageProcessingStatus, @@ -405,10 +405,12 @@ async def broadcast_and_process_message( return status -def add_grace_period_for_file(session: DbSession, file_hash: str, hours: int): +async def add_grace_period_for_file( + session: AsyncDbSession, file_hash: str, hours: int +): current_datetime = utc_now() delete_by = current_datetime + dt.timedelta(hours=hours) - insert_grace_period_file_pin( + await insert_grace_period_file_pin( session=session, file_hash=file_hash, created=utc_now(), diff --git a/tests/api/conftest.py b/tests/api/conftest.py index b45680b9a..01fa68f52 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -23,23 +23,23 @@ from aleph.jobs.process_pending_messages import PendingMessageProcessor from aleph.storage import StorageService from aleph.toolkit.timestamp import timestamp_to_datetime -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory # TODO: remove the raw parameter, it's just to avoid larger refactorings async def _load_fixtures( - session_factory: DbSessionFactory, filename: str, raw: bool = True + session_factory: AsyncDbSessionFactory, filename: str, raw: bool = True ) -> Sequence[Dict[str, Any]]: fixtures_dir = Path(__file__).parent / "fixtures" fixtures_file = fixtures_dir / filename - with fixtures_file.open() as f: + async with fixtures_file.open() as f: messages_json = json.load(f) messages = [] tx_hashes = set() - with session_factory() as session: + async with session_factory() as session: for message_dict in messages_json: message_db = MessageDb.from_message_dict(message_dict) messages.append(message_db) @@ -50,20 +50,20 @@ async def _load_fixtures( tx_hashes.add(tx_hash) session.add(chain_tx_db) - session.flush() - session.execute( + await session.flush() + await session.execute( insert(message_confirmations).values( item_hash=message_db.item_hash, tx_hash=tx_hash ) ) - session.commit() + await session.commit() return messages_json if raw else messages @pytest_asyncio.fixture async def fixture_messages( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, ) -> Sequence[Dict[str, Any]]: return await _load_fixtures(session_factory, "fixture_messages.json") @@ -83,23 +83,23 @@ def make_aggregate_element(message: MessageDb) -> AggregateElementDb: @pytest_asyncio.fixture async def fixture_aggregate_messages( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, ) -> Sequence[MessageDb]: messages = await _load_fixtures( session_factory, "fixture_aggregates.json", raw=False ) aggregate_keys = set() - with session_factory() as session: + async with session_factory() as session: for message in messages: aggregate_element = make_aggregate_element(message) # type: ignore session.add(aggregate_element) aggregate_keys.add((aggregate_element.owner, aggregate_element.key)) - session.commit() + await session.commit() for owner, key in aggregate_keys: - refresh_aggregate(session=session, owner=owner, key=key) + await refresh_aggregate(session=session, owner=owner, key=key) - session.commit() + await session.commit() return messages # type: ignore @@ -120,14 +120,14 @@ def make_post_db(message: MessageDb) -> PostDb: @pytest_asyncio.fixture async def fixture_posts( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, ) -> Sequence[PostDb]: messages = await _load_fixtures(session_factory, "fixture_posts.json", raw=False) posts = [make_post_db(message) for message in messages] # type: ignore - with session_factory() as session: + async with session_factory() as session: session.add_all(posts) - session.commit() + await session.commit() return posts @@ -197,7 +197,9 @@ def amended_post_with_refs_and_tags(post_with_refs_and_tags: Tuple[MessageDb, Po @pytest.fixture -def message_processor(mocker, mock_config: Config, session_factory: DbSessionFactory): +def message_processor( + mocker, mock_config: Config, session_factory: AsyncDbSessionFactory +): storage_engine = InMemoryStorageEngine(files={}) storage_service = StorageService( storage_engine=storage_engine, diff --git a/tests/api/test_get_message.py b/tests/api/test_get_message.py index 095ffc5f8..555e1705e 100644 --- a/tests/api/test_get_message.py +++ b/tests/api/test_get_message.py @@ -2,6 +2,7 @@ from typing import Any, Mapping, Sequence import pytest +import pytest_asyncio import pytz from aleph_message.models import Chain, ItemType, MessageType @@ -20,7 +21,7 @@ ) from aleph.toolkit.timestamp import timestamp_to_datetime from aleph.types.channel import Channel -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from aleph.types.message_status import ErrorCode, MessageStatus MESSAGE_URI = "/api/v0/messages/{}" @@ -30,9 +31,9 @@ RECEPTION_DATETIME = pytz.utc.localize(dt.datetime(2023, 1, 1)) -@pytest.fixture -def fixture_messages_with_status( - session_factory: DbSessionFactory, +@pytest_asyncio.fixture +async def fixture_messages_with_status( + session_factory: AsyncDbSessionFactory, ) -> Mapping[MessageStatus, Sequence[Any]]: pending_messages = [ @@ -171,7 +172,7 @@ def fixture_messages_with_status( MessageStatus.REJECTED: rejected_messages, } - with session_factory() as session: + async with session_factory() as session: for status, messages in messages_dict.items(): for message in messages: session.add(message) @@ -182,7 +183,7 @@ def fixture_messages_with_status( reception_time=RECEPTION_DATETIME, ) ) - session.commit() + await session.commit() return messages_dict diff --git a/tests/api/test_list_messages.py b/tests/api/test_list_messages.py index 0db1bc379..77418b96b 100644 --- a/tests/api/test_list_messages.py +++ b/tests/api/test_list_messages.py @@ -21,7 +21,7 @@ from aleph.db.models import MessageDb, PostDb from aleph.toolkit.timestamp import timestamp_to_datetime from aleph.types.channel import Channel -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from .utils import get_messages_by_keys @@ -170,7 +170,7 @@ async def test_get_messages_multiple_hashes(fixture_messages, ccn_api_client): async def test_get_messages_filter_by_tags( fixture_messages, ccn_api_client, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, post_with_refs_and_tags: Tuple[MessageDb, PostDb], amended_post_with_refs_and_tags: Tuple[MessageDb, PostDb], ): @@ -182,9 +182,9 @@ async def test_get_messages_filter_by_tags( message_db, _ = post_with_refs_and_tags amend_message_db, _ = amended_post_with_refs_and_tags - with session_factory() as session: + async with session_factory() as session: session.add_all([message_db, amend_message_db]) - session.commit() + await session.commit() # Matching tag for both messages response = await ccn_api_client.get(MESSAGES_URI, params={"tags": "mainnet"}) @@ -552,11 +552,11 @@ def instance_message_fixture() -> MessageDb: async def test_get_instance( ccn_api_client, instance_message_fixture: MessageDb, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, ): - with session_factory() as session: + async with session_factory() as session: session.add(instance_message_fixture) - session.commit() + await session.commit() response = await ccn_api_client.get( MESSAGES_URI, params={"hashes": instance_message_fixture.item_hash} diff --git a/tests/api/test_new_metric.py b/tests/api/test_new_metric.py index 5d25ca2fc..b795688a0 100644 --- a/tests/api/test_new_metric.py +++ b/tests/api/test_new_metric.py @@ -4,7 +4,7 @@ import pytest import pytest_asyncio -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from .conftest import _load_fixtures @@ -15,7 +15,7 @@ def _generate_uri(node_type: str, node_id: str) -> str: @pytest_asyncio.fixture async def fixture_metrics_messages( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, ) -> Sequence[Dict[str, Any]]: return await _load_fixtures(session_factory, "test-metric.json") diff --git a/tests/api/test_posts.py b/tests/api/test_posts.py index bd0e7b440..48d84a3bc 100644 --- a/tests/api/test_posts.py +++ b/tests/api/test_posts.py @@ -5,7 +5,7 @@ from aleph.db.models import MessageDb from aleph.db.models.posts import PostDb -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory POSTS_URI = "/api/v1/posts.json" @@ -56,17 +56,17 @@ async def test_get_posts(ccn_api_client, fixture_posts: Sequence[PostDb]): @pytest.mark.asyncio async def test_get_posts_refs( ccn_api_client, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, fixture_posts: Sequence[PostDb], post_with_refs_and_tags: Tuple[MessageDb, PostDb], ): message_db, post_db = post_with_refs_and_tags - with session_factory() as session: + async with session_factory() as session: session.add_all(fixture_posts) session.add(message_db) session.add(post_db) - session.commit() + await session.commit() # Match the ref response = await ccn_api_client.get( @@ -111,7 +111,7 @@ async def test_get_posts_refs( @pytest.mark.asyncio async def test_get_amended_posts_refs( ccn_api_client, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, fixture_posts: Sequence[PostDb], post_with_refs_and_tags: Tuple[MessageDb, PostDb], amended_post_with_refs_and_tags: Tuple[MessageDb, PostDb], @@ -121,13 +121,13 @@ async def test_get_amended_posts_refs( original_post_db.latest_amend = amend_post_db.item_hash - with session_factory() as session: + async with session_factory() as session: session.add_all(fixture_posts) session.add(original_message_db) session.add(original_post_db) session.add(amend_message_db) session.add(amend_post_db) - session.commit() + await session.commit() # Match the ref response = await ccn_api_client.get( @@ -172,17 +172,17 @@ async def test_get_amended_posts_refs( @pytest.mark.asyncio async def test_get_posts_tags( ccn_api_client, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, fixture_posts: Sequence[PostDb], post_with_refs_and_tags: Tuple[MessageDb, PostDb], ): message_db, post_db = post_with_refs_and_tags - with session_factory() as session: + async with session_factory() as session: session.add_all(fixture_posts) session.add(message_db) session.add(post_db) - session.commit() + await session.commit() # Match one tag response = await ccn_api_client.get( @@ -242,7 +242,7 @@ async def test_get_posts_tags( @pytest.mark.asyncio async def test_get_amended_posts_tags( ccn_api_client, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, fixture_posts: Sequence[PostDb], post_with_refs_and_tags: Tuple[MessageDb, PostDb], amended_post_with_refs_and_tags: Tuple[MessageDb, PostDb], @@ -252,13 +252,13 @@ async def test_get_amended_posts_tags( original_post_db.latest_amend = amend_post_db.item_hash - with session_factory() as session: + async with session_factory() as session: session.add_all(fixture_posts) session.add(original_message_db) session.add(original_post_db) session.add(amend_message_db) session.add(amend_post_db) - session.commit() + await session.commit() # Match one tag response = await ccn_api_client.get("/api/v0/posts.json", params={"tags": "amend"}) diff --git a/tests/api/test_storage.py b/tests/api/test_storage.py index 0e15de67e..bda8dca59 100644 --- a/tests/api/test_storage.py +++ b/tests/api/test_storage.py @@ -14,7 +14,7 @@ from aleph.db.accessors.files import get_file from aleph.db.models import AlephBalanceDb from aleph.storage import StorageService -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from aleph.types.files import FileType from aleph.types.message_status import MessageStatus from aleph.web.controllers.app_state_getters import ( @@ -97,7 +97,7 @@ async def api_client(ccn_test_aiohttp_app, mocker, aiohttp_client): async def add_file_raw_upload( api_client, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, uri: str, file_content: bytes, expected_file_hash: str, @@ -118,8 +118,8 @@ async def add_file_raw_upload( response_data = await get_file_response.read() # Check that the file appears in the DB - with session_factory() as session: - file = get_file(session=session, file_hash=file_hash) + async with session_factory() as session: + file = await get_file(session=session, file_hash=file_hash) assert file is not None assert file.hash == file_hash assert file.type == FileType.FILE @@ -130,7 +130,7 @@ async def add_file_raw_upload( async def add_file( api_client, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, uri: str, file_content: bytes, expected_file_hash: str, @@ -152,8 +152,8 @@ async def add_file( response_data = await get_file_response.read() # Check that the file appears in the DB - with session_factory() as session: - file = get_file(session=session, file_hash=file_hash) + async with session_factory() as session: + file = await get_file(session=session, file_hash=file_hash) assert file is not None assert file.hash == file_hash assert file.type == FileType.FILE @@ -164,7 +164,7 @@ async def add_file( async def add_file_with_message( api_client, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, uri: str, file_content: bytes, error_code: int, @@ -179,7 +179,7 @@ async def add_file_with_message( ), ) - with session_factory() as session: + async with session_factory() as session: session.add( AlephBalanceDb( address="0x6dA130FD646f826C1b8080C07448923DF9a79aaA", @@ -188,7 +188,7 @@ async def add_file_with_message( eth_height=0, ) ) - session.commit() + await session.commit() form_data = aiohttp.FormData() @@ -206,7 +206,7 @@ async def add_file_with_message( async def add_file_with_message_202( api_client, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, uri: str, file_content: bytes, size: str, @@ -221,7 +221,7 @@ async def add_file_with_message_202( message_status=MessageStatus.PENDING, ), ) - with session_factory() as session: + async with session_factory() as session: session.add( AlephBalanceDb( address="0x6dA130FD646f826C1b8080C07448923DF9a79aaA", @@ -230,7 +230,7 @@ async def add_file_with_message_202( eth_height=0, ) ) - session.commit() + await session.commit() form_data = aiohttp.FormData() @@ -247,7 +247,7 @@ async def add_file_with_message_202( @pytest.mark.asyncio -async def test_storage_add_file(api_client, session_factory: DbSessionFactory): +async def test_storage_add_file(api_client, session_factory: AsyncDbSessionFactory): await add_file( api_client, session_factory, @@ -259,7 +259,7 @@ async def test_storage_add_file(api_client, session_factory: DbSessionFactory): @pytest.mark.asyncio async def test_storage_add_file_raw_upload( - api_client, session_factory: DbSessionFactory + api_client, session_factory: AsyncDbSessionFactory ): await add_file_raw_upload( api_client, @@ -294,7 +294,7 @@ async def test_storage_add_file_with_message( api_client, fixture_product_prices_aggregate_in_db, fixture_settings_aggregate_in_db, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, file_content, expected_hash, size: Optional[int], @@ -328,7 +328,7 @@ async def test_storage_add_file_with_message( @pytest.mark.asyncio async def test_storage_add_file_with_message_202( api_client, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, file_content, expected_hash, size, @@ -351,7 +351,7 @@ async def test_storage_add_file_with_message_202( @pytest.mark.asyncio -async def test_ipfs_add_file(api_client, session_factory: DbSessionFactory): +async def test_ipfs_add_file(api_client, session_factory: AsyncDbSessionFactory): await add_file( api_client, session_factory, @@ -363,7 +363,7 @@ async def test_ipfs_add_file(api_client, session_factory: DbSessionFactory): async def add_json( api_client, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, uri: str, json: Any, expected_file_hash: ItemHash, @@ -393,8 +393,8 @@ async def add_json( assert await get_file_response.read() == serialized_json # Check that the file appears in the DB - with session_factory() as session: - file = get_file(session=session, file_hash=file_hash) + async with session_factory() as session: + file = await get_file(session=session, file_hash=file_hash) assert file is not None assert file.hash == file_hash assert file.type == FileType.FILE @@ -403,7 +403,7 @@ async def add_json( @pytest.mark.asyncio -async def test_storage_add_json(api_client, session_factory: DbSessionFactory): +async def test_storage_add_json(api_client, session_factory: AsyncDbSessionFactory): await add_json( api_client, session_factory, @@ -414,7 +414,7 @@ async def test_storage_add_json(api_client, session_factory: DbSessionFactory): @pytest.mark.asyncio -async def test_ipfs_add_json(api_client, session_factory: DbSessionFactory): +async def test_ipfs_add_json(api_client, session_factory: AsyncDbSessionFactory): await add_json( api_client, session_factory, diff --git a/tests/balances/test_balances.py b/tests/balances/test_balances.py index 1bb25819b..fa12c888c 100644 --- a/tests/balances/test_balances.py +++ b/tests/balances/test_balances.py @@ -7,7 +7,7 @@ from aleph.db.accessors.balances import get_balance_by_chain, get_total_balance from aleph.db.models import AlephBalanceDb from aleph.handlers.content.post import update_balances -from aleph.types.db_session import DbSession, DbSessionFactory +from aleph.types.db_session import AsyncDbSession, AsyncDbSessionFactory BALANCES_CONTENT_SOL: Mapping[str, Any] = { "tags": ["SOL", "SPL", "CsZ5LZkDS7h9TDKjrbL7VAwQZ9nsRu8vJLhRYfmGaN8K", "mainnet"], @@ -63,62 +63,67 @@ } -def compare_balances( - session: DbSession, balances: Mapping[str, float], chain: Chain, dapp: Optional[str] +async def compare_balances( + session: AsyncDbSession, + balances: Mapping[str, float], + chain: Chain, + dapp: Optional[str], ): for address, expected_balance in balances.items(): - balance_db = get_balance_by_chain( + balance_db = await get_balance_by_chain( session, address=address, chain=chain, dapp=dapp ) assert balance_db is not None # Easier to compare decimals and floats as strings assert str(balance_db) == str(expected_balance) - nb_balances_db = AlephBalanceDb.count(session) + nb_balances_db = await AlephBalanceDb.count(session) assert nb_balances_db == len(balances) @pytest.mark.asyncio -async def test_process_balances_solana(session_factory: DbSessionFactory): +async def test_process_balances_solana(session_factory: AsyncDbSessionFactory): content = BALANCES_CONTENT_SOL - with session_factory() as session: - update_balances(session=session, content=content) - session.commit() + async with session_factory() as session: + await update_balances(session=session, content=content) + await session.commit() balances = content["balances"] - compare_balances(session=session, balances=balances, chain=Chain.SOL, dapp=None) + await compare_balances( + session=session, balances=balances, chain=Chain.SOL, dapp=None + ) @pytest.mark.asyncio -async def test_process_balances_sablier(session_factory: DbSessionFactory): +async def test_process_balances_sablier(session_factory: AsyncDbSessionFactory): content = BALANCES_CONTENT_SABLIER - with session_factory() as session: - update_balances(session=session, content=content) - session.commit() + async with session_factory() as session: + await update_balances(session=session, content=content) + await session.commit() balances = content["balances"] - compare_balances( + await compare_balances( session=session, balances=balances, chain=Chain.ETH, dapp="SABLIER" ) @pytest.mark.asyncio -async def test_update_balances(session_factory: DbSessionFactory): +async def test_update_balances(session_factory: AsyncDbSessionFactory): content = BALANCES_CONTENT_SOL - with session_factory() as session: - update_balances(session=session, content=content) - session.commit() + async with session_factory() as session: + await update_balances(session=session, content=content) + await session.commit() new_content = BALANCES_CONTENT_SOL_UPDATE - with session_factory() as session: - update_balances(session=session, content=new_content) - session.commit() + async with session_factory() as session: + await update_balances(session=session, content=new_content) + await session.commit() session.expire_all() - compare_balances( + await compare_balances( session=session, balances=new_content["balances"], chain=Chain.SOL, @@ -126,11 +131,12 @@ async def test_update_balances(session_factory: DbSessionFactory): ) -def test_get_total_balance(session_factory: DbSessionFactory): +@pytest.mark.asyncio +async def test_get_total_balance(session_factory: AsyncDbSessionFactory): address_1 = "my-address" address_2 = "your-address" - with session_factory() as session: + async with session_factory() as session: session.add( AlephBalanceDb( address=address_1, @@ -167,25 +173,25 @@ def test_get_total_balance(session_factory: DbSessionFactory): eth_height=0, ) ) - session.commit() + await session.commit() - with session_factory() as session: - balance_with_dapps = get_total_balance( + async with session_factory() as session: + balance_with_dapps = await get_total_balance( session=session, address=address_1, include_dapps=True ) assert balance_with_dapps == 1_001_100_000 - balance_no_dapps = get_total_balance( + balance_no_dapps = await get_total_balance( session=session, address=address_1, include_dapps=False ) assert balance_no_dapps == 1_100_000 - balance_address_2 = get_total_balance( + balance_address_2 = await get_total_balance( session=session, address=address_2, include_dapps=False ) assert balance_address_2 == 3 - balance_unknown_address = get_total_balance( + balance_unknown_address = await get_total_balance( session=session, address="unknown-address", include_dapps=False ) assert balance_unknown_address == Decimal(0) diff --git a/tests/chains/test_chain_data_service.py b/tests/chains/test_chain_data_service.py index 9fdd2a30d..37a2c8fab 100644 --- a/tests/chains/test_chain_data_service.py +++ b/tests/chains/test_chain_data_service.py @@ -18,7 +18,7 @@ from aleph.schemas.pending_messages import parse_message from aleph.toolkit.timestamp import timestamp_to_datetime from aleph.types.chain_sync import ChainSyncProtocol -from aleph.types.db_session import DbSession, DbSessionFactory +from aleph.types.db_session import AsyncDbSession, AsyncDbSessionFactory @pytest.mark.asyncio @@ -42,7 +42,7 @@ async def test_prepare_sync_event_payload(mocker): ] async def mock_add_file( - session: DbSession, file_content: bytes, engine: ItemType = ItemType.ipfs + session: AsyncDbSession, file_content: bytes, engine: ItemType = ItemType.ipfs ) -> str: content = file_content archive = OnChainSyncEventPayload.model_validate_json(content) @@ -70,7 +70,7 @@ async def mock_add_file( @pytest.mark.asyncio async def test_smart_contract_protocol_ipfs_store( - mocker, session_factory: DbSessionFactory + mocker, session_factory: AsyncDbSessionFactory ): payload = MessageEventPayload( timestamp=1668611900, @@ -122,7 +122,7 @@ async def test_smart_contract_protocol_ipfs_store( @pytest.mark.asyncio async def test_smart_contract_protocol_regular_message( - mocker, session_factory: DbSessionFactory + mocker, session_factory: AsyncDbSessionFactory ): content = PostContent( content={"body": "My first post on Tezos"}, diff --git a/tests/chains/test_confirmation.py b/tests/chains/test_confirmation.py index c874e3e81..c710ca04d 100644 --- a/tests/chains/test_confirmation.py +++ b/tests/chains/test_confirmation.py @@ -13,7 +13,7 @@ from aleph.handlers.message_handler import MessageHandler from aleph.storage import StorageService from aleph.types.chain_sync import ChainSyncProtocol -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory MESSAGE_DICT: Mapping = { "chain": "ETH", @@ -53,7 +53,7 @@ def compare_chain_txs(expected: ChainTxDb, actual: ChainTxDb): @pytest.mark.asyncio async def test_confirm_message( mock_config: Config, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, test_storage_service: StorageService, chain_tx: ChainTxDb, ): @@ -83,12 +83,14 @@ async def test_confirm_message( reception_time=dt.datetime(2022, 1, 1), fetched=True, ) - with session_factory() as session: + async with session_factory() as session: await message_handler.process(session=session, pending_message=pending_message) - session.commit() + await session.commit() - with session_factory() as session: - message_in_db = get_message_by_item_hash(session=session, item_hash=item_hash) + async with session_factory() as session: + message_in_db = await get_message_by_item_hash( + session=session, item_hash=item_hash + ) assert message_in_db is not None assert message_in_db.content == content @@ -97,17 +99,19 @@ async def test_confirm_message( # Now, confirm the message # Insert a transaction in the DB to validate the foreign key constraint - with session_factory() as session: + async with session_factory() as session: session.add(chain_tx) pending_message.tx = chain_tx pending_message.tx_hash = chain_tx.hash await message_handler.process(session=session, pending_message=pending_message) - session.commit() + await session.commit() - with session_factory() as session: - message_in_db = get_message_by_item_hash(session=session, item_hash=item_hash) + async with session_factory() as session: + message_in_db = await get_message_by_item_hash( + session=session, item_hash=item_hash + ) assert message_in_db is not None assert message_in_db.confirmed @@ -119,7 +123,7 @@ async def test_confirm_message( @pytest.mark.asyncio async def test_process_confirmed_message( mock_config: Config, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, test_storage_service: StorageService, chain_tx: ChainTxDb, ): @@ -139,7 +143,7 @@ async def test_process_confirmed_message( ) # Insert a transaction in the DB to validate the foreign key constraint - with session_factory() as session: + async with session_factory() as session: session.add(chain_tx) pending_message = PendingMessageDb.from_message_dict( @@ -150,10 +154,12 @@ async def test_process_confirmed_message( pending_message.tx_hash = chain_tx.hash pending_message.tx = chain_tx await message_handler.process(session=session, pending_message=pending_message) - session.commit() + await session.commit() - with session_factory() as session: - message_in_db = get_message_by_item_hash(session=session, item_hash=item_hash) + async with session_factory() as session: + message_in_db = await get_message_by_item_hash( + session=session, item_hash=item_hash + ) assert message_in_db is not None assert message_in_db.confirmed diff --git a/tests/conftest.py b/tests/conftest.py index b50351ac0..5ebb69305 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,10 +25,15 @@ ) from aleph_message.models.execution.volume import ImmutableVolume from configmanager import Config +from sqlalchemy import text import aleph.config from aleph.db.accessors.files import insert_message_file_pin, upsert_file_tag -from aleph.db.connection import make_db_url, make_engine, make_session_factory +from aleph.db.connection import ( + make_async_engine, + make_async_session_factory, + make_db_url, +) from aleph.db.models import ( AlephBalanceDb, MessageStatusDb, @@ -49,7 +54,7 @@ SETTINGS_AGGREGATE_OWNER, ) from aleph.toolkit.timestamp import timestamp_to_datetime -from aleph.types.db_session import DbSession, DbSessionFactory +from aleph.types.db_session import AsyncDbSession, AsyncDbSessionFactory from aleph.types.files import FileTag, FileType from aleph.types.message_status import MessageStatus from aleph.web import create_aiohttp_app @@ -90,16 +95,18 @@ def run_db_migrations(config: Config): alembic.command.upgrade(alembic_cfg, "head", tag=db_url) -@pytest.fixture -def session_factory(mock_config): - engine = make_engine(config=mock_config, echo=False, application_name="aleph-tests") +@pytest_asyncio.fixture +async def session_factory(mock_config): + engine = make_async_engine( + config=mock_config, echo=False, application_name="aleph-tests" + ) - with engine.begin() as conn: - conn.execute("drop schema public cascade") - conn.execute("create schema public") + async with engine.begin() as conn: + await conn.execute(text("drop schema public cascade")) + await conn.execute(text("create schema public")) run_db_migrations(config=mock_config) - return make_session_factory(engine) + return make_async_session_factory(engine) @pytest.fixture @@ -182,8 +189,10 @@ async def ccn_api_client( return client -@pytest.fixture -def fixture_instance_message(session_factory: DbSessionFactory) -> PendingMessageDb: +@pytest_asyncio.fixture +async def fixture_instance_message( + session_factory: AsyncDbSessionFactory, +) -> PendingMessageDb: content = { "address": "0x9319Ad3B7A8E0eE24f2E639c40D8eD124C5520Ba", "allow_amend": False, @@ -267,7 +276,7 @@ def fixture_instance_message(session_factory: DbSessionFactory) -> PendingMessag retries=0, next_attempt=dt.datetime(2023, 1, 1), ) - with session_factory() as session: + async with session_factory() as session: session.add(pending_message) session.add( MessageStatusDb( @@ -276,18 +285,18 @@ def fixture_instance_message(session_factory: DbSessionFactory) -> PendingMessag reception_time=pending_message.reception_time, ) ) - session.commit() + await session.commit() return pending_message -@pytest.fixture -def instance_message_with_volumes_in_db( - session_factory: DbSessionFactory, fixture_instance_message: PendingMessageDb +@pytest_asyncio.fixture +async def instance_message_with_volumes_in_db( + session_factory: AsyncDbSessionFactory, fixture_instance_message: PendingMessageDb ) -> None: - with session_factory() as session: + async with session_factory() as session: insert_volume_refs(session, fixture_instance_message) - session.commit() + await session.commit() class Volume(Protocol): @@ -314,7 +323,7 @@ def get_volume_refs(content: ExecutableContent) -> List[Volume]: return volumes -def insert_volume_refs(session: DbSession, message: PendingMessageDb): +async def insert_volume_refs(session: AsyncDbSession, message: PendingMessageDb): """ Insert volume references in the DB to make the program processable. """ @@ -330,8 +339,8 @@ def insert_volume_refs(session: DbSession, message: PendingMessageDb): file_hash = volume.ref[::-1] session.add(StoredFileDb(hash=file_hash, size=1024 * 1024, type=FileType.FILE)) - session.flush() - insert_message_file_pin( + await session.flush() + await insert_message_file_pin( session=session, file_hash=volume.ref[::-1], owner=content.address, @@ -339,7 +348,7 @@ def insert_volume_refs(session: DbSession, message: PendingMessageDb): ref=None, created=created, ) - upsert_file_tag( + await upsert_file_tag( session=session, tag=FileTag(volume.ref), owner=content.address, @@ -348,8 +357,8 @@ def insert_volume_refs(session: DbSession, message: PendingMessageDb): ) -@pytest.fixture -def user_balance(session_factory: DbSessionFactory) -> AlephBalanceDb: +@pytest_asyncio.fixture +async def user_balance(session_factory: AsyncDbSessionFactory) -> AlephBalanceDb: balance = AlephBalanceDb( address="0x9319Ad3B7A8E0eE24f2E639c40D8eD124C5520Ba", chain=Chain.ETH, @@ -357,14 +366,16 @@ def user_balance(session_factory: DbSessionFactory) -> AlephBalanceDb: eth_height=0, ) - with session_factory() as session: + async with session_factory() as session: session.add(balance) - session.commit() + await session.commit() return balance -@pytest.fixture -def user_balance_eth_avax(session_factory: DbSessionFactory) -> AlephBalanceDb: +@pytest_asyncio.fixture +async def user_balance_eth_avax( + session_factory: AsyncDbSessionFactory, +) -> AlephBalanceDb: balance_eth = AlephBalanceDb( address="0x9319Ad3B7A8E0eE24f2E639c40D8eD124C5520Ba", chain=Chain.ETH, @@ -383,13 +394,15 @@ def user_balance_eth_avax(session_factory: DbSessionFactory) -> AlephBalanceDb: session.add(balance_eth) session.add(balance_avax) - session.commit() + await session.commit() return balance_avax -@pytest.fixture -def fixture_product_prices_aggregate_in_db(session_factory: DbSessionFactory) -> None: - with session_factory() as session: +@pytest_asyncio.fixture +async def fixture_product_prices_aggregate_in_db( + session_factory: AsyncDbSessionFactory, +) -> None: + async with session_factory() as session: item_hash = "7b74b9c5f73e7a0713dbe83a377b1d321ffb4a5411ea3df49790a9720b93a5bF" content = DEFAULT_PRICE_AGGREGATE session.add( @@ -413,12 +426,14 @@ def fixture_product_prices_aggregate_in_db(session_factory: DbSessionFactory) -> ) ) - session.commit() + await session.commit() -@pytest.fixture -def fixture_settings_aggregate_in_db(session_factory: DbSessionFactory) -> None: - with session_factory() as session: +@pytest_asyncio.fixture +async def fixture_settings_aggregate_in_db( + session_factory: AsyncDbSessionFactory, +) -> None: + async with session_factory() as session: item_hash = "a319a7216d39032212c2f11028a21efaac4e5f78254baa34001483c7af22b7a4" content = DEFAULT_SETTINGS_AGGREGATE @@ -443,4 +458,4 @@ def fixture_settings_aggregate_in_db(session_factory: DbSessionFactory) -> None: ) ) - session.commit() + await session.commit() diff --git a/tests/db/test_accounts.py b/tests/db/test_accounts.py index 2d71d72cf..3bfdc756c 100644 --- a/tests/db/test_accounts.py +++ b/tests/db/test_accounts.py @@ -12,7 +12,7 @@ from aleph.db.models import MessageDb from aleph.toolkit.timestamp import timestamp_to_datetime from aleph.types.channel import Channel -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory @pytest.fixture @@ -60,21 +60,21 @@ def fixture_messages(): @pytest.mark.asyncio async def test_get_message_stats_by_address( - session_factory: DbSessionFactory, fixture_messages: List[MessageDb] + session_factory: AsyncDbSessionFactory, fixture_messages: List[MessageDb] ): # No data test - with session_factory() as session: - stats_no_data = get_message_stats_by_address(session) + async with session_factory() as session: + stats_no_data = await get_message_stats_by_address(session) assert stats_no_data == [] # Refresh the materialized view session.add_all(fixture_messages) - session.commit() + await session.commit() - refresh_address_stats_mat_view(session) - session.commit() + await refresh_address_stats_mat_view(session) + await session.commit() - stats = get_message_stats_by_address(session) + stats = await get_message_stats_by_address(session) assert len(stats) == 2 stats_by_address = {row.address: row for row in stats} @@ -90,7 +90,7 @@ async def test_get_message_stats_by_address( assert stats_by_address["0x1234"].nb_messages == 1 # Filter by address - stats = get_message_stats_by_address(session, addresses=("0x1234",)) + stats = await get_message_stats_by_address(session, addresses=("0x1234",)) assert len(stats) == 1 row = stats[0] assert row.address == "0x1234" diff --git a/tests/db/test_aggregates.py b/tests/db/test_aggregates.py index 582742864..795dfa56d 100644 --- a/tests/db/test_aggregates.py +++ b/tests/db/test_aggregates.py @@ -11,11 +11,11 @@ refresh_aggregate, ) from aleph.db.models import AggregateDb, AggregateElementDb -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory @pytest.mark.asyncio -async def test_get_aggregate_by_key(session_factory: DbSessionFactory): +async def test_get_aggregate_by_key(session_factory: AsyncDbSessionFactory): key = "key" owner = "Me" creation_datetime = dt.datetime(2022, 1, 1) @@ -35,12 +35,12 @@ async def test_get_aggregate_by_key(session_factory: DbSessionFactory): ), ) - with session_factory() as session: + async with session_factory() as session: session.add(aggregate) - session.commit() + await session.commit() - with session_factory() as session: - aggregate_db = get_aggregate_by_key(session=session, owner=owner, key=key) + async with session_factory() as session: + aggregate_db = await get_aggregate_by_key(session=session, owner=owner, key=key) assert aggregate_db assert aggregate_db.key == key assert aggregate_db.owner == owner @@ -48,8 +48,8 @@ async def test_get_aggregate_by_key(session_factory: DbSessionFactory): assert aggregate_db.last_revision_hash == aggregate.last_revision.item_hash # Try not loading the content - with session_factory() as session: - aggregate_db = get_aggregate_by_key( + async with session_factory() as session: + aggregate_db = await get_aggregate_by_key( session=session, owner=owner, key=key, with_content=False ) @@ -59,9 +59,11 @@ async def test_get_aggregate_by_key(session_factory: DbSessionFactory): @pytest.mark.asyncio -async def test_get_aggregate_by_key_no_data(session_factory: DbSessionFactory): - with session_factory() as session: - aggregate = get_aggregate_by_key(session=session, owner="owner", key="key") +async def test_get_aggregate_by_key_no_data(session_factory: AsyncDbSessionFactory): + async with session_factory() as session: + aggregate = await get_aggregate_by_key( + session=session, owner="owner", key="key" + ) assert aggregate is None @@ -107,25 +109,25 @@ def aggregate_fixtures() -> Tuple[AggregateDb, Sequence[AggregateElementDb]]: ] -def _test_refresh_aggregate( - session_factory: DbSessionFactory, +async def _test_refresh_aggregate( + session_factory: AsyncDbSessionFactory, aggregate: Optional[AggregateDb], expected_aggregate: AggregateDb, elements: Sequence[AggregateElementDb], ): - with session_factory() as session: + async with session_factory() as session: session.add_all(elements) if aggregate: session.add(aggregate) - session.commit() + await session.commit() - with session_factory() as session: - refresh_aggregate( + async with session_factory() as session: + await refresh_aggregate( session=session, owner=expected_aggregate.owner, key=expected_aggregate.key ) - session.commit() + await session.commit() - aggregate_db = get_aggregate_by_key( + aggregate_db = await get_aggregate_by_key( session=session, owner=expected_aggregate.owner, key=expected_aggregate.key ) @@ -141,11 +143,11 @@ def _test_refresh_aggregate( @pytest.mark.asyncio async def test_refresh_aggregate_insert( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, aggregate_fixtures: Tuple[AggregateDb, Sequence[AggregateElementDb]], ): aggregate, elements = aggregate_fixtures - _test_refresh_aggregate( + await _test_refresh_aggregate( session_factory=session_factory, aggregate=None, expected_aggregate=aggregate, @@ -155,7 +157,7 @@ async def test_refresh_aggregate_insert( @pytest.mark.asyncio async def test_refresh_aggregate_update( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, aggregate_fixtures: Tuple[AggregateDb, Sequence[AggregateElementDb]], ): updated_aggregate, elements = aggregate_fixtures @@ -167,7 +169,7 @@ async def test_refresh_aggregate_update( last_revision_hash=elements[0].item_hash, dirty=True, ) - _test_refresh_aggregate( + await _test_refresh_aggregate( session_factory=session_factory, aggregate=aggregate, expected_aggregate=updated_aggregate, @@ -177,11 +179,11 @@ async def test_refresh_aggregate_update( @pytest.mark.asyncio async def test_refresh_aggregate_update_no_op( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, aggregate_fixtures: Tuple[AggregateDb, Sequence[AggregateElementDb]], ): aggregate, elements = aggregate_fixtures - _test_refresh_aggregate( + await _test_refresh_aggregate( session_factory=session_factory, aggregate=aggregate, expected_aggregate=aggregate, @@ -191,27 +193,29 @@ async def test_refresh_aggregate_update_no_op( @pytest.mark.asyncio async def test_get_content_keys( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, aggregate_fixtures: Tuple[AggregateDb, Sequence[AggregateElementDb]], ): aggregate, elements = aggregate_fixtures - with session_factory() as session: + async with session_factory() as session: session.add_all(elements) session.add(aggregate) - session.commit() + await session.commit() - with session_factory() as session: + async with session_factory() as session: keys = set( - get_aggregate_content_keys( + await get_aggregate_content_keys( session=session, key=aggregate.key, owner=aggregate.owner ) ) assert keys == set(aggregate.content.keys()) # Test no match - with session_factory() as session: + async with session_factory() as session: keys = set( - get_aggregate_content_keys(session=session, key="not-a-key", owner="no-one") + await get_aggregate_content_keys( + session=session, key="not-a-key", owner="no-one" + ) ) assert keys == set() diff --git a/tests/db/test_chains.py b/tests/db/test_chains.py index c4e2004aa..f7faddc0d 100644 --- a/tests/db/test_chains.py +++ b/tests/db/test_chains.py @@ -16,11 +16,11 @@ from aleph.db.models.chains import ChainSyncStatusDb, IndexerSyncStatusDb from aleph.toolkit.range import MultiRange, Range from aleph.types.chain_sync import ChainEventType -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory @pytest.mark.asyncio -async def test_get_last_height(session_factory: DbSessionFactory): +async def test_get_last_height(session_factory: AsyncDbSessionFactory): sync_type = ChainEventType.SYNC eth_sync_status = ChainSyncStatusDb( chain=Chain.ETH, @@ -29,20 +29,22 @@ async def test_get_last_height(session_factory: DbSessionFactory): last_update=pytz.utc.localize(dt.datetime(2022, 10, 1)), ) - with session_factory() as session: + async with session_factory() as session: session.add(eth_sync_status) - session.commit() + await session.commit() - with session_factory() as session: - height = get_last_height(session=session, chain=Chain.ETH, sync_type=sync_type) + async with session_factory() as session: + height = await get_last_height( + session=session, chain=Chain.ETH, sync_type=sync_type + ) assert height == eth_sync_status.height @pytest.mark.asyncio -async def test_get_last_height_no_data(session_factory: DbSessionFactory): - with session_factory() as session: - height = get_last_height( +async def test_get_last_height_no_data(session_factory: AsyncDbSessionFactory): + async with session_factory() as session: + height = await get_last_height( session=session, chain=Chain.NULS2, sync_type=ChainEventType.SYNC ) @@ -50,26 +52,26 @@ async def test_get_last_height_no_data(session_factory: DbSessionFactory): @pytest.mark.asyncio -async def test_upsert_chain_sync_status_insert(session_factory: DbSessionFactory): +async def test_upsert_chain_sync_status_insert(session_factory: AsyncDbSessionFactory): chain = Chain.ETH sync_type = ChainEventType.SYNC update_datetime = pytz.utc.localize(dt.datetime(2022, 11, 1)) height = 10 - with session_factory() as session: - upsert_chain_sync_status( + async with session_factory() as session: + await upsert_chain_sync_status( session=session, chain=chain, sync_type=sync_type, height=height, update_datetime=update_datetime, ) - session.commit() + await session.commit() - with session_factory() as session: + async with session_factory() as session: chain_sync_status = ( - session.execute( + await session.execute( select(ChainSyncStatusDb).where(ChainSyncStatusDb.chain == chain) ) ).scalar_one() @@ -81,7 +83,7 @@ async def test_upsert_chain_sync_status_insert(session_factory: DbSessionFactory @pytest.mark.asyncio -async def test_upsert_peer_replace(session_factory: DbSessionFactory): +async def test_upsert_peer_replace(session_factory: AsyncDbSessionFactory): existing_entry = ChainSyncStatusDb( chain=Chain.TEZOS, type=ChainEventType.SYNC, @@ -89,26 +91,26 @@ async def test_upsert_peer_replace(session_factory: DbSessionFactory): last_update=pytz.utc.localize(dt.datetime(2023, 2, 6)), ) - with session_factory() as session: + async with session_factory() as session: session.add(existing_entry) - session.commit() + await session.commit() new_height = 1001 new_update_datetime = pytz.utc.localize(dt.datetime(2023, 2, 7)) - with session_factory() as session: - upsert_chain_sync_status( + async with session_factory() as session: + await upsert_chain_sync_status( session=session, chain=existing_entry.chain, sync_type=ChainEventType.SYNC, height=new_height, update_datetime=new_update_datetime, ) - session.commit() + await session.commit() - with session_factory() as session: + async with session_factory() as session: chain_sync_status = ( - session.execute( + await session.execute( select(ChainSyncStatusDb).where( ChainSyncStatusDb.chain == existing_entry.chain ) @@ -141,7 +143,8 @@ def indexer_multirange(): ) -def test_get_indexer_multirange(session_factory: DbSessionFactory): +@pytest.mark.asyncio +async def test_get_indexer_multirange(session_factory: AsyncDbSessionFactory): chain = Chain.ETH event_type = ChainEventType.SYNC @@ -175,12 +178,12 @@ def test_get_indexer_multirange(session_factory: DbSessionFactory): ), ] - with session_factory() as session: + async with session_factory() as session: session.add_all(ranges) - session.commit() + await session.commit() - with session_factory() as session: - db_multirange = get_indexer_multirange( + async with session_factory() as session: + db_multirange = await get_indexer_multirange( session=session, chain=chain, event_type=event_type ) @@ -191,17 +194,17 @@ def test_get_indexer_multirange(session_factory: DbSessionFactory): ) -def test_update_indexer_multirange( - indexer_multirange: IndexerMultiRange, session_factory: DbSessionFactory +async def test_update_indexer_multirange( + indexer_multirange: IndexerMultiRange, session_factory: AsyncDbSessionFactory ): - with session_factory() as session: - update_indexer_multirange( + async with session_factory() as session: + await update_indexer_multirange( session=session, indexer_multirange=indexer_multirange ) - session.commit() + await session.commit() - with session_factory() as session: - indexer_multirange_db = get_indexer_multirange( + async with session_factory() as session: + indexer_multirange_db = await get_indexer_multirange( session=session, chain=indexer_multirange.chain, event_type=indexer_multirange.event_type, @@ -210,16 +213,16 @@ def test_update_indexer_multirange( assert indexer_multirange_db == indexer_multirange -def test_get_missing_indexer_datetime_multirange( - indexer_multirange: IndexerMultiRange, session_factory: DbSessionFactory +async def test_get_missing_indexer_datetime_multirange( + indexer_multirange: IndexerMultiRange, session_factory: AsyncDbSessionFactory ): - with session_factory() as session: - update_indexer_multirange( + async with session_factory() as session: + await update_indexer_multirange( session=session, indexer_multirange=indexer_multirange ) - session.commit() + await session.commit() - with session_factory() as session: + async with session_factory() as session: new_multirange = MultiRange( Range( dt.datetime(2019, 1, 1, tzinfo=dt.timezone.utc), @@ -227,7 +230,7 @@ def test_get_missing_indexer_datetime_multirange( upper_inc=True, ) ) - dt_multirange = get_missing_indexer_datetime_multirange( + dt_multirange = await get_missing_indexer_datetime_multirange( session=session, chain=indexer_multirange.chain, event_type=indexer_multirange.event_type, diff --git a/tests/db/test_cost.py b/tests/db/test_cost.py index d7cd634ac..95d51f503 100644 --- a/tests/db/test_cost.py +++ b/tests/db/test_cost.py @@ -4,6 +4,7 @@ from typing import List, Union import pytest +import pytest_asyncio import pytz from aleph_message.models import ( Chain, @@ -25,7 +26,7 @@ from aleph.db.models import AlephBalanceDb, MessageDb, MessageStatusDb, StoredFileDb from aleph.services.cost import get_total_and_detailed_costs from aleph.toolkit.timestamp import timestamp_to_datetime -from aleph.types.db_session import DbSession, DbSessionFactory +from aleph.types.db_session import AsyncDbSession, AsyncDbSessionFactory from aleph.types.files import FileTag, FileType from aleph.types.message_status import MessageStatus @@ -55,7 +56,7 @@ def get_volume_refs( return volumes -def insert_volume_refs(session: DbSession, message: MessageDb): +async def insert_volume_refs(session: AsyncDbSession, message: MessageDb): """ Insert volume references in the DB to make the program processable. """ @@ -74,8 +75,8 @@ def insert_volume_refs(session: DbSession, message: MessageDb): session.add( StoredFileDb(hash=file_hash, size=1024 * 1024, type=FileType.FILE) ) - session.flush() - insert_message_file_pin( + await session.flush() + await insert_message_file_pin( session=session, file_hash=volume.ref[::-1], owner=content.address, @@ -83,7 +84,7 @@ def insert_volume_refs(session: DbSession, message: MessageDb): ref=None, created=created, ) - upsert_file_tag( + await upsert_file_tag( session=session, tag=FileTag(volume.ref), owner=content.address, @@ -92,7 +93,7 @@ def insert_volume_refs(session: DbSession, message: MessageDb): ) -async def insert_costs(session: DbSession, message: MessageDb): +async def insert_costs(session: AsyncDbSession, message: MessageDb): """ Insert volume references in the DB to make the program processable. """ @@ -100,15 +101,17 @@ async def insert_costs(session: DbSession, message: MessageDb): if message.item_content: content = InstanceContent.model_validate_json(message.item_content) - _, costs = get_total_and_detailed_costs(session, content, message.item_hash) + _, costs = await get_total_and_detailed_costs( + session, content, message.item_hash + ) if costs: insert_stmt = make_costs_upsert_query(costs) - session.execute(insert_stmt) + await session.execute(insert_stmt) -@pytest.fixture -def fixture_instance_message(session_factory: DbSessionFactory) -> MessageDb: +@pytest_asyncio.fixture +async def fixture_instance_message(session_factory: AsyncDbSessionFactory) -> MessageDb: content = { "address": "0x9319Ad3B7A8E0eE24f2E639c40D8eD124C5520Ba", "allow_amend": False, @@ -191,7 +194,7 @@ def fixture_instance_message(session_factory: DbSessionFactory) -> MessageDb: channel=None, size=2000, ) - with session_factory() as session: + async with session_factory() as session: session.add(message) session.add( MessageStatusDb( @@ -200,19 +203,19 @@ def fixture_instance_message(session_factory: DbSessionFactory) -> MessageDb: reception_time=reception_time, ) ) - session.commit() + await session.commit() return message @pytest.mark.asyncio async def test_get_total_cost_for_address( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, fixture_instance_message, fixture_product_prices_aggregate_in_db, fixture_settings_aggregate_in_db, ): - with session_factory() as session: + async with session_factory() as session: session.add( AlephBalanceDb( address=fixture_instance_message.sender, @@ -222,11 +225,11 @@ async def test_get_total_cost_for_address( eth_height=0, ) ) - insert_volume_refs(session, fixture_instance_message) + await insert_volume_refs(session, fixture_instance_message) await insert_costs(session, fixture_instance_message) - session.commit() + await session.commit() - total_cost: Decimal = get_total_cost_for_address( + total_cost: Decimal = await get_total_cost_for_address( session=session, address=fixture_instance_message.sender ) diff --git a/tests/db/test_error_codes.py b/tests/db/test_error_codes.py index d5d31ce4d..870fb0eda 100644 --- a/tests/db/test_error_codes.py +++ b/tests/db/test_error_codes.py @@ -1,18 +1,20 @@ +import pytest from sqlalchemy import select from aleph.db.models import ErrorCodeDb -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from aleph.types.message_status import ErrorCode -def test_all_error_codes_mapped_in_db(session_factory: DbSessionFactory): +@pytest.mark.asyncio +async def test_all_error_codes_mapped_in_db(session_factory: AsyncDbSessionFactory): """ Check that the ErrorCode enum values are all mapped in the database and vice-versa. Sanity check for developers. """ - with session_factory() as session: - db_error_codes = session.execute(select(ErrorCodeDb)).scalars() + async with session_factory() as session: + db_error_codes = (await session.execute(select(ErrorCodeDb))).scalars() db_error_codes_dict = {e.code: e for e in db_error_codes} # All error code enum values must be mapped in the DB diff --git a/tests/db/test_files.py b/tests/db/test_files.py index 7927fa599..4a6ec1121 100644 --- a/tests/db/test_files.py +++ b/tests/db/test_files.py @@ -11,15 +11,15 @@ upsert_file_tag, ) from aleph.db.models import MessageFilePinDb, StoredFileDb, TxFilePinDb -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from aleph.types.files import FileTag, FileType @pytest.mark.asyncio -async def test_is_pinned_file(session_factory: DbSessionFactory): - def is_pinned(_session_factory, _file_hash) -> bool: +async def test_is_pinned_file(session_factory: AsyncDbSessionFactory): + async def is_pinned(_session_factory, _file_hash) -> bool: with _session_factory() as _session: - return is_pinned_file(session=_session, file_hash=_file_hash) + return await is_pinned_file(session=_session, file_hash=_file_hash) file = StoredFileDb( hash="QmTm7g1Mh3BhrQPjnedVQ5g67DR7cwhyMN3MvFt1JPPdWd", @@ -27,27 +27,27 @@ def is_pinned(_session_factory, _file_hash) -> bool: type=FileType.FILE, ) - with session_factory() as session: + async with session_factory() as session: session.add(file) - session.commit() + await session.commit() # We check for equality with True/False to determine that the function does indeed # return a boolean value - assert is_pinned(session_factory, file.hash) is False + assert await is_pinned(session_factory, file.hash) is False - with session_factory() as session: + async with session_factory() as session: session.add( TxFilePinDb( file_hash=file.hash, tx_hash="1234", created=dt.datetime(2020, 1, 1) ) ) - session.commit() + await session.commit() - assert is_pinned(session_factory, file.hash) is True + assert await is_pinned(session_factory, file.hash) is True @pytest.mark.asyncio -async def test_upsert_file_tag(session_factory: DbSessionFactory): +async def test_upsert_file_tag(session_factory: AsyncDbSessionFactory): original_file = StoredFileDb( hash="QmTm7g1Mh3BhrQPjnedVQ5g67DR7cwhyMN3MvFt1JPPdWd", size=32, @@ -63,57 +63,57 @@ async def test_upsert_file_tag(session_factory: DbSessionFactory): tag = FileTag("aleph/custom-tag") owner = "aleph" - with session_factory() as session: + async with session_factory() as session: session.add(original_file) session.add(new_version) - session.commit() + await session.commit() - with session_factory() as session: - upsert_file_tag( + async with session_factory() as session: + await upsert_file_tag( session=session, tag=tag, owner=owner, file_hash=original_file.hash, last_updated=original_datetime, ) - session.commit() + await session.commit() - file_tag_db = get_file_tag(session=session, tag=tag) + file_tag_db = await get_file_tag(session=session, tag=tag) assert file_tag_db is not None assert file_tag_db.owner == owner assert file_tag_db.file_hash == original_file.hash assert file_tag_db.last_updated == original_datetime # Update the tag - with session_factory() as session: + async with session_factory() as session: new_version_datetime = pytz.utc.localize(dt.datetime(2022, 1, 1)) - upsert_file_tag( + await upsert_file_tag( session=session, tag=tag, owner=owner, file_hash=new_version.hash, last_updated=new_version_datetime, ) - session.commit() + await session.commit() - file_tag_db = get_file_tag(session=session, tag=tag) + file_tag_db = await get_file_tag(session=session, tag=tag) assert file_tag_db is not None assert file_tag_db.owner == owner assert file_tag_db.file_hash == new_version.hash assert file_tag_db.last_updated == new_version_datetime # Try to update the tag to an older version and check it has no effect - with session_factory() as session: - upsert_file_tag( + async with session_factory() as session: + await upsert_file_tag( session=session, tag=tag, owner=owner, file_hash=original_file.hash, last_updated=original_datetime, ) - session.commit() + await session.commit() - file_tag_db = get_file_tag(session=session, tag=tag) + file_tag_db = await get_file_tag(session=session, tag=tag) assert file_tag_db is not None assert file_tag_db.owner == owner assert file_tag_db.file_hash == new_version.hash @@ -121,7 +121,7 @@ async def test_upsert_file_tag(session_factory: DbSessionFactory): @pytest.mark.asyncio -async def test_refresh_file_tag(session_factory: DbSessionFactory): +async def test_refresh_file_tag(session_factory: AsyncDbSessionFactory): files = [ StoredFileDb( hash="QmTm7g1Mh3BhrQPjnedVQ5g67DR7cwhyMN3MvFt1JPPdWd", @@ -152,15 +152,15 @@ async def test_refresh_file_tag(session_factory: DbSessionFactory): owner=owner, ) - with session_factory() as session: + async with session_factory() as session: session.add_all([first_pin, second_pin]) - session.commit() + await session.commit() - with session_factory() as session: - refresh_file_tag(session=session, tag=tag) - session.commit() + async with session_factory() as session: + await refresh_file_tag(session=session, tag=tag) + await session.commit() - file_tag_db = get_file_tag(session=session, tag=tag) + file_tag_db = await get_file_tag(session=session, tag=tag) assert file_tag_db assert file_tag_db.file_hash == second_pin.file_hash assert file_tag_db.last_updated == second_pin.created diff --git a/tests/db/test_messages.py b/tests/db/test_messages.py index 7b610b749..7b55387e7 100644 --- a/tests/db/test_messages.py +++ b/tests/db/test_messages.py @@ -22,7 +22,7 @@ from aleph.toolkit.timestamp import timestamp_to_datetime from aleph.types.chain_sync import ChainSyncProtocol from aleph.types.channel import Channel -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from aleph.types.message_status import MessageStatus @@ -66,14 +66,14 @@ def assert_messages_equal(expected: MessageDb, actual: MessageDb): @pytest.mark.asyncio async def test_get_message( - session_factory: DbSessionFactory, fixture_message: MessageDb + session_factory: AsyncDbSessionFactory, fixture_message: MessageDb ): - with session_factory() as session: + async with session_factory() as session: session.add(fixture_message) - session.commit() + await session.commit() - with session_factory() as session: - fetched_message = get_message_by_item_hash( + async with session_factory() as session: + fetched_message = await get_message_by_item_hash( session=session, item_hash=ItemHash(fixture_message.item_hash) ) @@ -87,7 +87,7 @@ async def test_get_message( @pytest.mark.asyncio async def test_get_message_with_confirmations( - session_factory: DbSessionFactory, fixture_message: MessageDb + session_factory: AsyncDbSessionFactory, fixture_message: MessageDb ): confirmations = [ ChainTxDb( @@ -114,12 +114,12 @@ async def test_get_message_with_confirmations( fixture_message.confirmations = confirmations - with session_factory() as session: + async with session_factory() as session: session.add(fixture_message) - session.commit() + await session.commit() - with session_factory() as session: - fetched_message = get_message_by_item_hash( + async with session_factory() as session: + fetched_message = await get_message_by_item_hash( session=session, item_hash=ItemHash(fixture_message.item_hash) ) @@ -141,43 +141,43 @@ async def test_get_message_with_confirmations( @pytest.mark.asyncio -async def test_message_exists(session_factory: DbSessionFactory, fixture_message): - with session_factory() as session: +async def test_message_exists(session_factory: AsyncDbSessionFactory, fixture_message): + async with session_factory() as session: assert not message_exists(session=session, item_hash=fixture_message.item_hash) session.add(fixture_message) - session.commit() + await session.commit() assert message_exists(session=session, item_hash=fixture_message.item_hash) @pytest.mark.asyncio async def test_message_count( - session_factory: DbSessionFactory, fixture_message: MessageDb + session_factory: AsyncDbSessionFactory, fixture_message: MessageDb ): - with session_factory() as session: + async with session_factory() as session: session.add(fixture_message) - session.commit() + await session.commit() # Analyze updates the table size estimate - session.execute(text("analyze messages")) - session.commit() + await session.execute(text("analyze messages")) + await session.commit() - with session_factory() as session: - exact_count = MessageDb.count(session) + async with session_factory() as session: + exact_count = await MessageDb.count(session) assert exact_count == 1 - estimate_count = MessageDb.estimate_count(session) + estimate_count = await MessageDb.estimate_count(session) assert isinstance(estimate_count, int) assert estimate_count == 1 - fast_count = MessageDb.fast_count(session) + fast_count = await MessageDb.fast_count(session) assert fast_count == 1 @pytest.mark.asyncio async def test_upsert_query_confirmation( - session_factory: DbSessionFactory, fixture_message: MessageDb + session_factory: AsyncDbSessionFactory, fixture_message: MessageDb ): item_hash = fixture_message.item_hash @@ -196,31 +196,35 @@ async def test_upsert_query_confirmation( item_hash=item_hash, tx_hash=chain_tx.hash ) - with session_factory() as session: + async with session_factory() as session: session.add(fixture_message) session.add(chain_tx) - session.commit() + await session.commit() # Insert - with session_factory() as session: - session.execute(upsert_stmt) - session.commit() - - confirmation_db = session.execute( - select(message_confirmations).where( - message_confirmations.c.item_hash == item_hash + async with session_factory() as session: + await session.execute(upsert_stmt) + await session.commit() + + confirmation_db = ( + await session.execute( + select(message_confirmations).where( + message_confirmations.c.item_hash == item_hash + ) ) ).one() assert confirmation_db.tx_hash == chain_tx.hash # Upsert - with session_factory() as session: - session.execute(upsert_stmt) - session.commit() - - confirmation_db = session.execute( - select(message_confirmations).where( - message_confirmations.c.item_hash == item_hash + async with session_factory() as session: + await session.execute(upsert_stmt) + await session.commit() + + confirmation_db = ( + await session.execute( + select(message_confirmations).where( + message_confirmations.c.item_hash == item_hash + ) ) ).one() assert confirmation_db.tx_hash == chain_tx.hash @@ -228,22 +232,22 @@ async def test_upsert_query_confirmation( @pytest.mark.asyncio async def test_upsert_query_message( - session_factory: DbSessionFactory, fixture_message: MessageDb + session_factory: AsyncDbSessionFactory, fixture_message: MessageDb ): message = copy(fixture_message) message.time = fixture_message.time - dt.timedelta(seconds=1) upsert_stmt = make_message_upsert_query(message) - with session_factory() as session: + async with session_factory() as session: session.add(message) - session.commit() + await session.commit() - with session_factory() as session: - session.execute(upsert_stmt) - session.commit() + async with session_factory() as session: + await session.execute(upsert_stmt) + await session.commit() - message_db = get_message_by_item_hash( + message_db = await get_message_by_item_hash( session=session, item_hash=ItemHash(message.item_hash) ) @@ -253,9 +257,9 @@ async def test_upsert_query_message( @pytest.mark.asyncio async def test_get_unconfirmed_messages( - session_factory: DbSessionFactory, fixture_message: MessageDb + session_factory: AsyncDbSessionFactory, fixture_message: MessageDb ): - with session_factory() as session: + async with session_factory() as session: session.add(fixture_message) session.add( MessageStatusDb( @@ -264,10 +268,10 @@ async def test_get_unconfirmed_messages( reception_time=fixture_message.time, ) ) - session.commit() + await session.commit() - with session_factory() as session: - unconfirmed_messages = list(get_unconfirmed_messages(session)) + async with session_factory() as session: + unconfirmed_messages = list(await get_unconfirmed_messages(session)) assert len(unconfirmed_messages) == 1 assert_messages_equal(fixture_message, unconfirmed_messages[0]) @@ -283,74 +287,76 @@ async def test_get_unconfirmed_messages( protocol_version=1, content="Qmsomething", ) - with session_factory() as session: + async with session_factory() as session: session.add(tx) - session.flush() - session.execute( + await session.flush() + await session.execute( insert(message_confirmations).values( item_hash=fixture_message.item_hash, tx_hash=tx.hash ) ) - session.commit() + await session.commit() - with session_factory() as session: + async with session_factory() as session: # Check that the message is now ignored - unconfirmed_messages = list(get_unconfirmed_messages(session)) + unconfirmed_messages = list(await get_unconfirmed_messages(session)) assert unconfirmed_messages == [] # Check that it is also ignored when the chain parameter is specified - unconfirmed_messages = list(get_unconfirmed_messages(session, chain=tx.chain)) + unconfirmed_messages = list( + await get_unconfirmed_messages(session, chain=tx.chain) + ) assert unconfirmed_messages == [] # Check that it reappears if we specify a different chain unconfirmed_messages = list( - get_unconfirmed_messages(session, chain=Chain.TEZOS) + await get_unconfirmed_messages(session, chain=Chain.TEZOS) ) assert len(unconfirmed_messages) == 1 assert_messages_equal(fixture_message, unconfirmed_messages[0]) # Check that the limit parameter is respected unconfirmed_messages = list( - get_unconfirmed_messages(session, chain=Chain.TEZOS, limit=0) + await get_unconfirmed_messages(session, chain=Chain.TEZOS, limit=0) ) assert unconfirmed_messages == [] @pytest.mark.asyncio async def test_get_unconfirmed_messages_trusted_messages( - session_factory: DbSessionFactory, fixture_message: MessageDb + session_factory: AsyncDbSessionFactory, fixture_message: MessageDb ): fixture_message.signature = None - with session_factory() as session: + async with session_factory() as session: session.add(fixture_message) - session.commit() + await session.commit() - with session_factory() as session: - unconfirmed_messages = list(get_unconfirmed_messages(session)) + async with session_factory() as session: + unconfirmed_messages = list(await get_unconfirmed_messages(session)) assert unconfirmed_messages == [] @pytest.mark.asyncio async def test_get_distinct_channels( - session_factory: DbSessionFactory, fixture_message: MessageDb + session_factory: AsyncDbSessionFactory, fixture_message: MessageDb ): # TODO: improve this test # * use several messages # * test if None if considered as a channel # * test - with session_factory() as session: + async with session_factory() as session: session.add(fixture_message) - session.commit() - channels = list(get_distinct_channels(session=session)) + await session.commit() + channels = list(await get_distinct_channels(session=session)) assert channels == [fixture_message.channel] @pytest.mark.asyncio async def test_forget_message( - session_factory: DbSessionFactory, fixture_message: MessageDb + session_factory: AsyncDbSessionFactory, fixture_message: MessageDb ): - with session_factory() as session: + async with session_factory() as session: session.add(fixture_message) session.add( MessageStatusDb( @@ -359,34 +365,34 @@ async def test_forget_message( reception_time=fixture_message.time, ) ) - session.commit() + await session.commit() forget_message_hash = ( "d06251c954d4c75476c749e80b8f2a4962d20282b28b3e237e30b0a76157df2d" ) - with session_factory() as session: - forget_message( + async with session_factory() as session: + await forget_message( session=session, item_hash=fixture_message.item_hash, forget_message_hash=forget_message_hash, ) - session.commit() + await session.commit() - message_status = get_message_status( + message_status = await get_message_status( session=session, item_hash=ItemHash(fixture_message.item_hash) ) assert message_status assert message_status.status == MessageStatus.FORGOTTEN # Assert that the message is not present in messages anymore - message = get_message_by_item_hash( + message = await get_message_by_item_hash( session=session, item_hash=ItemHash(fixture_message.item_hash) ) assert message is None # Assert that the metadata was inserted properly in forgotten_messages - forgotten_message = get_forgotten_message( + forgotten_message = await get_forgotten_message( session=session, item_hash=ItemHash(fixture_message.item_hash) ) assert forgotten_message @@ -406,14 +412,14 @@ async def test_forget_message( "2aa1f44199181110e0c6b79ccc5e40ceaf20eac791dcfcd1b4f8f2f32b2d8502" ) - append_to_forgotten_by( + await append_to_forgotten_by( session=session, forgotten_message_hash=fixture_message.item_hash, forget_message_hash=new_forget_message_hash, ) - session.commit() + await session.commit() - forgotten_message = get_forgotten_message( + forgotten_message = await get_forgotten_message( session=session, item_hash=fixture_message.item_hash ) assert forgotten_message diff --git a/tests/db/test_peers.py b/tests/db/test_peers.py index 3826b3997..c4da192e0 100644 --- a/tests/db/test_peers.py +++ b/tests/db/test_peers.py @@ -6,11 +6,11 @@ from aleph.db.accessors.peers import get_all_addresses_by_peer_type, upsert_peer from aleph.db.models.peers import PeerDb, PeerType -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory @pytest.mark.asyncio -async def test_get_all_addresses_by_peer_type(session_factory: DbSessionFactory): +async def test_get_all_addresses_by_peer_type(session_factory: AsyncDbSessionFactory): peer_id = "some-peer-id" last_seen = pytz.utc.localize(dt.datetime(2022, 10, 1)) source = PeerType.P2P @@ -38,19 +38,19 @@ async def test_get_all_addresses_by_peer_type(session_factory: DbSessionFactory) last_seen=last_seen, ) - with session_factory() as session: + async with session_factory() as session: session.add_all([http_entry, p2p_entry, ipfs_entry]) - session.commit() + await session.commit() - with session_factory() as session: - http_entries = get_all_addresses_by_peer_type( + async with session_factory() as session: + http_entries = await get_all_addresses_by_peer_type( session=session, peer_type=PeerType.HTTP ) - p2p_entries = get_all_addresses_by_peer_type( + p2p_entries = await get_all_addresses_by_peer_type( session=session, peer_type=PeerType.P2P ) - ipfs_entries = get_all_addresses_by_peer_type( + ipfs_entries = await get_all_addresses_by_peer_type( session=session, peer_type=PeerType.IPFS ) @@ -62,24 +62,26 @@ async def test_get_all_addresses_by_peer_type(session_factory: DbSessionFactory) @pytest.mark.asyncio @pytest.mark.parametrize("peer_type", (PeerType.HTTP, PeerType.P2P, PeerType.IPFS)) async def test_get_all_addresses_by_peer_type_no_match( - session_factory: DbSessionFactory, peer_type: PeerType + session_factory: AsyncDbSessionFactory, peer_type: PeerType ): - with session_factory() as session: - entries = get_all_addresses_by_peer_type(session=session, peer_type=peer_type) + async with session_factory() as session: + entries = await get_all_addresses_by_peer_type( + session=session, peer_type=peer_type + ) assert entries == [] @pytest.mark.asyncio -async def test_upsert_peer_insert(session_factory: DbSessionFactory): +async def test_upsert_peer_insert(session_factory: AsyncDbSessionFactory): peer_id = "peer-id" peer_type = PeerType.HTTP address = "http://127.0.0.1:4024" source = PeerType.IPFS last_seen = pytz.utc.localize(dt.datetime(2022, 10, 1)) - with session_factory() as session: - upsert_peer( + async with session_factory() as session: + await upsert_peer( session=session, peer_id=peer_id, address=address, @@ -87,12 +89,12 @@ async def test_upsert_peer_insert(session_factory: DbSessionFactory): source=source, last_seen=last_seen, ) - session.commit() + await session.commit() - with session_factory() as session: + async with session_factory() as session: peer = ( ( - session.execute( + await session.execute( select(PeerDb).where( (PeerDb.peer_id == peer_id) & (PeerDb.peer_type == peer_type) ) @@ -110,15 +112,15 @@ async def test_upsert_peer_insert(session_factory: DbSessionFactory): @pytest.mark.asyncio -async def test_upsert_peer_replace(session_factory: DbSessionFactory): +async def test_upsert_peer_replace(session_factory: AsyncDbSessionFactory): peer_id = "peer-id" peer_type = PeerType.HTTP address = "http://127.0.0.1:4024" source = PeerType.P2P last_seen = pytz.utc.localize(dt.datetime(2022, 10, 1)) - with session_factory() as session: - upsert_peer( + async with session_factory() as session: + await upsert_peer( session=session, peer_id=peer_id, peer_type=peer_type, @@ -126,14 +128,14 @@ async def test_upsert_peer_replace(session_factory: DbSessionFactory): source=source, last_seen=last_seen, ) - session.commit() + await session.commit() new_address = "http://0.0.0.0:4024" new_source = PeerType.IPFS new_last_seen = pytz.utc.localize(dt.datetime(2022, 10, 2)) - with session_factory() as session: - upsert_peer( + async with session_factory() as session: + await upsert_peer( session=session, peer_id=peer_id, peer_type=peer_type, @@ -141,12 +143,12 @@ async def test_upsert_peer_replace(session_factory: DbSessionFactory): source=new_source, last_seen=new_last_seen, ) - session.commit() + await session.commit() - with session_factory() as session: + async with session_factory() as session: peer = ( ( - session.execute( + await session.execute( select(PeerDb).where( (PeerDb.peer_id == peer_id) & (PeerDb.peer_type == peer_type) ) diff --git a/tests/db/test_pending_messages_db.py b/tests/db/test_pending_messages_db.py index ea1c46ff1..4b2bd789a 100644 --- a/tests/db/test_pending_messages_db.py +++ b/tests/db/test_pending_messages_db.py @@ -1,5 +1,5 @@ import datetime as dt -from typing import List +from typing import List, Set import pytest from aleph_message.models import Chain, ItemType, MessageType @@ -7,10 +7,11 @@ from aleph.db.accessors.pending_messages import ( count_pending_messages, get_next_pending_messages, + get_next_pending_messages_by_address, ) from aleph.db.models import ChainTxDb, PendingMessageDb from aleph.types.chain_sync import ChainSyncProtocol -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory @pytest.fixture @@ -87,58 +88,95 @@ def fixture_pending_messages(): reception_time=dt.datetime(2022, 10, 7, 21, 53, 10), fetched=True, ), + # Add additional messages with the same address as the first message + PendingMessageDb( + id=405, + item_hash="548b3c6f6455e6f4216b01b43522bddc3564a14c04799ed0ce8af4857c7ba15g", + type=MessageType.forget, + chain=Chain.ETH, + sender="0xaC033C1cA5C49Eff98A1D9a56BeDBC4840010BA4", + signature="0x3619c016987c4221c85842ce250f3e50a9b8e42c04d4f9fbdfdfad9941d6c5195a502a4f63289429513bf152d24d0a7bb0533701ec3c7bbca91b18ce7eaa7dee1b", + item_type=ItemType.inline, + item_content='{"address":"0xaC033C1cA5C49Eff98A1D9a56BeDBC4840010BA4","time":1648215809.0270267,"hashes":["fea0e00f73102aa951794a3ea85f6f1bbfd3decb804fb73232f2a645a379ae55"],"reason":"Another message"}', + channel="INTEGRATION_TESTS", + time=dt.datetime(2022, 10, 7, 17, 6), + next_attempt=dt.datetime(2022, 10, 7, 17, 6), + retries=0, + check_message=True, + reception_time=dt.datetime(2022, 10, 7, 17, 6, 10), + fetched=True, + ), + PendingMessageDb( + id=406, + item_hash="648b3c6f6455e6f4216b01b43522bddc3564a14c04799ed0ce8af4857c7ba15h", + type=MessageType.forget, + chain=Chain.ETH, + sender="0xaC033C1cA5C49Eff98A1D9a56BeDBC4840010BA4", + signature="0x3619c016987c4221c85842ce250f3e50a9b8e42c04d4f9fbdfdfad9941d6c5195a502a4f63289429513bf152d24d0a7bb0533701ec3c7bbca91b18ce7eaa7dee1b", + item_type=ItemType.inline, + item_content='{"address":"0xaC033C1cA5C49Eff98A1D9a56BeDBC4840010BA4","time":1648215809.0270267,"hashes":["fea0e00f73102aa951794a3ea85f6f1bbfd3decb804fb73232f2a645a379ae56"],"reason":"Yet another message"}', + channel="INTEGRATION_TESTS", + time=dt.datetime(2022, 10, 7, 17, 7), + next_attempt=dt.datetime(2022, 10, 7, 17, 7), + retries=0, + check_message=True, + reception_time=dt.datetime(2022, 10, 7, 17, 7, 10), + fetched=True, + ), ] @pytest.mark.asyncio async def test_count_pending_messages( - session_factory: DbSessionFactory, fixture_pending_messages: List[PendingMessageDb] + session_factory: AsyncDbSessionFactory, + fixture_pending_messages: List[PendingMessageDb], ): - with session_factory() as session: + async with session_factory() as session: session.add_all(fixture_pending_messages) - session.commit() + await session.commit() - with session_factory() as session: - count_all = count_pending_messages(session=session) - assert count_all == 3 + async with session_factory() as session: + count_all = await count_pending_messages(session=session) + assert count_all == 5 # Only one message is linked to an ETH transaction - count_eth = count_pending_messages(session=session, chain=Chain.ETH) + count_eth = await count_pending_messages(session=session, chain=Chain.ETH) assert count_eth == 1 # Only one message is linked to a TEZOS transaction - count_tezos = count_pending_messages(session=session, chain=Chain.TEZOS) + count_tezos = await count_pending_messages(session=session, chain=Chain.TEZOS) assert count_tezos == 1 # No message should be linked to any Solana transaction - count_sol = count_pending_messages(session=session, chain=Chain.SOL) + count_sol = await count_pending_messages(session=session, chain=Chain.SOL) assert count_sol == 0 @pytest.mark.asyncio async def test_get_pending_messages( - session_factory: DbSessionFactory, fixture_pending_messages: List[PendingMessageDb] + session_factory: AsyncDbSessionFactory, + fixture_pending_messages: List[PendingMessageDb], ): - with session_factory() as session: + async with session_factory() as session: session.add_all(fixture_pending_messages) - session.commit() + await session.commit() current_time = max( pending_message.next_attempt for pending_message in fixture_pending_messages ) - with session_factory() as session: + async with session_factory() as session: pending_messages = list( - get_next_pending_messages(session=session, current_time=current_time) + await get_next_pending_messages(session=session, current_time=current_time) ) - assert len(pending_messages) == 3 - # Check the order of messages - assert [m.id for m in pending_messages] == [404, 42, 27] + assert len(pending_messages) == 5 + # Check the order of messages (by next_attempt ascending) + assert [m.id for m in pending_messages] == [404, 42, 27, 405, 406] # Exclude hashes pending_messages = list( - get_next_pending_messages( + await get_next_pending_messages( session=session, current_time=current_time, exclude_item_hashes={ @@ -146,5 +184,78 @@ async def test_get_pending_messages( }, ) ) - assert len(pending_messages) == 2 - assert [m.id for m in pending_messages] == [404, 27] + assert len(pending_messages) == 4 + assert [m.id for m in pending_messages] == [404, 27, 405, 406] + + +@pytest.mark.asyncio +async def test_get_next_pending_messages_by_address( + session_factory: AsyncDbSessionFactory, + fixture_pending_messages: List[PendingMessageDb], +): + async with session_factory() as session: + session.add_all(fixture_pending_messages) + await session.commit() + + current_time = max( + pending_message.next_attempt for pending_message in fixture_pending_messages + ) + + async with session_factory() as session: + # Test fetching messages by address + pending_messages = await get_next_pending_messages_by_address( + session=session, current_time=current_time, batch_size=10 + ) + + # Should get all messages with address "0xaC033C1cA5C49Eff98A1D9a56BeDBC4840010BA4" + assert len(pending_messages) == 3 + addresses = [ + msg.content.get("address") + for msg in pending_messages + if msg.content and isinstance(msg.content, dict) + ] + assert all( + addr == "0xaC033C1cA5C49Eff98A1D9a56BeDBC4840010BA4" for addr in addresses + ) + + # Check sorting by next_attempt + assert [m.id for m in pending_messages] == [404, 405, 406] + + # Test with exclude_item_hashes + exclude_hashes: Set[str] = { + "448b3c6f6455e6f4216b01b43522bddc3564a14c04799ed0ce8af4857c7ba15f" + } + pending_messages = await get_next_pending_messages_by_address( + session=session, + current_time=current_time, + exclude_item_hashes=exclude_hashes, + batch_size=10, + ) + assert len(pending_messages) == 3 # Should fetch messages with another address + assert ( + pending_messages[0].content is not None + and pending_messages[0].content.get("address") + == "0x720F319A9c3226dCDd7D8C49163D79EDa1084E98" + ) + + # Test with exclude_addresses + exclude_addresses: Set[str] = {"0xaC033C1cA5C49Eff98A1D9a56BeDBC4840010BA4"} + pending_messages = await get_next_pending_messages_by_address( + session=session, + current_time=current_time, + exclude_addresses=exclude_addresses, + batch_size=10, + ) + assert len(pending_messages) == 2 # Should fetch messages with another address + assert ( + pending_messages[0].content is not None + and pending_messages[0].content.get("address") + == "0x720F319A9c3226dCDd7D8C49163D79EDa1084E98" + ) + + # Test batch size limit + pending_messages = await get_next_pending_messages_by_address( + session=session, current_time=current_time, batch_size=2 + ) + assert len(pending_messages) == 2 # Should limit to 2 messages + assert [m.id for m in pending_messages] == [404, 405] diff --git a/tests/db/test_pending_txs.py b/tests/db/test_pending_txs.py index caf57f8f8..25b6f71e8 100644 --- a/tests/db/test_pending_txs.py +++ b/tests/db/test_pending_txs.py @@ -7,7 +7,7 @@ from aleph.db.accessors.pending_txs import count_pending_txs, get_pending_txs from aleph.db.models import ChainTxDb, PendingTxDb from aleph.types.chain_sync import ChainSyncProtocol -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory @pytest.fixture @@ -58,21 +58,21 @@ def assert_pending_txs_equal(expected: PendingTxDb, actual: PendingTxDb): @pytest.mark.asyncio async def test_get_pending_txs( - session_factory: DbSessionFactory, fixture_txs: Sequence[PendingTxDb] + session_factory: AsyncDbSessionFactory, fixture_txs: Sequence[PendingTxDb] ): - with session_factory() as session: + async with session_factory() as session: session.add_all(fixture_txs) - session.commit() + await session.commit() - with session_factory() as session: - pending_txs = list(get_pending_txs(session=session)) + async with session_factory() as session: + pending_txs = list(await get_pending_txs(session=session)) for expected_tx, actual_tx in zip(pending_txs, fixture_txs): assert_pending_txs_equal(expected_tx, actual_tx) # Test the limit parameter - with session_factory() as session: - pending_txs = list(get_pending_txs(session=session, limit=1)) + async with session_factory() as session: + pending_txs = list(await get_pending_txs(session=session, limit=1)) assert pending_txs assert len(pending_txs) == 1 @@ -81,13 +81,13 @@ async def test_get_pending_txs( @pytest.mark.asyncio async def test_count_pending_txs( - session_factory: DbSessionFactory, fixture_txs: Sequence[PendingTxDb] + session_factory: AsyncDbSessionFactory, fixture_txs: Sequence[PendingTxDb] ): - with session_factory() as session: + async with session_factory() as session: session.add_all(fixture_txs) - session.commit() + await session.commit() - with session_factory() as session: - nb_txs = count_pending_txs(session=session) + async with session_factory() as session: + nb_txs = await count_pending_txs(session=session) assert nb_txs == len(fixture_txs) diff --git a/tests/db/test_posts.py b/tests/db/test_posts.py index e7185720e..451030de0 100644 --- a/tests/db/test_posts.py +++ b/tests/db/test_posts.py @@ -21,7 +21,7 @@ from aleph.db.models import MessageDb from aleph.db.models.posts import PostDb from aleph.types.channel import Channel -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from aleph.types.sort_order import SortOrder @@ -195,43 +195,47 @@ def assert_posts_v0_equal( @pytest.mark.asyncio async def test_get_post_no_amend( - original_post: PostDb, session_factory: DbSessionFactory + original_post: PostDb, session_factory: AsyncDbSessionFactory ): """ Checks that getting a post without amends works. """ - with session_factory() as session: + async with session_factory() as session: session.add(original_post) - session.commit() + await session.commit() - with session_factory() as session: - post = get_post(session=session, item_hash=original_post.item_hash) + async with session_factory() as session: + post = await get_post(session=session, item_hash=original_post.item_hash) assert post assert_posts_equal(merged_post=post, original=original_post) @pytest.mark.asyncio async def test_get_post_with_one_amend( - original_post: PostDb, first_amend_post: PostDb, session_factory: DbSessionFactory + original_post: PostDb, + first_amend_post: PostDb, + session_factory: AsyncDbSessionFactory, ): """ Checks that getting an amended post will return the amend and not the original. """ - with session_factory() as session: + async with session_factory() as session: session.add(original_post) session.add(first_amend_post) original_post.latest_amend = first_amend_post.item_hash - session.commit() + await session.commit() - with session_factory() as session: - post = get_post(session=session, item_hash=original_post.item_hash) + async with session_factory() as session: + post = await get_post(session=session, item_hash=original_post.item_hash) assert post assert_posts_equal( merged_post=post, original=original_post, last_amend=first_amend_post ) # Check that the query will not return a result when addressing the amend hash - amend_post = get_post(session=session, item_hash=first_amend_post.item_hash) + amend_post = await get_post( + session=session, item_hash=first_amend_post.item_hash + ) assert amend_post is None @@ -240,20 +244,20 @@ async def test_get_post_with_two_amends( original_post: PostDb, first_amend_post: PostDb, second_amend_post, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, ): """ Checks that getting a post amended twice will return the latest amend. """ - with session_factory() as session: + async with session_factory() as session: session.add(original_post) session.add(first_amend_post) session.add(second_amend_post) original_post.latest_amend = second_amend_post.item_hash - session.commit() + await session.commit() - with session_factory() as session: - post = get_post(session=session, item_hash=original_post.item_hash) + async with session_factory() as session: + post = await get_post(session=session, item_hash=original_post.item_hash) assert post assert_posts_equal( merged_post=post, original=original_post, last_amend=second_amend_post @@ -265,28 +269,28 @@ async def test_get_matching_posts( original_post: PostDb, first_amend_post: PostDb, post_from_second_user: PostDb, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, ): """ Tests that the list getter works. """ - with session_factory() as session: + async with session_factory() as session: session.add(original_post) session.add(first_amend_post) original_post.latest_amend = first_amend_post.item_hash session.add(post_from_second_user) - session.commit() + await session.commit() - with session_factory() as session: + async with session_factory() as session: # Get all posts, no filter - matching_posts = get_matching_posts(session=session) + matching_posts = await get_matching_posts(session=session) assert len(matching_posts) == 2 - nb_posts = count_matching_posts(session=session) + nb_posts = await count_matching_posts(session=session) assert nb_posts == 2 # Get by hash - matching_hash_posts = get_matching_posts( + matching_hash_posts = await get_matching_posts( session=session, hashes=[original_post.item_hash] ) assert matching_hash_posts @@ -295,33 +299,33 @@ async def test_get_matching_posts( original=original_post, last_amend=first_amend_post, ) - nb_matching_hash_posts = count_matching_posts( + nb_matching_hash_posts = await count_matching_posts( session=session, hashes=[original_post.item_hash] ) assert nb_matching_hash_posts == 1 # Get by owner address - matching_address_posts = get_matching_posts( + matching_address_posts = await get_matching_posts( session=session, addresses=[post_from_second_user.owner] ) assert matching_address_posts assert_posts_equal( merged_post=one(matching_address_posts), original=post_from_second_user ) - nb_matching_address_posts = count_matching_posts( + nb_matching_address_posts = await count_matching_posts( session=session, addresses=[post_from_second_user.owner] ) assert nb_matching_address_posts == 1 # Get by channel - matching_channel_posts = get_matching_posts( + matching_channel_posts = await get_matching_posts( session=session, channels=[post_from_second_user.channel] ) assert matching_channel_posts assert_posts_equal( merged_post=one(matching_channel_posts), original=post_from_second_user ) - nb_matching_channel_posts = count_matching_posts( + nb_matching_channel_posts = await count_matching_posts( session=session, channels=[post_from_second_user.channel] ) assert nb_matching_channel_posts == 1 @@ -335,13 +339,13 @@ async def test_get_matching_posts_legacy( first_amend_message: MessageDb, post_from_second_user: PostDb, message_from_second_user: MessageDb, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, ): """ Tests that the list getter works with the legacy format. Same test logic as the test above. """ - with session_factory() as session: + async with session_factory() as session: session.add_all( [original_message, first_amend_message, message_from_second_user] ) @@ -349,17 +353,17 @@ async def test_get_matching_posts_legacy( session.add(first_amend_post) original_post.latest_amend = first_amend_post.item_hash session.add(post_from_second_user) - session.commit() + await session.commit() - with session_factory() as session: + async with session_factory() as session: # Get all posts, no filter - matching_posts = get_matching_posts_legacy(session=session) + matching_posts = await get_matching_posts_legacy(session=session) assert len(matching_posts) == 2 - nb_posts = count_matching_posts(session=session) + nb_posts = await count_matching_posts(session=session) assert nb_posts == 2 # Get by hash - matching_hash_posts = get_matching_posts_legacy( + matching_hash_posts = await get_matching_posts_legacy( session=session, hashes=[original_post.item_hash] ) assert matching_hash_posts @@ -370,13 +374,13 @@ async def test_get_matching_posts_legacy( last_amend=first_amend_post, amend_message=first_amend_message, ) - nb_matching_hash_posts = count_matching_posts( + nb_matching_hash_posts = await count_matching_posts( session=session, hashes=[original_post.item_hash] ) assert nb_matching_hash_posts == 1 # Get by owner address - matching_address_posts = get_matching_posts_legacy( + matching_address_posts = await get_matching_posts_legacy( session=session, addresses=[post_from_second_user.owner] ) assert matching_address_posts @@ -385,13 +389,13 @@ async def test_get_matching_posts_legacy( original=post_from_second_user, original_message=message_from_second_user, ) - nb_matching_address_posts = count_matching_posts( + nb_matching_address_posts = await count_matching_posts( session=session, addresses=[post_from_second_user.owner] ) assert nb_matching_address_posts == 1 # Get by channel - matching_channel_posts = get_matching_posts_legacy( + matching_channel_posts = await get_matching_posts_legacy( session=session, channels=[post_from_second_user.channel] ) assert matching_channel_posts @@ -400,7 +404,7 @@ async def test_get_matching_posts_legacy( original=post_from_second_user, original_message=message_from_second_user, ) - nb_matching_channel_posts = count_matching_posts( + nb_matching_channel_posts = await count_matching_posts( session=session, channels=[post_from_second_user.channel] ) assert nb_matching_channel_posts == 1 @@ -411,25 +415,25 @@ async def test_get_matching_posts_time_filters( original_post: PostDb, first_amend_post: PostDb, post_from_second_user: PostDb, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, ): """ Tests that the time filters for the list getter work. """ - with session_factory() as session: + async with session_factory() as session: session.add(original_post) session.add(first_amend_post) original_post.latest_amend = first_amend_post.item_hash session.add(post_from_second_user) - session.commit() + await session.commit() - with session_factory() as session: + async with session_factory() as session: start_datetime = first_amend_post.creation_datetime end_datetime = start_datetime + dt.timedelta(days=1) # Sanity check, the amend is supposed to be the latest entry assert start_datetime > post_from_second_user.creation_datetime - matching_posts = get_matching_posts( + matching_posts = await get_matching_posts( session=session, start_date=start_datetime, end_date=end_datetime ) assert matching_posts @@ -445,22 +449,24 @@ async def test_get_matching_posts_sort_order( original_post: PostDb, first_amend_post: PostDb, post_from_second_user: PostDb, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, ): """ Tests that the sort order specifier for the list getter work. """ - with session_factory() as session: + async with session_factory() as session: session.add(original_post) session.add(first_amend_post) original_post.latest_amend = first_amend_post.item_hash session.add(post_from_second_user) - session.commit() + await session.commit() - with session_factory() as session: + async with session_factory() as session: # Ascending order first - asc_posts = get_matching_posts(session=session, sort_order=SortOrder.ASCENDING) + asc_posts = await get_matching_posts( + session=session, sort_order=SortOrder.ASCENDING + ) assert asc_posts assert_posts_equal(merged_post=asc_posts[0], original=post_from_second_user) assert_posts_equal( @@ -470,7 +476,9 @@ async def test_get_matching_posts_sort_order( ) # Descending order first - asc_posts = get_matching_posts(session=session, sort_order=SortOrder.DESCENDING) + asc_posts = await get_matching_posts( + session=session, sort_order=SortOrder.DESCENDING + ) assert asc_posts assert_posts_equal( merged_post=asc_posts[0], @@ -482,54 +490,60 @@ async def test_get_matching_posts_sort_order( @pytest.mark.asyncio async def test_get_matching_posts_no_data( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, ): """ Tests that the list getter works when a node starts syncing. """ - with session_factory() as session: - posts = list(get_matching_posts(session=session)) + async with session_factory() as session: + posts = list(await get_matching_posts(session=session)) assert posts == [] @pytest.mark.asyncio async def test_refresh_latest_amend( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, original_post: PostDb, first_amend_post: PostDb, second_amend_post: PostDb, ): - with session_factory() as session: + async with session_factory() as session: session.add(original_post) session.add(first_amend_post) session.add(second_amend_post) - session.commit() + await session.commit() - with session_factory() as session: - refresh_latest_amend(session, original_post.item_hash) - session.commit() + async with session_factory() as session: + await refresh_latest_amend(session, original_post.item_hash) + await session.commit() - original_post_db = get_original_post(session, item_hash=original_post.item_hash) + original_post_db = await get_original_post( + session, item_hash=original_post.item_hash + ) assert original_post_db assert original_post_db.latest_amend == second_amend_post.item_hash # Now, delete the second post and check that refreshing the latest amend works - with session_factory() as session: - delete_post(session, item_hash=second_amend_post.item_hash) - refresh_latest_amend(session=session, item_hash=original_post.item_hash) - session.commit() + async with session_factory() as session: + await delete_post(session, item_hash=second_amend_post.item_hash) + await refresh_latest_amend(session=session, item_hash=original_post.item_hash) + await session.commit() - original_post_db = get_original_post(session, item_hash=original_post.item_hash) + original_post_db = await get_original_post( + session, item_hash=original_post.item_hash + ) assert original_post_db assert original_post_db.latest_amend == first_amend_post.item_hash # Delete the last amend, check that latest_amend is now null - with session_factory() as session: - delete_post(session, item_hash=first_amend_post.item_hash) - refresh_latest_amend(session=session, item_hash=original_post.item_hash) - session.commit() + async with session_factory() as session: + await delete_post(session, item_hash=first_amend_post.item_hash) + await refresh_latest_amend(session=session, item_hash=original_post.item_hash) + await session.commit() - original_post_db = get_original_post(session, item_hash=original_post.item_hash) + original_post_db = await get_original_post( + session, item_hash=original_post.item_hash + ) assert original_post_db assert original_post_db.latest_amend is None diff --git a/tests/db/test_programs_db.py b/tests/db/test_programs_db.py index 3ae5af7d3..70fe7d97a 100644 --- a/tests/db/test_programs_db.py +++ b/tests/db/test_programs_db.py @@ -26,7 +26,7 @@ RuntimeDb, VmVersionDb, ) -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from aleph.types.vms import VmVersion @@ -177,13 +177,14 @@ def assert_programs_equal(expected: ProgramDb, actual: ProgramDb): assert actual.runtime.comment == expected.runtime.comment -def test_program_accessors( - session_factory: DbSessionFactory, +@pytest.mark.asyncio +async def test_program_accessors( + session_factory: AsyncDbSessionFactory, original_program: ProgramDb, program_update: ProgramDb, program_with_many_volumes: ProgramDb, ): - with session_factory() as session: + async with session_factory() as session: session.add(original_program) session.add(program_update) session.add( @@ -203,22 +204,22 @@ def test_program_accessors( last_updated=program_with_many_volumes.created, ) ) - session.commit() + await session.commit() - with session_factory() as session: - original_program_db = get_program( + async with session_factory() as session: + original_program_db = await get_program( session=session, item_hash=original_program.item_hash ) assert original_program_db is not None assert_programs_equal(expected=original_program, actual=original_program_db) - program_update_db = get_program( + program_update_db = await get_program( session=session, item_hash=program_update.item_hash ) assert program_update_db is not None assert_programs_equal(expected=program_update_db, actual=program_update) - program_with_many_volumes_db = get_program( + program_with_many_volumes_db = await get_program( session=session, item_hash=program_with_many_volumes.item_hash ) assert program_with_many_volumes_db is not None @@ -226,75 +227,78 @@ def test_program_accessors( expected=program_with_many_volumes_db, actual=program_with_many_volumes ) - is_amend_allowed = is_vm_amend_allowed( + is_amend_allowed = await is_vm_amend_allowed( session=session, vm_hash=original_program.item_hash ) assert is_amend_allowed is False - is_amend_allowed = is_vm_amend_allowed( + is_amend_allowed = await is_vm_amend_allowed( session=session, vm_hash=program_with_many_volumes.item_hash ) assert is_amend_allowed is True -def test_refresh_program( - session_factory: DbSessionFactory, +@pytest.mark.asyncio +async def test_refresh_program( + session_factory: AsyncDbSessionFactory, original_program: ProgramDb, program_update: ProgramDb, ): program_hash = original_program.item_hash - def get_program_version(session) -> Optional[VmVersionDb]: - return session.execute( - select(VmVersionDb).where(VmVersionDb.vm_hash == program_hash) + async def get_program_version(session) -> Optional[VmVersionDb]: + return ( + await session.execute( + select(VmVersionDb).where(VmVersionDb.vm_hash == program_hash) + ) ).scalar_one_or_none() # Insert program version with refresh_program_version - with session_factory() as session: + async with session_factory() as session: session.add(original_program) - session.commit() + await session.commit() - refresh_vm_version(session=session, vm_hash=program_hash) - session.commit() + await refresh_vm_version(session=session, vm_hash=program_hash) + await session.commit() - program_version_db = get_program_version(session) + program_version_db = await get_program_version(session) assert program_version_db is not None assert program_version_db.current_version == program_hash assert program_version_db.last_updated == original_program.created # Update the version of the program, program_versions should be updated - with session_factory() as session: + async with session_factory() as session: session.add(program_update) - session.commit() + await session.commit() - refresh_vm_version(session=session, vm_hash=program_hash) - session.commit() + await refresh_vm_version(session=session, vm_hash=program_hash) + await session.commit() - program_version_db = get_program_version(session) + program_version_db = await get_program_version(session) assert program_version_db is not None assert program_version_db.current_version == program_update.item_hash assert program_version_db.last_updated == program_update.created # Delete the update, the original should be back in program_versions - with session_factory() as session: - delete_vm(session=session, vm_hash=program_update.item_hash) - session.commit() + async with session_factory() as session: + await delete_vm(session=session, vm_hash=program_update.item_hash) + await session.commit() - refresh_vm_version(session=session, vm_hash=program_hash) - session.commit() + await refresh_vm_version(session=session, vm_hash=program_hash) + await session.commit() - program_version_db = get_program_version(session) + program_version_db = await get_program_version(session) assert program_version_db is not None assert program_version_db.current_version == program_hash assert program_version_db.last_updated == original_program.created # Delete the original, no entry should be left in program_versions - with session_factory() as session: - delete_vm(session=session, vm_hash=original_program.item_hash) - session.commit() + async with session_factory() as session: + await delete_vm(session=session, vm_hash=original_program.item_hash) + await session.commit() - refresh_vm_version(session=session, vm_hash=program_hash) - session.commit() + await refresh_vm_version(session=session, vm_hash=program_hash) + await session.commit() - program_version_db = get_program_version(session) + program_version_db = await get_program_version(session) assert program_version_db is None diff --git a/tests/helpers/message_test_helpers.py b/tests/helpers/message_test_helpers.py index ef79e851a..700d7eec2 100644 --- a/tests/helpers/message_test_helpers.py +++ b/tests/helpers/message_test_helpers.py @@ -8,7 +8,7 @@ from aleph.db.models import MessageDb, MessageStatusDb, PendingMessageDb from aleph.jobs.process_pending_messages import PendingMessageProcessor from aleph.toolkit.timestamp import utc_now -from aleph.types.db_session import DbSession +from aleph.types.db_session import AsyncDbSession from aleph.types.message_processing_result import MessageProcessingResult from aleph.types.message_status import MessageStatus @@ -45,7 +45,7 @@ def make_validated_message_from_dict( async def process_pending_messages( message_processor: PendingMessageProcessor, pending_messages: Sequence[PendingMessageDb], - session: DbSession, + session: AsyncDbSession, ) -> Iterable[MessageProcessingResult]: for pending_message in pending_messages: @@ -58,7 +58,7 @@ async def process_pending_messages( ) ) session.add_all(pending_messages) - session.commit() + await session.commit() pipeline = message_processor.make_pipeline() diff --git a/tests/message_processing/conftest.py b/tests/message_processing/conftest.py index eb08606b4..38cc0217b 100644 --- a/tests/message_processing/conftest.py +++ b/tests/message_processing/conftest.py @@ -13,7 +13,7 @@ from aleph.handlers.message_handler import MessageHandler from aleph.jobs.process_pending_messages import PendingMessageProcessor from aleph.storage import StorageService -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from .load_fixtures import load_fixture_messages @@ -27,7 +27,7 @@ def fixture_messages(): # it could make sense to have some general fixtures available to all the test cases # to reduce duplication between DB tests, API tests, etc. async def _load_fixtures( - session_factory: DbSessionFactory, filename: str + session_factory: AsyncDbSessionFactory, filename: str ) -> Sequence[Dict[str, Any]]: fixtures_dir = Path(__file__).parent / "fixtures" fixtures_file = fixtures_dir / filename @@ -51,30 +51,32 @@ async def _load_fixtures( chain_txs.append(ChainTxDb.from_dict(confirmation)) tx_hashes.add(tx_hash) - with session_factory() as session: + async with session_factory() as session: session.add_all(pending_messages) session.add_all(chain_txs) - session.commit() + await session.commit() return messages_json @pytest_asyncio.fixture async def fixture_aggregate_messages( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, ) -> Sequence[Dict[str, Any]]: return await _load_fixtures(session_factory, "test-data-aggregates.json") @pytest_asyncio.fixture async def fixture_post_messages( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, ) -> Sequence[Dict[str, Any]]: return await _load_fixtures(session_factory, "test-data-posts.json") -@pytest.fixture -def message_processor(mocker, mock_config: Config, session_factory: DbSessionFactory): +@pytest_asyncio.fixture +async def message_processor( + mocker, mock_config: Config, session_factory: AsyncDbSessionFactory +): storage_engine = InMemoryStorageEngine(files={}) storage_service = StorageService( storage_engine=storage_engine, diff --git a/tests/message_processing/test_process_aggregates.py b/tests/message_processing/test_process_aggregates.py index cad49b04f..9bc0ceeab 100644 --- a/tests/message_processing/test_process_aggregates.py +++ b/tests/message_processing/test_process_aggregates.py @@ -20,7 +20,7 @@ from aleph.storage import StorageService from aleph.toolkit.timestamp import timestamp_to_datetime from aleph.types.channel import Channel -from aleph.types.db_session import DbSession, DbSessionFactory +from aleph.types.db_session import AsyncDbSession, AsyncDbSessionFactory from aleph.types.message_processing_result import ProcessedMessage @@ -28,7 +28,7 @@ async def test_process_aggregate_first_element( mocker, mock_config: Config, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, fixture_aggregate_messages: List[Dict], ): storage_service = StorageService( @@ -45,9 +45,9 @@ async def test_process_aggregate_first_element( item_hash = "a87004aa03f8ae63d2c4bbe84b93b9ce70ca6482ce36c82ab0b0f689fc273f34" - with session_factory() as session: + async with session_factory() as session: pending_message = ( - session.execute( + await session.execute( select(PendingMessageDb) .where(PendingMessageDb.item_hash == item_hash) .options(selectinload(PendingMessageDb.tx)) @@ -55,7 +55,7 @@ async def test_process_aggregate_first_element( ).scalar_one() await message_handler.process(session=session, pending_message=pending_message) - session.commit() + await session.commit() # Check the aggregate content = json.loads(pending_message.item_content) @@ -63,9 +63,9 @@ async def test_process_aggregate_first_element( expected_key = content["key"] expected_creation_datetime = timestamp_to_datetime(content["time"]) - with session_factory() as session: + async with session_factory() as session: elements = list( - get_aggregate_elements( + await get_aggregate_elements( session=session, key=expected_key, owner=pending_message.sender ) ) @@ -75,7 +75,7 @@ async def test_process_aggregate_first_element( assert element.creation_datetime == expected_creation_datetime assert element.content == content["content"] - aggregate = get_aggregate_by_key( + aggregate = await get_aggregate_by_key( session=session, owner=pending_message.sender, key=expected_key, @@ -91,7 +91,7 @@ async def test_process_aggregate_first_element( @pytest.mark.asyncio async def test_process_aggregates( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_processor: PendingMessageProcessor, fixture_aggregate_messages: List[Dict], ): @@ -140,14 +140,14 @@ def aggregate_updates() -> Sequence[PendingMessageDb]: async def process_aggregates_one_by_one( - session: DbSession, + session: AsyncDbSession, message_processor: PendingMessageProcessor, aggregates: Iterable[PendingMessageDb], ) -> Sequence[MessageDb]: messages = [] for pending_aggregate in aggregates: session.add(pending_aggregate) - session.commit() + await session.commit() result = one( await process_pending_messages( @@ -164,11 +164,11 @@ async def process_aggregates_one_by_one( @pytest.mark.asyncio async def test_process_aggregates_in_order( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_processor: PendingMessageProcessor, aggregate_updates: Sequence[PendingMessageDb], ): - with session_factory() as session: + async with session_factory() as session: original, update = await process_aggregates_one_by_one( session=session, message_processor=message_processor, @@ -181,7 +181,7 @@ async def test_process_aggregates_in_order( content = original.parsed_content assert isinstance(content, AggregateContent) - aggregate = get_aggregate_by_key( + aggregate = await get_aggregate_by_key( session=session, key=str(content.key), owner=content.address ) assert aggregate @@ -191,11 +191,11 @@ async def test_process_aggregates_in_order( @pytest.mark.asyncio async def test_process_aggregates_reverse_order( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_processor: PendingMessageProcessor, aggregate_updates: Sequence[PendingMessageDb], ): - with session_factory() as session: + async with session_factory() as session: update, original = await process_aggregates_one_by_one( session=session, message_processor=message_processor, @@ -204,7 +204,7 @@ async def test_process_aggregates_reverse_order( content = original.parsed_content assert isinstance(content, AggregateContent) - aggregate = get_aggregate_by_key( + aggregate = await get_aggregate_by_key( session=session, key=str(content.key), owner=content.address ) assert aggregate @@ -215,9 +215,9 @@ async def test_process_aggregates_reverse_order( @pytest.mark.asyncio async def test_delete_aggregate_one_element( mocker, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, ): - with session_factory() as session: + async with session_factory() as session: element = AggregateElementDb( item_hash="d73d50b2d2c670d4c6c8e03ad0e4e2145642375f92784c68539a3400e0e4e242", key="my-aggregate", @@ -236,7 +236,7 @@ async def test_delete_aggregate_one_element( dirty=False, ) ) - session.commit() + await session.commit() message = mocker.MagicMock() message.item_hash = element.item_hash @@ -249,14 +249,14 @@ async def test_delete_aggregate_one_element( aggregate_handler = AggregateMessageHandler() await aggregate_handler.forget_message(session=session, message=message) - session.commit() + await session.commit() aggregate = get_aggregate_by_key( session=session, owner=element.owner, key=element.key ) assert aggregate is None aggregate_elements = list( - get_aggregate_elements( + await get_aggregate_elements( session=session, owner=element.owner, key=element.key ) ) @@ -267,10 +267,10 @@ async def test_delete_aggregate_one_element( @pytest.mark.parametrize("element_to_forget", ["first", "last"]) async def test_delete_aggregate_two_elements( mocker, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, element_to_forget: str, ): - with session_factory() as session: + async with session_factory() as session: first_element = AggregateElementDb( item_hash="d73d50b2d2c670d4c6c8e03ad0e4e2145642375f92784c68539a3400e0e4e242", key="my-aggregate", @@ -297,7 +297,7 @@ async def test_delete_aggregate_two_elements( dirty=False, ) ) - session.commit() + await session.commit() if element_to_forget == "first": element_to_delete, element_to_keep = first_element, last_element @@ -315,9 +315,9 @@ async def test_delete_aggregate_two_elements( aggregate_handler = AggregateMessageHandler() await aggregate_handler.forget_message(session=session, message=message) - session.commit() + await session.commit() - aggregate = get_aggregate_by_key( + aggregate = await get_aggregate_by_key( session=session, owner=first_element.owner, key=first_element.key ) assert aggregate is not None @@ -329,7 +329,7 @@ async def test_delete_aggregate_two_elements( assert aggregate.creation_datetime == element_to_keep.creation_datetime aggregate_elements = list( - get_aggregate_elements( + await get_aggregate_elements( session=session, owner=first_element.owner, key=first_element.key ) ) diff --git a/tests/message_processing/test_process_confidential.py b/tests/message_processing/test_process_confidential.py index 625a771a3..343430590 100644 --- a/tests/message_processing/test_process_confidential.py +++ b/tests/message_processing/test_process_confidential.py @@ -5,6 +5,7 @@ from typing import List, Protocol, cast import pytest +import pytest_asyncio import pytz from aleph_message.models import ( Chain, @@ -16,6 +17,7 @@ from aleph_message.models.execution.program import ProgramContent from aleph_message.models.execution.volume import ImmutableVolume from more_itertools import one +from sqlalchemy import select from aleph.db.accessors.files import insert_message_file_pin, upsert_file_tag from aleph.db.accessors.vms import get_instance, get_vm_version @@ -31,7 +33,7 @@ from aleph.jobs.process_pending_messages import PendingMessageProcessor from aleph.toolkit.timestamp import timestamp_to_datetime from aleph.types.channel import Channel -from aleph.types.db_session import DbSession, DbSessionFactory +from aleph.types.db_session import AsyncDbSession, AsyncDbSessionFactory from aleph.types.files import FileTag, FileType from aleph.types.message_status import MessageStatus @@ -41,9 +43,9 @@ class Volume(Protocol): use_latest: bool -@pytest.fixture -def fixture_confidential_vm_message( - session_factory: DbSessionFactory, +@pytest_asyncio.fixture +async def fixture_confidential_vm_message( + session_factory: AsyncDbSessionFactory, ) -> PendingMessageDb: content = { "address": "0x9319Ad3B7A8E0eE24f2E639c40D8eD124C5520Ba", @@ -143,7 +145,7 @@ def fixture_confidential_vm_message( retries=1, next_attempt=dt.datetime(2023, 1, 1), ) - with session_factory() as session: + async with session_factory() as session: session.add(pending_message) session.add( MessageStatusDb( @@ -152,13 +154,13 @@ def fixture_confidential_vm_message( reception_time=pending_message.reception_time, ) ) - session.commit() + await session.commit() return pending_message -@pytest.fixture -def user_balance(session_factory: DbSessionFactory) -> AlephBalanceDb: +@pytest_asyncio.fixture +async def user_balance(session_factory: AsyncDbSessionFactory) -> AlephBalanceDb: balance = AlephBalanceDb( address="0x9319Ad3B7A8E0eE24f2E639c40D8eD124C5520Ba", chain=Chain.ETH, @@ -166,9 +168,9 @@ def user_balance(session_factory: DbSessionFactory) -> AlephBalanceDb: eth_height=0, ) - with session_factory() as session: + async with session_factory() as session: session.add(balance) - session.commit() + await session.commit() return balance @@ -192,7 +194,7 @@ def get_volume_refs(content: ExecutableContent) -> List[ImmutableVolume]: return volumes -def insert_volume_refs(session: DbSession, message: PendingMessageDb): +async def insert_volume_refs(session: AsyncDbSession, message: PendingMessageDb): item_content = message.item_content if message.item_content is not None else "" content = InstanceContent.model_validate_json(item_content) volumes = get_volume_refs(content) @@ -200,13 +202,14 @@ def insert_volume_refs(session: DbSession, message: PendingMessageDb): for volume in volumes: file_hash = volume.ref[::-1] - existing_file = session.query(StoredFileDb).filter_by(hash=file_hash).first() + sth = select(StoredFileDb).where(StoredFileDb.hash == file_hash) + existing_file = (await session.execute(sth)).first() if not existing_file: session.add( StoredFileDb(hash=file_hash, size=1024 * 1024, type=FileType.FILE) ) - session.flush() - insert_message_file_pin( + await session.flush() + await insert_message_file_pin( session=session, file_hash=volume.ref[::-1], owner=content.address, @@ -214,7 +217,7 @@ def insert_volume_refs(session: DbSession, message: PendingMessageDb): ref=None, created=created, ) - upsert_file_tag( + await upsert_file_tag( session=session, tag=FileTag(volume.ref), owner=content.address, @@ -225,16 +228,16 @@ def insert_volume_refs(session: DbSession, message: PendingMessageDb): @pytest.mark.asyncio async def test_process_confidential_vm( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_processor: PendingMessageProcessor, fixture_confidential_vm_message: PendingMessageDb, fixture_product_prices_aggregate_in_db, fixture_settings_aggregate_in_db, user_balance: AlephBalanceDb, ): - with session_factory() as session: - insert_volume_refs(session, fixture_confidential_vm_message) - session.commit() + async with session_factory() as session: + await insert_volume_refs(session, fixture_confidential_vm_message) + await session.commit() pipeline = message_processor.make_pipeline() _ = [message async for message in pipeline] @@ -242,8 +245,8 @@ async def test_process_confidential_vm( assert fixture_confidential_vm_message.item_content content_dict = json.loads(fixture_confidential_vm_message.item_content) - with session_factory() as session: - instance = get_instance( + async with session_factory() as session: + instance = await get_instance( session=session, item_hash=fixture_confidential_vm_message.item_hash ) assert instance is not None @@ -298,7 +301,7 @@ async def test_process_confidential_vm( assert ephemeral_volume.mount == "/var/cache" assert ephemeral_volume.size_mib == 5 - instance_version = get_vm_version( + instance_version = await get_vm_version( session=session, vm_hash=fixture_confidential_vm_message.item_hash ) assert instance_version diff --git a/tests/message_processing/test_process_forgets.py b/tests/message_processing/test_process_forgets.py index 3a6c3dc96..5bc09cdc3 100644 --- a/tests/message_processing/test_process_forgets.py +++ b/tests/message_processing/test_process_forgets.py @@ -27,7 +27,7 @@ from aleph.jobs.process_pending_messages import PendingMessageProcessor from aleph.toolkit.timestamp import timestamp_to_datetime from aleph.types.channel import Channel -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from aleph.types.files import FileType from aleph.types.message_processing_result import ProcessedMessage, RejectedMessage from aleph.types.message_status import MessageStatus @@ -53,7 +53,7 @@ def forget_handler(mocker) -> ForgetMessageHandler: @pytest.mark.asyncio async def test_forget_post_message( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_processor: PendingMessageProcessor, mock_config: Config, ): @@ -100,7 +100,7 @@ async def test_forget_post_message( forget_message_dict, reception_time=dt.datetime(2022, 1, 2), fetched=True ) - with session_factory() as session: + async with session_factory() as session: target_message_result = one( await process_pending_messages( message_processor=message_processor, @@ -112,7 +112,9 @@ async def test_forget_post_message( target_message = target_message_result.message # Sanity check - post = get_post(session=session, item_hash=ItemHash(target_message.item_hash)) + post = await get_post( + session=session, item_hash=ItemHash(target_message.item_hash) + ) assert post # Now process, the forget message @@ -127,13 +129,13 @@ async def test_forget_post_message( assert isinstance(forget_message_result, ProcessedMessage) forget_message = forget_message_result.message - target_message_status = get_message_status( + target_message_status = await get_message_status( session=session, item_hash=ItemHash(target_message.item_hash) ) assert target_message_status assert target_message_status.status == MessageStatus.FORGOTTEN - forget_message_status = get_message_status( + forget_message_status = await get_message_status( session=session, item_hash=ItemHash(forget_message.item_hash) ) assert forget_message_status @@ -151,7 +153,7 @@ async def test_forget_post_message( @pytest.mark.asyncio async def test_forget_store_message( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_processor: PendingMessageProcessor, mock_config: Config, fixture_product_prices_aggregate_in_db, @@ -199,7 +201,7 @@ async def test_forget_store_message( content=b"Test", ) - with session_factory() as session: + async with session_factory() as session: target_message_result = one( await process_pending_messages( message_processor=message_processor, @@ -210,7 +212,7 @@ async def test_forget_store_message( assert isinstance(target_message_result, ProcessedMessage) # Sanity check - nb_references = count_file_pins(session=session, file_hash=file_hash) + nb_references = await count_file_pins(session=session, file_hash=file_hash) assert nb_references == 1 forget_message_result = one( @@ -223,7 +225,7 @@ async def test_forget_store_message( assert isinstance(forget_message_result, ProcessedMessage) # Check that the file is pinned with a grace period - file = get_file(session=session, file_hash=file_hash) + file = await get_file(session=session, file_hash=file_hash) assert file assert file.pins @@ -233,7 +235,7 @@ async def test_forget_store_message( @pytest.mark.asyncio async def test_forget_forget_message( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_processor: PendingMessageProcessor, mock_config: Config, ): @@ -280,7 +282,7 @@ async def test_forget_forget_message( reception_time=dt.datetime(2022, 1, 2), ) - with session_factory() as session: + async with session_factory() as session: session.add(target_message) session.add( MessageStatusDb( @@ -289,7 +291,7 @@ async def test_forget_forget_message( reception_time=dt.datetime(2022, 1, 1), ) ) - session.commit() + await session.commit() processed_message_results = list( await process_pending_messages( @@ -303,13 +305,13 @@ async def test_forget_forget_message( for result in processed_message_results: assert isinstance(result, RejectedMessage) - target_message_status = get_message_status( + target_message_status = await get_message_status( session=session, item_hash=ItemHash(target_message.item_hash) ) assert target_message_status assert target_message_status.status == MessageStatus.PROCESSED - forget_message_status = get_message_status( + forget_message_status = await get_message_status( session=session, item_hash=ItemHash(pending_forget_message.item_hash) ) assert forget_message_status @@ -318,7 +320,7 @@ async def test_forget_forget_message( @pytest.mark.asyncio async def test_forget_store_multi_users( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_processor: PendingMessageProcessor, mock_config: Config, ): @@ -398,7 +400,7 @@ async def test_forget_store_multi_users( storage_engine = message_processor.message_handler.storage_service.storage_engine await storage_engine.write(filename=file_hash, content=file_content) - with session_factory() as session: + async with session_factory() as session: # Add messages, file references, etc session.add(store_message_user1) session.add(store_message_user2) @@ -450,20 +452,22 @@ async def test_forget_store_multi_users( ) assert isinstance(forget_message_result, ProcessedMessage) - message1_status = get_message_status( + message1_status = await get_message_status( session=session, item_hash=ItemHash(store_message_user1.item_hash) ) assert message1_status assert message1_status.status == MessageStatus.FORGOTTEN # Check that the second message and its linked objects are still there - message2_status = get_message_status( + message2_status = await get_message_status( session=session, item_hash=ItemHash(store_message_user2.item_hash) ) assert message2_status assert message2_status.status == MessageStatus.PROCESSED - file_pin = session.execute( - select(FilePinDb).where(FilePinDb.file_hash == file_hash) + file_pin = ( + await session.execute( + select(FilePinDb).where(FilePinDb.file_hash == file_hash) + ) ).scalar_one() assert file_pin.item_hash == store_message_user2.item_hash assert file_pin.owner == store_message_user2.sender @@ -475,7 +479,7 @@ async def test_forget_store_multi_users( @pytest.mark.asyncio async def test_forget_store_message_dependent( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_processor: PendingMessageProcessor, mock_config: Config, fixture_product_prices_aggregate_in_db, @@ -562,7 +566,7 @@ async def test_forget_store_message_dependent( content=b"Runtime", ) - with session_factory() as session: + async with session_factory() as session: target_message_result = one( await process_pending_messages( message_processor=message_processor, @@ -592,7 +596,7 @@ async def test_forget_store_message_dependent( assert isinstance(target_message_result2, ProcessedMessage) # Sanity check - nb_references = count_file_pins(session=session, file_hash=file_hash) + nb_references = await count_file_pins(session=session, file_hash=file_hash) assert nb_references == 1 forget_message_result = one( @@ -606,5 +610,5 @@ async def test_forget_store_message_dependent( assert forget_message_result.error_code == 503 # Check that the file continues pinned with a grace period - nb_references = count_file_pins(session=session, file_hash=file_hash) + nb_references = await count_file_pins(session=session, file_hash=file_hash) assert nb_references == 1 diff --git a/tests/message_processing/test_process_forgotten_messages.py b/tests/message_processing/test_process_forgotten_messages.py index 54a5b9e6a..c433d3111 100644 --- a/tests/message_processing/test_process_forgotten_messages.py +++ b/tests/message_processing/test_process_forgotten_messages.py @@ -1,14 +1,14 @@ import datetime as dt -from typing import cast import pytest +from aleph_message.models import ItemHash from configmanager import Config +from aleph.db.accessors.messages import get_message_by_item_hash from aleph.db.models import PendingMessageDb -from aleph.db.models.messages import ForgottenMessageDb, MessageDb from aleph.handlers.message_handler import MessageHandler from aleph.storage import StorageService -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from aleph.types.message_processing_result import ProcessedMessage, RejectedMessage from aleph.types.message_status import ErrorCode @@ -19,7 +19,7 @@ async def test_duplicated_forgotten_message( mocker, mock_config: Config, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, test_storage_service: StorageService, ): signature_verifier = mocker.AsyncMock() @@ -43,7 +43,7 @@ async def test_duplicated_forgotten_message( config=mock_config, ) - with session_factory() as session: + async with session_factory() as session: # 1) process post message test1 = await message_handler.process( session=session, @@ -51,10 +51,10 @@ async def test_duplicated_forgotten_message( ) assert isinstance(test1, ProcessedMessage) - res1 = cast( - MessageDb, - session.query(MessageDb).where(MessageDb.item_hash == post_hash).first(), + res1 = await get_message_by_item_hash( + session=session, item_hash=ItemHash(post_hash) ) + assert res1 assert res1.item_hash == post_hash # 2) process forget message @@ -63,10 +63,8 @@ async def test_duplicated_forgotten_message( pending_message=m2, ) assert isinstance(test2, ProcessedMessage) - - res2 = cast( - MessageDb, - session.query(MessageDb).where(MessageDb.item_hash == post_hash).first(), + res2 = await get_message_by_item_hash( + session=session, item_hash=ItemHash(post_hash) ) assert res2 is None @@ -78,16 +76,13 @@ async def test_duplicated_forgotten_message( assert isinstance(test3, RejectedMessage) assert test3.error_code == ErrorCode.FORGOTTEN_DUPLICATE - res3 = cast( - MessageDb, - session.query(MessageDb).where(MessageDb.item_hash == post_hash).first(), + res3 = await get_message_by_item_hash( + session=session, item_hash=ItemHash(post_hash) ) + assert res3 is None - res4 = cast( - ForgottenMessageDb, - session.query(ForgottenMessageDb) - .where(ForgottenMessageDb.item_hash == post_hash) - .first(), + res4 = await get_message_by_item_hash( + session=session, item_hash=ItemHash(post_hash) ) assert res4 diff --git a/tests/message_processing/test_process_instances.py b/tests/message_processing/test_process_instances.py index 704c51569..5eb34bfed 100644 --- a/tests/message_processing/test_process_instances.py +++ b/tests/message_processing/test_process_instances.py @@ -5,6 +5,7 @@ from typing import List, Union import pytest +import pytest_asyncio import pytz from aleph_message.models import ( Chain, @@ -46,13 +47,15 @@ get_total_and_detailed_costs_from_db, ) from aleph.toolkit.timestamp import timestamp_to_datetime -from aleph.types.db_session import DbSession, DbSessionFactory +from aleph.types.db_session import AsyncDbSession, AsyncDbSessionFactory from aleph.types.files import FileTag, FileType from aleph.types.message_status import ErrorCode, MessageStatus -@pytest.fixture -def fixture_instance_message(session_factory: DbSessionFactory) -> PendingMessageDb: +@pytest_asyncio.fixture +async def fixture_instance_message( + session_factory: AsyncDbSessionFactory, +) -> PendingMessageDb: content = { "address": "0x9319Ad3B7A8E0eE24f2E639c40D8eD124C5520Ba", "allow_amend": False, @@ -137,7 +140,7 @@ def fixture_instance_message(session_factory: DbSessionFactory) -> PendingMessag retries=0, next_attempt=dt.datetime(2023, 1, 1), ) - with session_factory() as session: + async with session_factory() as session: session.add(pending_message) session.add( @@ -147,14 +150,14 @@ def fixture_instance_message(session_factory: DbSessionFactory) -> PendingMessag reception_time=pending_message.reception_time, ) ) - session.commit() + await session.commit() return pending_message -@pytest.fixture -def fixture_instance_message_payg( - session_factory: DbSessionFactory, +@pytest_asyncio.fixture +async def fixture_instance_message_payg( + session_factory: AsyncDbSessionFactory, ) -> PendingMessageDb: content = { "address": "0x9319Ad3B7A8E0eE24f2E639c40D8eD124C5520Ba", @@ -245,7 +248,7 @@ def fixture_instance_message_payg( retries=0, next_attempt=dt.datetime(2023, 1, 1), ) - with session_factory() as session: + async with session_factory() as session: session.add(pending_message) session.add( @@ -255,13 +258,13 @@ def fixture_instance_message_payg( reception_time=pending_message.reception_time, ) ) - session.commit() + await session.commit() return pending_message -@pytest.fixture -def user_balance(session_factory: DbSessionFactory) -> AlephBalanceDb: +@pytest_asyncio.fixture +async def user_balance(session_factory: AsyncDbSessionFactory) -> AlephBalanceDb: balance = AlephBalanceDb( address="0x9319Ad3B7A8E0eE24f2E639c40D8eD124C5520Ba", chain=Chain.ETH, @@ -269,14 +272,14 @@ def user_balance(session_factory: DbSessionFactory) -> AlephBalanceDb: eth_height=0, ) - with session_factory() as session: + async with session_factory() as session: session.add(balance) - session.commit() + await session.commit() return balance -@pytest.fixture -def fixture_forget_instance_message( +@pytest_asyncio.fixture +async def fixture_forget_instance_message( fixture_instance_message: PendingMessageDb, user_balance: AlephBalanceDb, ) -> PendingMessageDb: @@ -332,7 +335,7 @@ def get_volume_refs( return volumes -def insert_volume_refs(session: DbSession, message: PendingMessageDb): +async def insert_volume_refs(session: AsyncDbSession, message: PendingMessageDb): """ Insert volume references in the DB to make the program processable. """ @@ -349,8 +352,8 @@ def insert_volume_refs(session: DbSession, message: PendingMessageDb): file_hash = volume.ref[::-1] session.add(StoredFileDb(hash=file_hash, size=1024 * 1024, type=FileType.FILE)) - session.flush() - insert_message_file_pin( + await session.flush() + await insert_message_file_pin( session=session, file_hash=volume.ref[::-1], owner=content.address, @@ -358,7 +361,7 @@ def insert_volume_refs(session: DbSession, message: PendingMessageDb): ref=None, created=created, ) - upsert_file_tag( + await upsert_file_tag( session=session, tag=FileTag(volume.ref), owner=content.address, @@ -369,16 +372,16 @@ def insert_volume_refs(session: DbSession, message: PendingMessageDb): @pytest.mark.asyncio async def test_process_instance( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_processor: PendingMessageProcessor, fixture_instance_message: PendingMessageDb, user_balance: AlephBalanceDb, fixture_product_prices_aggregate_in_db, fixture_settings_aggregate_in_db, ): - with session_factory() as session: - insert_volume_refs(session, fixture_instance_message) - session.commit() + async with session_factory() as session: + await insert_volume_refs(session, fixture_instance_message) + await session.commit() pipeline = message_processor.make_pipeline() # Exhaust the iterator @@ -387,8 +390,8 @@ async def test_process_instance( assert fixture_instance_message.item_content content_dict = json.loads(fixture_instance_message.item_content) - with session_factory() as session: - instance = get_instance( + async with session_factory() as session: + instance = await get_instance( session=session, item_hash=fixture_instance_message.item_hash ) assert instance is not None @@ -445,7 +448,7 @@ async def test_process_instance( assert ephemeral_volume.mount == "/var/cache" assert ephemeral_volume.size_mib == 5 - instance_version = get_vm_version( + instance_version = await get_vm_version( session=session, vm_hash=fixture_instance_message.item_hash ) assert instance_version @@ -456,7 +459,7 @@ async def test_process_instance( @pytest.mark.asyncio async def test_process_instance_missing_volumes( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_processor: PendingMessageProcessor, fixture_instance_message: PendingMessageDb, user_balance: AlephBalanceDb, @@ -471,17 +474,17 @@ async def test_process_instance_missing_volumes( # Exhaust the iterator _ = [message async for message in pipeline] - with session_factory() as session: - instance = get_instance(session=session, item_hash=vm_hash) + async with session_factory() as session: + instance = await get_instance(session=session, item_hash=vm_hash) assert instance is None - message_status = get_message_status( + message_status = await get_message_status( session=session, item_hash=ItemHash(vm_hash) ) assert message_status is not None assert message_status.status == MessageStatus.REJECTED - rejected_message = get_rejected_message( + rejected_message = await get_rejected_message( session=session, item_hash=ItemHash(vm_hash) ) assert rejected_message is not None @@ -499,7 +502,7 @@ async def test_process_instance_missing_volumes( @pytest.mark.asyncio async def test_forget_instance_message( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_processor: PendingMessageProcessor, fixture_instance_message: PendingMessageDb, user_balance: AlephBalanceDb, @@ -510,51 +513,51 @@ async def test_forget_instance_message( vm_hash = fixture_instance_message.item_hash # Process the instance message - with session_factory() as session: - insert_volume_refs(session, fixture_instance_message) - session.commit() + async with session_factory() as session: + await insert_volume_refs(session, fixture_instance_message) + await session.commit() pipeline = message_processor.make_pipeline() # Exhaust the iterator _ = [message async for message in pipeline] # Sanity check - with session_factory() as session: - instance = get_instance(session=session, item_hash=vm_hash) + async with session_factory() as session: + instance = await get_instance(session=session, item_hash=vm_hash) assert instance is not None # Insert the FORGET message and process it session.add(fixture_forget_instance_message) - session.commit() + await session.commit() pipeline = message_processor.make_pipeline() # Exhaust the iterator _ = [message async for message in pipeline] - with session_factory() as session: - instance = get_instance(session=session, item_hash=vm_hash) + async with session_factory() as session: + instance = await get_instance(session=session, item_hash=vm_hash) assert instance is None, "The instance is still present despite being forgotten" - instance_version = get_vm_version(session=session, vm_hash=vm_hash) + instance_version = await get_vm_version(session=session, vm_hash=vm_hash) assert instance_version is None @pytest.mark.asyncio async def test_process_instance_balance( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_processor: PendingMessageProcessor, fixture_instance_message: PendingMessageDb, ): - with session_factory() as session: - insert_volume_refs(session, fixture_instance_message) - session.commit() + async with session_factory() as session: + await insert_volume_refs(session, fixture_instance_message) + await session.commit() pipeline = message_processor.make_pipeline() # Exhaust the iterator _ = [message async for message in pipeline] - with session_factory() as session: - rejected_message = get_rejected_message( + async with session_factory() as session: + rejected_message = await get_rejected_message( session=session, item_hash=fixture_instance_message.item_hash ) assert rejected_message is not None @@ -562,24 +565,24 @@ async def test_process_instance_balance( @pytest.mark.asyncio async def test_get_additional_storage_price( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, fixture_instance_message: PendingMessageDb, fixture_product_prices_aggregate_in_db, fixture_settings_aggregate_in_db, ): - with session_factory() as session: - insert_volume_refs(session, fixture_instance_message) - session.commit() + async with session_factory() as session: + await insert_volume_refs(session, fixture_instance_message) + await session.commit() if fixture_instance_message.item_content: content = InstanceContent.model_validate_json( fixture_instance_message.item_content ) - with session_factory() as session: - settings = _get_settings(session) - pricing = _get_product_price(session, content, settings) + async with session_factory() as session: + settings = await _get_settings(session) + pricing = await _get_product_price(session, content, settings) - additional_price = _get_additional_storage_price( + additional_price = await _get_additional_storage_price( content=content, pricing=pricing, session=session, @@ -594,15 +597,15 @@ async def test_get_additional_storage_price( @pytest.mark.asyncio async def test_get_total_and_detailed_costs_from_db( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_processor: PendingMessageProcessor, fixture_instance_message: PendingMessageDb, fixture_product_prices_aggregate_in_db, fixture_settings_aggregate_in_db, ): - with session_factory() as session: - insert_volume_refs(session, fixture_instance_message) - session.commit() + async with session_factory() as session: + await insert_volume_refs(session, fixture_instance_message) + await session.commit() pipeline = message_processor.make_pipeline() # Exhaust the iterator @@ -612,8 +615,8 @@ async def test_get_total_and_detailed_costs_from_db( content = InstanceContent.model_validate_json( fixture_instance_message.item_content ) - with session_factory() as session: - cost, _ = get_total_and_detailed_costs( + async with session_factory() as session: + cost, _ = await get_total_and_detailed_costs( session=session, content=content, item_hash=fixture_instance_message.item_hash, @@ -624,16 +627,16 @@ async def test_get_total_and_detailed_costs_from_db( @pytest.mark.asyncio async def test_compare_account_cost_with_cost_function_hold( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_processor: PendingMessageProcessor, fixture_instance_message: PendingMessageDb, user_balance: AlephBalanceDb, fixture_product_prices_aggregate_in_db, fixture_settings_aggregate_in_db, ): - with session_factory() as session: - insert_volume_refs(session, fixture_instance_message) - session.commit() + async with session_factory() as session: + await insert_volume_refs(session, fixture_instance_message) + await session.commit() pipeline = message_processor.make_pipeline() # Exhaust the iterator @@ -641,14 +644,14 @@ async def test_compare_account_cost_with_cost_function_hold( assert fixture_instance_message.item_content content = InstanceContent.model_validate_json(fixture_instance_message.item_content) - with session_factory() as session: - db_cost, _ = get_total_and_detailed_costs_from_db( + async with session_factory() as session: + db_cost, _ = await get_total_and_detailed_costs_from_db( session=session, content=content, item_hash=fixture_instance_message.item_hash, ) - cost, _ = get_total_and_detailed_costs( + cost, _ = await get_total_and_detailed_costs( session=session, content=content, item_hash=fixture_instance_message.item_hash, @@ -659,16 +662,16 @@ async def test_compare_account_cost_with_cost_function_hold( @pytest.mark.asyncio async def test_compare_account_cost_with_cost_payg_funct( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_processor: PendingMessageProcessor, fixture_instance_message_payg: PendingMessageDb, fixture_product_prices_aggregate_in_db, fixture_settings_aggregate_in_db, user_balance: AlephBalanceDb, ): - with session_factory() as session: - insert_volume_refs(session, fixture_instance_message_payg) - session.commit() + async with session_factory() as session: + await insert_volume_refs(session, fixture_instance_message_payg) + await session.commit() pipeline = message_processor.make_pipeline() # Exhaust the iterator @@ -680,15 +683,15 @@ async def test_compare_account_cost_with_cost_payg_funct( fixture_instance_message_payg.item_content ) # Parse again - with session_factory() as session: + async with session_factory() as session: assert content.payment.type == PaymentType.superfluid - cost, details = get_total_and_detailed_costs( + cost, details = await get_total_and_detailed_costs( session=session, content=content, item_hash=fixture_instance_message_payg.item_hash, ) - db_cost, details = get_total_and_detailed_costs_from_db( + db_cost, details = await get_total_and_detailed_costs_from_db( session=session, content=content, item_hash=fixture_instance_message_payg.item_hash, @@ -698,9 +701,10 @@ async def test_compare_account_cost_with_cost_payg_funct( assert cost == db_cost +@pytest.mark.asyncio @pytest.fixture -def fixture_instance_message_only_rootfs( - session_factory: DbSessionFactory, +async def fixture_instance_message_only_rootfs( + session_factory: AsyncDbSessionFactory, ) -> PendingMessageDb: content = { "address": "0x9319Ad3B7A8E0eE24f2E639c40D8eD124C5520Ba", @@ -750,7 +754,7 @@ def fixture_instance_message_only_rootfs( retries=0, next_attempt=dt.datetime(2023, 1, 1), ) - with session_factory() as session: + async with session_factory() as session: session.add(pending_message) session.add( @@ -760,23 +764,23 @@ def fixture_instance_message_only_rootfs( reception_time=pending_message.reception_time, ) ) - session.commit() + await session.commit() return pending_message @pytest.mark.asyncio async def test_compare_account_cost_with_cost_function_without_volume( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_processor: PendingMessageProcessor, fixture_instance_message_only_rootfs: PendingMessageDb, user_balance: AlephBalanceDb, fixture_product_prices_aggregate_in_db, fixture_settings_aggregate_in_db, ): - with session_factory() as session: - insert_volume_refs(session, fixture_instance_message_only_rootfs) - session.commit() + async with session_factory() as session: + await insert_volume_refs(session, fixture_instance_message_only_rootfs) + await session.commit() pipeline = message_processor.make_pipeline() # Exhaust the iterator @@ -786,12 +790,12 @@ async def test_compare_account_cost_with_cost_function_without_volume( content = InstanceContent.model_validate_json( fixture_instance_message_only_rootfs.item_content ) - with session_factory() as session: - cost, details = get_total_and_detailed_costs( + async with session_factory() as session: + cost, details = await get_total_and_detailed_costs( session=session, content=content, item_hash="abab" ) - db_cost, details = get_total_and_detailed_costs_from_db( + db_cost, details = await get_total_and_detailed_costs_from_db( session=session, content=content, item_hash=fixture_instance_message_only_rootfs.item_hash, diff --git a/tests/message_processing/test_process_pending_messages.py b/tests/message_processing/test_process_pending_messages.py index 0cae91bb1..16deb1d74 100644 --- a/tests/message_processing/test_process_pending_messages.py +++ b/tests/message_processing/test_process_pending_messages.py @@ -5,7 +5,7 @@ from aleph.handlers.message_handler import MessagePublisher from aleph.storage import StorageService from aleph.toolkit.timestamp import utc_now -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from aleph.types.message_status import MessageOrigin from .load_fixtures import load_fixture_message @@ -15,7 +15,7 @@ async def test_duplicated_pending_message( mocker, mock_config: Config, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, test_storage_service: StorageService, ): message = load_fixture_message("test-data-pending-messaging.json") @@ -44,6 +44,6 @@ async def test_duplicated_pending_message( assert test2.content == test1.content assert test2.reception_time == test1.reception_time - with session_factory() as session: - pending_messages = session.query(PendingMessageDb).count() + async with session_factory() as session: + pending_messages = await PendingMessageDb.count(session) # Do that work ? assert pending_messages == 1 diff --git a/tests/message_processing/test_process_pending_txs.py b/tests/message_processing/test_process_pending_txs.py index 1f1f98980..010c5a170 100644 --- a/tests/message_processing/test_process_pending_txs.py +++ b/tests/message_processing/test_process_pending_txs.py @@ -17,7 +17,7 @@ from aleph.storage import StorageService from aleph.toolkit.timestamp import timestamp_to_datetime from aleph.types.chain_sync import ChainSyncProtocol -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from aleph.types.message_status import MessageStatus from .load_fixtures import load_fixture_messages @@ -35,7 +35,7 @@ async def get_fixture_chaindata_messages( async def test_process_pending_tx_on_chain_protocol( mocker, mock_config: Config, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, test_storage_service: StorageService, ): chain_data_service = mocker.AsyncMock() @@ -66,9 +66,9 @@ async def test_process_pending_tx_on_chain_protocol( pending_tx = PendingTxDb(tx=chain_tx) - with session_factory() as session: + async with session_factory() as session: session.add(pending_tx) - session.commit() + await session.commit() seen_ids: Set[str] = set() await pending_tx_processor.handle_pending_tx( @@ -77,14 +77,14 @@ async def test_process_pending_tx_on_chain_protocol( fixture_messages = load_fixture_messages(f"{pending_tx.tx.content}.json") - with session_factory() as session: - pending_txs = session.execute(select(PendingTxDb)).scalars().all() + async with session_factory() as session: + pending_txs = (await session.execute(select(PendingTxDb))).scalars().all() assert not pending_txs for fixture_message in fixture_messages: item_hash = fixture_message["item_hash"] message_status_db = ( - session.execute( + await session.execute( select(MessageStatusDb).where( MessageStatusDb.item_hash == item_hash ) @@ -93,7 +93,7 @@ async def test_process_pending_tx_on_chain_protocol( assert message_status_db.status == MessageStatus.PENDING pending_message_db = ( - session.execute( + await session.execute( select(PendingMessageDb).where( PendingMessageDb.item_hash == item_hash ) @@ -108,7 +108,7 @@ async def test_process_pending_tx_on_chain_protocol( async def _process_smart_contract_tx( mocker, mock_config: Config, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, test_storage_service: StorageService, payload: MessageEventPayload, ): @@ -141,17 +141,19 @@ async def _process_smart_contract_tx( pending_tx = PendingTxDb(tx=tx) - with session_factory() as session: + async with session_factory() as session: session.add(pending_tx) - session.commit() + await session.commit() await pending_tx_processor.handle_pending_tx(pending_tx=pending_tx) - with session_factory() as session: - pending_txs = session.execute(select(PendingTxDb)).scalars().all() + async with session_factory() as session: + pending_txs = (await session.execute(select(PendingTxDb))).scalars().all() assert not pending_txs - pending_messages = list(session.execute(select(PendingMessageDb)).scalars()) + pending_messages = list( + (await session.execute(select(PendingMessageDb))).scalars() + ) assert len(pending_messages) == 1 pending_message_db = pending_messages[0] @@ -165,7 +167,7 @@ async def _process_smart_contract_tx( assert pending_message_db.type == MessageType(payload.message_type) message_status_db = ( - session.execute( + await session.execute( select(MessageStatusDb).where( MessageStatusDb.item_hash == pending_message_db.item_hash ) @@ -178,7 +180,7 @@ async def _process_smart_contract_tx( async def test_process_pending_smart_contract_tx_store_ipfs( mocker, mock_config: Config, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, test_storage_service: StorageService, ): payload = MessageEventPayload( @@ -201,7 +203,7 @@ async def test_process_pending_smart_contract_tx_store_ipfs( async def test_process_pending_smart_contract_tx_post( mocker, mock_config: Config, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, test_storage_service: StorageService, ): payload = MessageEventPayload( @@ -214,7 +216,7 @@ async def test_process_pending_smart_contract_tx_post( type="my-type", address="KT1VBeLD7hzKpj17aRJ3Kc6QQFeikCEXi7W6", time=1000, - ).json(), + ).model_dump_json(), ) await _process_smart_contract_tx( diff --git a/tests/message_processing/test_process_posts.py b/tests/message_processing/test_process_posts.py index 0bf88f42c..619210102 100644 --- a/tests/message_processing/test_process_posts.py +++ b/tests/message_processing/test_process_posts.py @@ -12,12 +12,12 @@ from aleph.handlers.content.post import PostMessageHandler from aleph.jobs.process_pending_messages import PendingMessageProcessor from aleph.toolkit.timestamp import timestamp_to_datetime -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory @pytest.mark.asyncio async def test_process_post_and_amend( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, mock_config: Config, message_processor: PendingMessageProcessor, fixture_post_messages: List[Dict], @@ -29,9 +29,9 @@ async def test_process_post_and_amend( original_hash = "9f02e3b5efdbdc0b487359117ae3af40db654892487feae452689a0b84dc1025" amend_hash = "93776ad67063b955869a7fa705ea2987add39486e1ed5951e9842291cf0f566c" - with session_factory() as session: + async with session_factory() as session: # We should now have one post - post = get_post(session=session, item_hash=original_hash) + post = await get_post(session=session, item_hash=original_hash) fixtures_by_item_hash = {m["item_hash"]: m for m in fixture_post_messages} original = fixtures_by_item_hash[original_hash] @@ -51,7 +51,7 @@ async def test_process_post_and_amend( @pytest.mark.asyncio async def test_forget_original_post( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, mock_config: Config, message_processor: PendingMessageProcessor, fixture_post_messages: List[Dict], @@ -66,8 +66,8 @@ async def test_forget_original_post( content_handler = PostMessageHandler( balances_addresses=[], balances_post_type="no-balances-today" ) - with session_factory() as session: - original_message = get_message_by_item_hash( + async with session_factory() as session: + original_message = await get_message_by_item_hash( session=session, item_hash=ItemHash(original_hash) ) assert original_message is not None @@ -75,9 +75,9 @@ async def test_forget_original_post( session=session, message=original_message, ) - session.commit() + await session.commit() assert additional_hashes_to_forget == {amend_hash} - posts = list(session.execute(select(PostDb)).scalars()) + posts = list((await session.execute(select(PostDb))).scalars()) assert posts == [] diff --git a/tests/message_processing/test_process_programs.py b/tests/message_processing/test_process_programs.py index 99fbce657..d81d40e79 100644 --- a/tests/message_processing/test_process_programs.py +++ b/tests/message_processing/test_process_programs.py @@ -5,6 +5,7 @@ from typing import List import pytest +import pytest_asyncio import pytz from aleph_message.models import Chain, ItemHash, ItemType, MessageType from aleph_message.models.execution import MachineType @@ -28,13 +29,15 @@ ) from aleph.jobs.process_pending_messages import PendingMessageProcessor from aleph.toolkit.timestamp import timestamp_to_datetime -from aleph.types.db_session import DbSession, DbSessionFactory +from aleph.types.db_session import AsyncDbSession, AsyncDbSessionFactory from aleph.types.files import FileTag, FileType from aleph.types.message_status import ErrorCode, MessageStatus -@pytest.fixture -def fixture_program_message(session_factory: DbSessionFactory) -> PendingMessageDb: +@pytest_asyncio.fixture +async def fixture_program_message( + session_factory: AsyncDbSessionFactory, +) -> PendingMessageDb: pending_message = PendingMessageDb( item_hash="734a1287a2b7b5be060312ff5b05ad1bcf838950492e3428f2ac6437a1acad26", type=MessageType.program, @@ -51,7 +54,7 @@ def fixture_program_message(session_factory: DbSessionFactory) -> PendingMessage retries=0, next_attempt=dt.datetime(2023, 1, 1), ) - with session_factory() as session: + async with session_factory() as session: session.add(pending_message) session.add( MessageStatusDb( @@ -60,14 +63,14 @@ def fixture_program_message(session_factory: DbSessionFactory) -> PendingMessage reception_time=pending_message.reception_time, ) ) - session.commit() + await session.commit() return pending_message -@pytest.fixture -def fixture_program_message_with_subscriptions( - session_factory: DbSessionFactory, +@pytest_asyncio.fixture +async def fixture_program_message_with_subscriptions( + session_factory: AsyncDbSessionFactory, ) -> PendingMessageDb: pending_message = PendingMessageDb( item_hash="cad11970efe9b7478300fd04d7cc91c646ca0a792b9cc718650f86e1ccfac73e", @@ -85,7 +88,7 @@ def fixture_program_message_with_subscriptions( retries=0, next_attempt=dt.datetime(2023, 1, 1), ) - with session_factory() as session: + async with session_factory() as session: session.add(pending_message) session.add( MessageStatusDb( @@ -94,7 +97,7 @@ def fixture_program_message_with_subscriptions( reception_time=pending_message.reception_time, ) ) - session.commit() + await session.commit() return pending_message @@ -111,7 +114,7 @@ def get_volumes_with_ref(content: ProgramContent) -> List: return volumes -def insert_volume_refs(session: DbSession, message: PendingMessageDb): +async def insert_volume_refs(session: AsyncDbSession, message: PendingMessageDb): """ Insert volume references in the DB to make the program processable. """ @@ -128,8 +131,8 @@ def insert_volume_refs(session: DbSession, message: PendingMessageDb): file_hash = volume.ref[::-1] session.add(StoredFileDb(hash=file_hash, size=1024 * 1024, type=FileType.FILE)) - session.flush() - insert_message_file_pin( + await session.flush() + await insert_message_file_pin( session=session, file_hash=volume.ref[::-1], owner=content.address, @@ -138,7 +141,7 @@ def insert_volume_refs(session: DbSession, message: PendingMessageDb): created=created, ) if volume.use_latest: - upsert_file_tag( + await upsert_file_tag( session=session, tag=FileTag(volume.ref), owner=content.address, @@ -147,8 +150,8 @@ def insert_volume_refs(session: DbSession, message: PendingMessageDb): ) -@pytest.fixture -def user_balance(session_factory: DbSessionFactory) -> AlephBalanceDb: +@pytest.mark.asyncio +async def user_balance(session_factory: AsyncDbSessionFactory) -> AlephBalanceDb: balance = AlephBalanceDb( address="0x7083b90eBA420832A03C6ac7e6328d37c72e0260", chain=Chain.ETH, @@ -156,24 +159,24 @@ def user_balance(session_factory: DbSessionFactory) -> AlephBalanceDb: eth_height=0, ) - with session_factory() as session: + async with session_factory() as session: session.add(balance) - session.commit() + await session.commit() return balance @pytest.mark.asyncio async def test_process_program( user_balance: AlephBalanceDb, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_processor: PendingMessageProcessor, fixture_program_message: PendingMessageDb, fixture_product_prices_aggregate_in_db, fixture_settings_aggregate_in_db, ): - with session_factory() as session: - insert_volume_refs(session, fixture_program_message) - session.commit() + async with session_factory() as session: + await insert_volume_refs(session, fixture_program_message) + await session.commit() pipeline = message_processor.make_pipeline() # Exhaust the iterator @@ -182,8 +185,8 @@ async def test_process_program( assert fixture_program_message.item_content content_dict = json.loads(fixture_program_message.item_content) - with session_factory() as session: - program = get_program( + async with session_factory() as session: + program = await get_program( session=session, item_hash=fixture_program_message.item_hash ) assert program is not None @@ -239,16 +242,16 @@ async def test_process_program( @pytest.mark.asyncio async def test_program_with_subscriptions( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_processor: PendingMessageProcessor, fixture_program_message_with_subscriptions: PendingMessageDb, fixture_product_prices_aggregate_in_db, fixture_settings_aggregate_in_db, ): program_message = fixture_program_message_with_subscriptions - with session_factory() as session: - insert_volume_refs(session, program_message) - session.commit() + async with session_factory() as session: + await insert_volume_refs(session, program_message) + await session.commit() pipeline = message_processor.make_pipeline() # Exhaust the iterator @@ -257,8 +260,8 @@ async def test_program_with_subscriptions( assert program_message.item_content json.loads(program_message.item_content) - with session_factory() as session: - program: VmBaseDb = session.execute(select(VmBaseDb)).scalar_one() + async with session_factory() as session: + program: VmBaseDb = (await session.execute(select(VmBaseDb))).scalar_one() message_triggers = program.message_triggers assert message_triggers assert len(message_triggers) == 1 @@ -270,7 +273,7 @@ async def test_program_with_subscriptions( @pytest.mark.asyncio async def test_process_program_missing_volumes( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_processor: PendingMessageProcessor, fixture_program_message_with_subscriptions: PendingMessageDb, ): @@ -285,17 +288,19 @@ async def test_process_program_missing_volumes( # Exhaust the iterator _ = [message async for message in pipeline] - with session_factory() as session: - program_db = get_program(session=session, item_hash=ItemHash(program_hash)) + async with session_factory() as session: + program_db = await get_program( + session=session, item_hash=ItemHash(program_hash) + ) assert program_db is None - message_status = get_message_status( + message_status = await get_message_status( session=session, item_hash=ItemHash(program_hash) ) assert message_status is not None assert message_status.status == MessageStatus.REJECTED - rejected_message = get_rejected_message( + rejected_message = await get_rejected_message( session=session, item_hash=ItemHash(program_hash) ) assert rejected_message is not None diff --git a/tests/message_processing/test_process_stores.py b/tests/message_processing/test_process_stores.py index d06a8b06d..10de99e1a 100644 --- a/tests/message_processing/test_process_stores.py +++ b/tests/message_processing/test_process_stores.py @@ -19,7 +19,7 @@ from aleph.toolkit.constants import STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP from aleph.toolkit.timestamp import timestamp_to_datetime from aleph.types.channel import Channel -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from aleph.types.message_status import InsufficientBalanceException, MessageStatus @@ -87,7 +87,7 @@ async def exists(self, filename: str) -> bool: async def test_process_store( mocker, mock_config: Config, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, fixture_product_prices_aggregate_in_db, fixture_settings_aggregate_in_db, fixture_store_message: PendingMessageDb, @@ -109,13 +109,13 @@ async def test_process_store( config=mock_config, ) - with session_factory() as session: + async with session_factory() as session: await message_handler.process( session=session, pending_message=fixture_store_message ) - session.commit() + await session.commit() - cost, _ = get_total_and_detailed_costs_from_db( + cost, _ = await get_total_and_detailed_costs_from_db( session=session, content=fixture_store_message.content, item_hash=fixture_store_message.item_hash, @@ -127,7 +127,7 @@ async def test_process_store( @pytest.mark.asyncio async def test_process_store_no_signature( mocker, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, message_processor: PendingMessageProcessor, fixture_store_message: PendingMessageDb, fixture_product_prices_aggregate_in_db, @@ -145,7 +145,7 @@ async def test_process_store_no_signature( content = json.loads(fixture_store_message.item_content) fixture_store_message.content = content - with session_factory() as session: + async with session_factory() as session: session.add(fixture_store_message) session.add( MessageStatusDb( @@ -154,7 +154,7 @@ async def test_process_store_no_signature( reception_time=fixture_store_message.reception_time, ) ) - session.commit() + await session.commit() storage_service = StorageService( storage_engine=MockStorageEngine( @@ -176,15 +176,15 @@ async def test_process_store_no_signature( # Exhaust the iterator _ = [message async for message in pipeline] - with session_factory() as session: - message_db = get_message_by_item_hash( + async with session_factory() as session: + message_db = await get_message_by_item_hash( session=session, item_hash=ItemHash(fixture_store_message.item_hash) ) assert message_db is not None assert message_db.signature is None - file_pin = get_message_file_pin( + file_pin = await get_message_file_pin( session=session, item_hash=ItemHash(fixture_store_message.item_hash) ) assert file_pin is not None @@ -195,7 +195,7 @@ async def test_process_store_no_signature( async def test_process_store_with_not_enough_balance( mocker, mock_config: Config, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, fixture_product_prices_aggregate_in_db, fixture_settings_aggregate_in_db, fixture_store_message_with_cost: PendingMessageDb, @@ -220,7 +220,7 @@ async def test_process_store_with_not_enough_balance( config=mock_config, ) - with session_factory() as session: + async with session_factory() as session: # NOTE: Account balance is 0 at this point with pytest.raises(InsufficientBalanceException): await message_handler.process( @@ -232,7 +232,7 @@ async def test_process_store_with_not_enough_balance( async def test_process_store_small_file_no_balance_required( mocker, mock_config: Config, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, fixture_product_prices_aggregate_in_db, fixture_settings_aggregate_in_db, fixture_store_message_with_cost: PendingMessageDb, @@ -261,22 +261,22 @@ async def test_process_store_small_file_no_balance_required( config=mock_config, ) - with session_factory() as session: + async with session_factory() as session: # NOTE: Account balance is 0 at this point, but since the file is small # it should still be processed await message_handler.process( session=session, pending_message=fixture_store_message_with_cost ) - session.commit() + await session.commit() # Verify that the message was processed successfully - message_db = get_message_by_item_hash( + message_db = await get_message_by_item_hash( session=session, item_hash=ItemHash(fixture_store_message_with_cost.item_hash), ) assert message_db is not None - file_pin = get_message_file_pin( + file_pin = await get_message_file_pin( session=session, item_hash=ItemHash(fixture_store_message_with_cost.item_hash), ) diff --git a/tests/permissions/test_check_sender_authorization.py b/tests/permissions/test_check_sender_authorization.py index 6da789607..a396e915c 100644 --- a/tests/permissions/test_check_sender_authorization.py +++ b/tests/permissions/test_check_sender_authorization.py @@ -7,7 +7,7 @@ from aleph.db.models import AggregateDb, AggregateElementDb from aleph.permissions import check_sender_authorization from aleph.toolkit.timestamp import timestamp_to_datetime -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory @pytest.mark.asyncio @@ -101,7 +101,7 @@ async def test_authorized(mocker): @pytest.mark.asyncio -async def test_authorized_with_db(session_factory: DbSessionFactory): +async def test_authorized_with_db(session_factory: AsyncDbSessionFactory): aggregate_content = { "authorizations": [{"address": "0x86F39e17910E3E6d9F38412EB7F24Bf0Ba31eb2E"}] } @@ -125,9 +125,9 @@ async def test_authorized_with_db(session_factory: DbSessionFactory): AUTHORIZED_MESSAGE, AUTHORIZED_MESSAGE["item_content"] ) - with session_factory() as session: + async with session_factory() as session: session.add(aggregate) - session.commit() + await session.commit() is_authorized = await check_sender_authorization( session=session, message=message diff --git a/tests/services/test_cost_service.py b/tests/services/test_cost_service.py index 8d8441841..7f06e8a9e 100644 --- a/tests/services/test_cost_service.py +++ b/tests/services/test_cost_service.py @@ -14,7 +14,7 @@ _get_settings_aggregate, get_total_and_detailed_costs, ) -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory class StoredFileDb: @@ -270,8 +270,9 @@ def fixture_hold_program_message_complete() -> ExecutableContent: return CostEstimationProgramContent.model_validate(content) -def test_compute_cost( - session_factory: DbSessionFactory, +@pytest.mark.asyncio +async def test_compute_cost( + session_factory: AsyncDbSessionFactory, fixture_product_prices_aggregate_in_db, fixture_settings_aggregate_in_db, fixture_hold_instance_message, @@ -280,15 +281,16 @@ def test_compute_cost( mock = Mock() mock.patch("_get_file_from_ref", return_value=file_db) - with session_factory() as session: - cost, details = get_total_and_detailed_costs( + async with session_factory() as session: + cost, details = await get_total_and_detailed_costs( session=session, content=fixture_hold_instance_message, item_hash="abab" ) assert cost == Decimal("1000") -def test_compute_cost_conf( - session_factory: DbSessionFactory, +@pytest.mark.asyncio +async def test_compute_cost_conf( + session_factory: AsyncDbSessionFactory, fixture_product_prices_aggregate_in_db, fixture_settings_aggregate_in_db, fixture_hold_instance_message, @@ -312,15 +314,16 @@ def test_compute_cost_conf( mock = Mock() mock.patch("_get_file_from_ref", return_value=file_db) - with session_factory() as session: - cost, _ = get_total_and_detailed_costs( + async with session_factory() as session: + cost, _ = await get_total_and_detailed_costs( session=session, content=rebuilt_message, item_hash="abab" ) assert cost == 2000 -def test_get_additional_storage_price( - session_factory: DbSessionFactory, +@pytest.mark.asyncio +async def test_get_additional_storage_price( + session_factory: AsyncDbSessionFactory, fixture_product_prices_aggregate_in_db, fixture_settings_aggregate_in_db, fixture_hold_instance_message, @@ -329,14 +332,14 @@ def test_get_additional_storage_price( mock = Mock() mock.patch("_get_file_from_ref", return_value=file_db) - with session_factory() as session: + async with session_factory() as session: content = fixture_hold_instance_message - settings = _get_settings(session) - pricing = _get_product_price( + settings = await _get_settings(session) + pricing = await _get_product_price( session=session, content=content, settings=settings ) - cost = _get_additional_storage_price( + cost = await _get_additional_storage_price( session=session, content=content, item_hash="abab", @@ -348,8 +351,9 @@ def test_get_additional_storage_price( assert additional_cost == 0 -def test_compute_cost_instance_complete( - session_factory: DbSessionFactory, +@pytest.mark.asyncio +async def test_compute_cost_instance_complete( + session_factory: AsyncDbSessionFactory, fixture_product_prices_aggregate_in_db, fixture_settings_aggregate_in_db, fixture_hold_instance_message_complete, @@ -358,8 +362,8 @@ def test_compute_cost_instance_complete( mock = Mock() mock.patch("_get_file_from_ref", return_value=file_db) - with session_factory() as session: - cost, _ = get_total_and_detailed_costs( + async with session_factory() as session: + cost, _ = await get_total_and_detailed_costs( session=session, content=fixture_hold_instance_message_complete, item_hash="abab", @@ -367,8 +371,9 @@ def test_compute_cost_instance_complete( assert cost == 1017.50 -def test_compute_cost_program_complete( - session_factory: DbSessionFactory, +@pytest.mark.asyncio +async def test_compute_cost_program_complete( + session_factory: AsyncDbSessionFactory, fixture_product_prices_aggregate_in_db, fixture_settings_aggregate_in_db, fixture_hold_program_message_complete, @@ -377,8 +382,8 @@ def test_compute_cost_program_complete( mock = Mock() mock.patch("_get_file_from_ref", return_value=file_db) - with session_factory() as session: - cost, _ = get_total_and_detailed_costs( + async with session_factory() as session: + cost, _ = await get_total_and_detailed_costs( session=session, content=fixture_hold_program_message_complete, item_hash="asdf", @@ -386,8 +391,9 @@ def test_compute_cost_program_complete( assert cost == Decimal("630.400000000000000000") -def test_compute_flow_cost( - session_factory: DbSessionFactory, +@pytest.mark.asyncio +async def test_compute_flow_cost( + session_factory: AsyncDbSessionFactory, fixture_product_prices_aggregate_in_db, fixture_settings_aggregate_in_db, fixture_flow_instance_message, @@ -396,16 +402,17 @@ def test_compute_flow_cost( mock = Mock() mock.patch("_get_file_from_ref", return_value=file_db) - with session_factory() as session: - cost, _ = get_total_and_detailed_costs( + async with session_factory() as session: + cost, _ = await get_total_and_detailed_costs( session=session, content=fixture_flow_instance_message, item_hash="abab" ) assert cost == Decimal("0.000015277777777777") -def test_compute_flow_cost_conf( - session_factory: DbSessionFactory, +@pytest.mark.asyncio +async def test_compute_flow_cost_conf( + session_factory: AsyncDbSessionFactory, fixture_product_prices_aggregate_in_db, fixture_settings_aggregate_in_db, fixture_flow_instance_message, @@ -430,16 +437,17 @@ def test_compute_flow_cost_conf( mock = Mock() mock.patch("_get_file_from_ref", return_value=file_db) - with session_factory() as session: - cost, _ = get_total_and_detailed_costs( + async with session_factory() as session: + cost, _ = await get_total_and_detailed_costs( session=session, content=rebuilt_message, item_hash="abab" ) assert cost == Decimal("0.000030555555555555") -def test_compute_flow_cost_complete( - session_factory: DbSessionFactory, +@pytest.mark.asyncio +async def test_compute_flow_cost_complete( + session_factory: AsyncDbSessionFactory, fixture_product_prices_aggregate_in_db, fixture_settings_aggregate_in_db, fixture_flow_instance_message_complete, @@ -448,8 +456,8 @@ def test_compute_flow_cost_complete( mock = Mock() mock.patch("_get_file_from_ref", return_value=file_db) - with session_factory() as session: - cost, _ = get_total_and_detailed_costs( + async with session_factory() as session: + cost, _ = await get_total_and_detailed_costs( session=session, content=fixture_flow_instance_message_complete, item_hash="abab", @@ -458,33 +466,37 @@ def test_compute_flow_cost_complete( assert cost == Decimal("0.000032243382777775") -def test_default_settings_aggregates( - session_factory: DbSessionFactory, +@pytest.mark.asyncio +async def test_default_settings_aggregates( + session_factory: AsyncDbSessionFactory, ): - with session_factory() as session: - aggregate = _get_settings_aggregate(session) + async with session_factory() as session: + aggregate = await _get_settings_aggregate(session) assert isinstance(aggregate, dict) -def test_default_price_aggregates( - session_factory: DbSessionFactory, +@pytest.mark.asyncio +async def test_default_price_aggregates( + session_factory: AsyncDbSessionFactory, ): - with session_factory() as session: - price_aggregate = _get_price_aggregate(session=session) + async with session_factory() as session: + price_aggregate = await _get_price_aggregate(session=session) assert isinstance(price_aggregate, dict) -def test_default_settings_aggregates_db( - session_factory: DbSessionFactory, fixture_settings_aggregate_in_db +@pytest.mark.asyncio +async def test_default_settings_aggregates_db( + session_factory: AsyncDbSessionFactory, fixture_settings_aggregate_in_db ): - with session_factory() as session: - aggregate = _get_settings_aggregate(session) + async with session_factory() as session: + aggregate = await _get_settings_aggregate(session) assert isinstance(aggregate, AggregateDb) -def test_default_price_aggregates_db( - session_factory: DbSessionFactory, fixture_product_prices_aggregate_in_db +@pytest.mark.asyncio +async def test_default_price_aggregates_db( + session_factory: AsyncDbSessionFactory, fixture_product_prices_aggregate_in_db ): - with session_factory() as session: - price_aggregate = _get_price_aggregate(session=session) + async with session_factory() as session: + price_aggregate = await _get_price_aggregate(session=session) assert isinstance(price_aggregate, AggregateDb) diff --git a/tests/services/test_garbage_collector.py b/tests/services/test_garbage_collector.py index cdde1d4a5..895ab91f4 100644 --- a/tests/services/test_garbage_collector.py +++ b/tests/services/test_garbage_collector.py @@ -14,13 +14,13 @@ from aleph.services.storage.engine import StorageEngine from aleph.services.storage.garbage_collector import GarbageCollector from aleph.storage import StorageService -from aleph.types.db_session import DbSession, DbSessionFactory +from aleph.types.db_session import AsyncDbSession, AsyncDbSessionFactory from aleph.types.files import FileType -@pytest.fixture +@pytest_asyncio.fixture def gc( - session_factory: DbSessionFactory, test_storage_service: StorageService + session_factory: AsyncDbSessionFactory, test_storage_service: StorageService ) -> GarbageCollector: return GarbageCollector( session_factory=session_factory, storage_service=test_storage_service @@ -29,7 +29,7 @@ def gc( @pytest_asyncio.fixture async def fixture_files( - session_factory: DbSessionFactory, test_storage_service: StorageService + session_factory: AsyncDbSessionFactory, test_storage_service: StorageService ): files = [ StoredFileDb( @@ -89,17 +89,17 @@ async def fixture_files( filename=file.hash, content=b"test" ) - with session_factory() as session: + async with session_factory() as session: session.add_all(files) - session.commit() + await session.commit() yield files async def assert_file_is_deleted( - session: DbSession, storage_engine: StorageEngine, file_hash: str + session: AsyncDbSession, storage_engine: StorageEngine, file_hash: str ): - file_db = get_file(session=session, file_hash=file_hash) + file_db = await get_file(session=session, file_hash=file_hash) assert file_db is None content = await storage_engine.read(filename=file_hash) @@ -107,9 +107,9 @@ async def assert_file_is_deleted( async def assert_file_exists( - session: DbSession, storage_engine: StorageEngine, file_hash: str + session: AsyncDbSession, storage_engine: StorageEngine, file_hash: str ): - file_db = get_file(session=session, file_hash=file_hash) + file_db = await get_file(session=session, file_hash=file_hash) assert file_db content = await storage_engine.read(filename=file_hash) @@ -126,14 +126,14 @@ async def assert_file_exists( ], ) async def test_garbage_collector_collect( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, gc: GarbageCollector, fixture_files: List[StoredFileDb], gc_run_datetime: dt.datetime, ): - with session_factory() as session: + async with session_factory() as session: await gc.collect(datetime=gc_run_datetime) - session.commit() + await session.commit() storage_engine = gc.storage_service.storage_engine for fixture_file in fixture_files: diff --git a/tests/storage/test_store_message.py b/tests/storage/test_store_message.py index c0a732202..d867c7d45 100644 --- a/tests/storage/test_store_message.py +++ b/tests/storage/test_store_message.py @@ -11,7 +11,7 @@ from aleph.schemas.message_content import ContentSource, RawContent from aleph.services.ipfs import IpfsService from aleph.storage import StorageService -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from aleph.types.files import FileType @@ -56,7 +56,7 @@ def fixture_message_directory() -> MessageDb: @pytest.mark.asyncio async def test_handle_new_storage_file( mocker, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, mock_config: Config, fixture_message_file: MessageDb, ): @@ -88,14 +88,14 @@ async def test_handle_new_storage_file( store_message_handler = StoreMessageHandler( storage_service=storage_service, grace_period=24 ) - with session_factory() as session: + async with session_factory() as session: await store_message_handler.fetch_related_content( session=session, message=message ) - session.commit() + await session.commit() - with session_factory() as session: - stored_files = list((session.execute(select(StoredFileDb))).scalars()) + async with session_factory() as session: + stored_files = list((await session.execute(select(StoredFileDb))).scalars()) assert len(stored_files) == 1 stored_file: StoredFileDb = stored_files[0] @@ -110,7 +110,7 @@ async def test_handle_new_storage_file( @pytest.mark.asyncio async def test_handle_new_storage_directory( mocker, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, mock_config: Config, fixture_message_directory: MessageDb, ): @@ -136,14 +136,14 @@ async def test_handle_new_storage_directory( storage_service=storage_service, grace_period=24 ) - with session_factory() as session: + async with session_factory() as session: await store_message_handler.fetch_related_content( session=session, message=message ) - session.commit() + await session.commit() - with session_factory() as session: - stored_files = list((session.execute(select(StoredFileDb))).scalars()) + async with session_factory() as session: + stored_files = list((await session.execute(select(StoredFileDb))).scalars()) assert len(stored_files) == 1 stored_file = stored_files[0] @@ -159,7 +159,7 @@ async def test_handle_new_storage_directory( @pytest.mark.asyncio async def test_store_files_is_false( mocker, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, mock_config: Config, fixture_message_directory: MessageDb, ): @@ -188,14 +188,14 @@ async def test_store_files_is_false( storage_service=storage_service, grace_period=24 ) - with session_factory() as session: + async with session_factory() as session: await store_message_handler.fetch_related_content( session=session, message=message ) - session.commit() + await session.commit() - with session_factory() as session: - stored_files = list((session.execute(select(StoredFileDb))).scalars()) + async with session_factory() as session: + stored_files = list((await session.execute(select(StoredFileDb))).scalars()) assert len(stored_files) == 1 stored_file = stored_files[0] @@ -212,7 +212,7 @@ async def test_store_files_is_false( @pytest.mark.asyncio async def test_store_files_is_false_ipfs_is_disabled( mocker, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, mock_config: Config, fixture_message_directory: MessageDb, ): @@ -242,14 +242,14 @@ async def test_store_files_is_false_ipfs_is_disabled( storage_service=storage_service, grace_period=24 ) - with session_factory() as session: + async with session_factory() as session: await store_message_handler.fetch_related_content( session=session, message=message ) - session.commit() + await session.commit() - with session_factory() as session: - stored_files = list((session.execute(select(StoredFileDb))).scalars()) + async with session_factory() as session: + stored_files = list((await session.execute(select(StoredFileDb))).scalars()) assert len(stored_files) == 1 stored_file = stored_files[0] diff --git a/tests/test_network.py b/tests/test_network.py index 9ae83e6b5..cd3f63ed9 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -8,7 +8,7 @@ from aleph.handlers.message_handler import MessageHandler from aleph.schemas.pending_messages import parse_message from aleph.storage import StorageService -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory from aleph.types.message_status import InvalidMessageException @@ -91,7 +91,7 @@ async def test_invalid_signature_message_2(mocker): @pytest.mark.asyncio async def test_incoming_inline_content( mock_config: Config, - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, test_storage_service: StorageService, ): message_dict = { @@ -121,7 +121,7 @@ async def test_incoming_inline_content( fetched=True, ) - with session_factory() as session: + async with session_factory() as session: message = await message_handler.verify_and_fetch( session=session, pending_message=pending_message ) diff --git a/tests/web/controllers/test_programs.py b/tests/web/controllers/test_programs.py index b796c81f1..26c36b2bf 100644 --- a/tests/web/controllers/test_programs.py +++ b/tests/web/controllers/test_programs.py @@ -8,12 +8,12 @@ from message_test_helpers import make_validated_message_from_dict from aleph.db.models import MessageDb -from aleph.types.db_session import DbSessionFactory +from aleph.types.db_session import AsyncDbSessionFactory @pytest_asyncio.fixture async def fixture_program_messages( - session_factory: DbSessionFactory, + session_factory: AsyncDbSessionFactory, ) -> List[MessageDb]: fixtures_file = Path(__file__).parent / "fixtures/messages/program.json" @@ -37,9 +37,9 @@ async def fixture_program_messages( ) ) - with session_factory() as session: + async with session_factory() as session: session.add_all(messages) - session.commit() + await session.commit() return messages