diff --git a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/api/rest/dependencies.py b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/api/rest/dependencies.py index 49ce9523cfe..dacf0ff08b5 100644 --- a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/api/rest/dependencies.py +++ b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/api/rest/dependencies.py @@ -4,16 +4,11 @@ # import logging -from collections.abc import AsyncGenerator, Callable -from typing import Annotated -from fastapi import Depends from fastapi.requests import Request from servicelib.fastapi.dependencies import get_app, get_reverse_url_mapper from sqlalchemy.ext.asyncio import AsyncEngine -from ...services.modules.db.repositories._base import BaseRepository - logger = logging.getLogger(__name__) @@ -23,15 +18,6 @@ def get_resource_tracker_db_engine(request: Request) -> AsyncEngine: return engine -def get_repository(repo_type: type[BaseRepository]) -> Callable: - async def _get_repo( - engine: Annotated[AsyncEngine, Depends(get_resource_tracker_db_engine)], - ) -> AsyncGenerator[BaseRepository, None]: - yield repo_type(db_engine=engine) - - return _get_repo - - assert get_reverse_url_mapper # nosec assert get_app # nosec diff --git a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/api/rpc/_resource_tracker.py b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/api/rpc/_resource_tracker.py index d7e9a5ca74d..5a382782f9d 100644 --- a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/api/rpc/_resource_tracker.py +++ b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/api/rpc/_resource_tracker.py @@ -29,9 +29,6 @@ from ...core.settings import ApplicationSettings from ...services import pricing_plans, pricing_units, service_runs -from ...services.modules.db.repositories.resource_tracker import ( - ResourceTrackerRepository, -) from ...services.modules.s3 import get_s3_client router = RPCRouter() @@ -56,7 +53,7 @@ async def get_service_run_page( return await service_runs.list_service_runs( user_id=user_id, product_name=product_name, - resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine), + db_engine=app.state.engine, limit=limit, offset=offset, wallet_id=wallet_id, @@ -87,7 +84,7 @@ async def export_service_runs( s3_region=s3_settings.S3_REGION, user_id=user_id, product_name=product_name, - resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine), + db_engine=app.state.engine, wallet_id=wallet_id, access_all_wallet_usage=access_all_wallet_usage, order_by=order_by, @@ -111,7 +108,7 @@ async def get_osparc_credits_aggregated_usages_page( return await service_runs.get_osparc_credits_aggregated_usages_page( user_id=user_id, product_name=product_name, - resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine), + db_engine=app.state.engine, aggregated_by=aggregated_by, time_period=time_period, limit=limit, @@ -134,7 +131,7 @@ async def get_pricing_plan( return await pricing_plans.get_pricing_plan( product_name=product_name, pricing_plan_id=pricing_plan_id, - resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine), + db_engine=app.state.engine, ) @@ -146,7 +143,7 @@ async def list_pricing_plans( ) -> list[PricingPlanGet]: return await pricing_plans.list_pricing_plans_by_product( product_name=product_name, - resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine), + db_engine=app.state.engine, ) @@ -158,7 +155,7 @@ async def create_pricing_plan( ) -> PricingPlanGet: return await pricing_plans.create_pricing_plan( data=data, - resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine), + db_engine=app.state.engine, ) @@ -172,7 +169,7 @@ async def update_pricing_plan( return await pricing_plans.update_pricing_plan( product_name=product_name, data=data, - resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine), + db_engine=app.state.engine, ) @@ -191,7 +188,7 @@ async def get_pricing_unit( product_name=product_name, pricing_plan_id=pricing_plan_id, pricing_unit_id=pricing_unit_id, - resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine), + db_engine=app.state.engine, ) @@ -205,7 +202,7 @@ async def create_pricing_unit( return await pricing_units.create_pricing_unit( product_name=product_name, data=data, - resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine), + db_engine=app.state.engine, ) @@ -219,7 +216,7 @@ async def update_pricing_unit( return await pricing_units.update_pricing_unit( product_name=product_name, data=data, - resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine), + db_engine=app.state.engine, ) @@ -238,7 +235,7 @@ async def list_connected_services_to_pricing_plan_by_pricing_plan( ] = await pricing_plans.list_connected_services_to_pricing_plan_by_pricing_plan( product_name=product_name, pricing_plan_id=pricing_plan_id, - resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine), + db_engine=app.state.engine, ) return output @@ -257,5 +254,5 @@ async def connect_service_to_pricing_plan( pricing_plan_id=pricing_plan_id, service_key=service_key, service_version=service_version, - resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine), + db_engine=app.state.engine, ) diff --git a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/background_task_periodic_heartbeat_check.py b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/background_task_periodic_heartbeat_check.py index 256b737d479..fba9332502e 100644 --- a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/background_task_periodic_heartbeat_check.py +++ b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/background_task_periodic_heartbeat_check.py @@ -10,11 +10,12 @@ ServiceRunStatus, ) from pydantic import NonNegativeInt, PositiveInt +from sqlalchemy.ext.asyncio import AsyncEngine from ..core.settings import ApplicationSettings from ..models.credit_transactions import CreditTransactionCreditsAndStatusUpdate from ..models.service_runs import ServiceRunStoppedAtUpdate -from .modules.db.repositories.resource_tracker import ResourceTrackerRepository +from .modules.db import credit_transactions_db, service_runs_db from .utils import compute_service_run_credit_costs, make_negative _logger = logging.getLogger(__name__) @@ -23,7 +24,7 @@ async def _check_service_heartbeat( - resource_tracker_repo: ResourceTrackerRepository, + db_engine: AsyncEngine, base_start_timestamp: datetime, resource_usage_tracker_missed_heartbeat_interval: timedelta, resource_usage_tracker_missed_heartbeat_counter_fail: NonNegativeInt, @@ -55,7 +56,7 @@ async def _check_service_heartbeat( missed_heartbeat_counter, ) await _close_unhealthy_service( - resource_tracker_repo, service_run_id, base_start_timestamp + db_engine, service_run_id, base_start_timestamp ) else: _logger.warning( @@ -63,13 +64,16 @@ async def _check_service_heartbeat( service_run_id, missed_heartbeat_counter, ) - await resource_tracker_repo.update_service_missed_heartbeat_counter( - service_run_id, last_heartbeat_at, missed_heartbeat_counter + await service_runs_db.update_service_missed_heartbeat_counter( + db_engine, + service_run_id=service_run_id, + last_heartbeat_at=last_heartbeat_at, + missed_heartbeat_counter=missed_heartbeat_counter, ) async def _close_unhealthy_service( - resource_tracker_repo: ResourceTrackerRepository, + db_engine: AsyncEngine, service_run_id: ServiceRunId, base_start_timestamp: datetime, ): @@ -80,8 +84,8 @@ async def _close_unhealthy_service( service_run_status=ServiceRunStatus.ERROR, service_run_status_msg="Service missed more heartbeats. It's considered unhealthy.", ) - running_service = await resource_tracker_repo.update_service_run_stopped_at( - update_service_run_stopped_at + running_service = await service_runs_db.update_service_run_stopped_at( + db_engine, data=update_service_run_stopped_at ) if running_service is None: @@ -108,8 +112,8 @@ async def _close_unhealthy_service( else CreditTransactionStatus.BILLED ), ) - await resource_tracker_repo.update_credit_transaction_credits_and_status( - update_credit_transaction + await credit_transactions_db.update_credit_transaction_credits_and_status( + db_engine, data=update_credit_transaction ) @@ -118,19 +122,18 @@ async def periodic_check_of_running_services_task(app: FastAPI) -> None: # This check runs across all products app_settings: ApplicationSettings = app.state.settings - resource_tracker_repo: ResourceTrackerRepository = ResourceTrackerRepository( - db_engine=app.state.engine - ) + _db_engine = app.state.engine base_start_timestamp = datetime.now(tz=timezone.utc) # Get all current running services (across all products) - total_count: PositiveInt = ( - await resource_tracker_repo.total_service_runs_with_running_status_across_all_products() + total_count: PositiveInt = await service_runs_db.total_service_runs_with_running_status_across_all_products( + _db_engine ) for offset in range(0, total_count, _BATCH_SIZE): - batch_check_services = await resource_tracker_repo.list_service_runs_with_running_status_across_all_products( + batch_check_services = await service_runs_db.list_service_runs_with_running_status_across_all_products( + _db_engine, offset=offset, limit=_BATCH_SIZE, ) @@ -138,7 +141,7 @@ async def periodic_check_of_running_services_task(app: FastAPI) -> None: await asyncio.gather( *( _check_service_heartbeat( - resource_tracker_repo=resource_tracker_repo, + db_engine=_db_engine, base_start_timestamp=base_start_timestamp, resource_usage_tracker_missed_heartbeat_interval=app_settings.RESOURCE_USAGE_TRACKER_MISSED_HEARTBEAT_INTERVAL_SEC, resource_usage_tracker_missed_heartbeat_counter_fail=app_settings.RESOURCE_USAGE_TRACKER_MISSED_HEARTBEAT_COUNTER_FAIL, diff --git a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/credit_transactions.py b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/credit_transactions.py index 0d4362e9748..c58eb76be8a 100644 --- a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/credit_transactions.py +++ b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/credit_transactions.py @@ -13,19 +13,18 @@ ) from models_library.wallets import WalletID from servicelib.rabbitmq import RabbitMQClient +from sqlalchemy.ext.asyncio import AsyncEngine -from ..api.rest.dependencies import get_repository +from ..api.rest.dependencies import get_resource_tracker_db_engine from ..models.credit_transactions import CreditTransactionCreate -from .modules.db.repositories.resource_tracker import ResourceTrackerRepository +from .modules.db import credit_transactions_db from .modules.rabbitmq import get_rabbitmq_client_from_request from .utils import sum_credit_transactions_and_publish_to_rabbitmq async def create_credit_transaction( credit_transaction_create_body: CreditTransactionCreateBody, - resource_tracker_repo: Annotated[ - ResourceTrackerRepository, Depends(get_repository(ResourceTrackerRepository)) - ], + db_engine: Annotated[AsyncEngine, Depends(get_resource_tracker_db_engine)], rabbitmq_client: Annotated[ RabbitMQClient, Depends(get_rabbitmq_client_from_request) ], @@ -47,12 +46,12 @@ async def create_credit_transaction( created_at=credit_transaction_create_body.created_at, last_heartbeat_at=credit_transaction_create_body.created_at, ) - transaction_id = await resource_tracker_repo.create_credit_transaction( - transaction_create + transaction_id = await credit_transactions_db.create_credit_transaction( + db_engine, data=transaction_create ) await sum_credit_transactions_and_publish_to_rabbitmq( - resource_tracker_repo, + db_engine, rabbitmq_client, credit_transaction_create_body.product_name, credit_transaction_create_body.wallet_id, @@ -64,10 +63,8 @@ async def create_credit_transaction( async def sum_credit_transactions_by_product_and_wallet( product_name: ProductName, wallet_id: WalletID, - resource_tracker_repo: Annotated[ - ResourceTrackerRepository, Depends(get_repository(ResourceTrackerRepository)) - ], + db_engine: Annotated[AsyncEngine, Depends(get_resource_tracker_db_engine)], ) -> WalletTotalCredits: - return await resource_tracker_repo.sum_credit_transactions_by_product_and_wallet( - product_name, wallet_id + return await credit_transactions_db.sum_credit_transactions_by_product_and_wallet( + db_engine, product_name=product_name, wallet_id=wallet_id ) diff --git a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/modules/db/credit_transactions_db.py b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/modules/db/credit_transactions_db.py new file mode 100644 index 00000000000..76a8e9f1dfe --- /dev/null +++ b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/modules/db/credit_transactions_db.py @@ -0,0 +1,162 @@ +import logging +from decimal import Decimal +from typing import cast + +import sqlalchemy as sa +from models_library.api_schemas_resource_usage_tracker.credit_transactions import ( + WalletTotalCredits, +) +from models_library.products import ProductName +from models_library.resource_tracker import CreditTransactionId, CreditTransactionStatus +from models_library.wallets import WalletID +from simcore_postgres_database.models.resource_tracker_credit_transactions import ( + resource_tracker_credit_transactions, +) +from simcore_postgres_database.utils_repos import transaction_context +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine + +from ....exceptions.errors import CreditTransactionNotCreatedDBError +from ....models.credit_transactions import ( + CreditTransactionCreate, + CreditTransactionCreditsAndStatusUpdate, + CreditTransactionCreditsUpdate, +) + +_logger = logging.getLogger(__name__) + + +async def create_credit_transaction( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + data: CreditTransactionCreate +) -> CreditTransactionId: + async with transaction_context(engine, connection) as conn: + insert_stmt = ( + resource_tracker_credit_transactions.insert() + .values( + product_name=data.product_name, + wallet_id=data.wallet_id, + wallet_name=data.wallet_name, + pricing_plan_id=data.pricing_plan_id, + pricing_unit_id=data.pricing_unit_id, + pricing_unit_cost_id=data.pricing_unit_cost_id, + user_id=data.user_id, + user_email=data.user_email, + osparc_credits=data.osparc_credits, + transaction_status=data.transaction_status, + transaction_classification=data.transaction_classification, + service_run_id=data.service_run_id, + payment_transaction_id=data.payment_transaction_id, + created=data.created_at, + last_heartbeat_at=data.last_heartbeat_at, + modified=sa.func.now(), + ) + .returning(resource_tracker_credit_transactions.c.transaction_id) + ) + result = await conn.execute(insert_stmt) + row = result.first() + if row is None: + raise CreditTransactionNotCreatedDBError(data=data) + return cast(CreditTransactionId, row[0]) + + +async def update_credit_transaction_credits( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + data: CreditTransactionCreditsUpdate +) -> CreditTransactionId | None: + async with transaction_context(engine, connection) as conn: + update_stmt = ( + resource_tracker_credit_transactions.update() + .values( + modified=sa.func.now(), + osparc_credits=data.osparc_credits, + last_heartbeat_at=data.last_heartbeat_at, + ) + .where( + ( + resource_tracker_credit_transactions.c.service_run_id + == data.service_run_id + ) + & ( + resource_tracker_credit_transactions.c.transaction_status + == CreditTransactionStatus.PENDING + ) + & ( + resource_tracker_credit_transactions.c.last_heartbeat_at + <= data.last_heartbeat_at + ) + ) + .returning(resource_tracker_credit_transactions.c.service_run_id) + ) + result = await conn.execute(update_stmt) + row = result.first() + if row is None: + return None + return cast(CreditTransactionId | None, row[0]) + + +async def update_credit_transaction_credits_and_status( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + data: CreditTransactionCreditsAndStatusUpdate +) -> CreditTransactionId | None: + async with transaction_context(engine, connection) as conn: + update_stmt = ( + resource_tracker_credit_transactions.update() + .values( + modified=sa.func.now(), + osparc_credits=data.osparc_credits, + transaction_status=data.transaction_status, + ) + .where( + ( + resource_tracker_credit_transactions.c.service_run_id + == data.service_run_id + ) + & ( + resource_tracker_credit_transactions.c.transaction_status + == CreditTransactionStatus.PENDING + ) + ) + .returning(resource_tracker_credit_transactions.c.service_run_id) + ) + result = await conn.execute(update_stmt) + row = result.first() + if row is None: + return None + return cast(CreditTransactionId | None, row[0]) + + +async def sum_credit_transactions_by_product_and_wallet( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + product_name: ProductName, + wallet_id: WalletID +) -> WalletTotalCredits: + async with transaction_context(engine, connection) as conn: + sum_stmt = sa.select( + sa.func.sum(resource_tracker_credit_transactions.c.osparc_credits) + ).where( + (resource_tracker_credit_transactions.c.product_name == product_name) + & (resource_tracker_credit_transactions.c.wallet_id == wallet_id) + & ( + resource_tracker_credit_transactions.c.transaction_status.in_( + [ + CreditTransactionStatus.BILLED, + CreditTransactionStatus.PENDING, + ] + ) + ) + ) + result = await conn.execute(sum_stmt) + row = result.first() + if row is None or row[0] is None: + return WalletTotalCredits( + wallet_id=wallet_id, available_osparc_credits=Decimal(0) + ) + return WalletTotalCredits(wallet_id=wallet_id, available_osparc_credits=row[0]) diff --git a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/modules/db/pricing_plans_db.py b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/modules/db/pricing_plans_db.py new file mode 100644 index 00000000000..ea6376cc15b --- /dev/null +++ b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/modules/db/pricing_plans_db.py @@ -0,0 +1,668 @@ +import logging + +import sqlalchemy as sa +from models_library.products import ProductName +from models_library.resource_tracker import ( + PricingPlanCreate, + PricingPlanId, + PricingPlanUpdate, + PricingUnitCostId, + PricingUnitId, + PricingUnitWithCostCreate, + PricingUnitWithCostUpdate, +) +from models_library.services import ServiceKey, ServiceVersion +from simcore_postgres_database.models.resource_tracker_pricing_plan_to_service import ( + resource_tracker_pricing_plan_to_service, +) +from simcore_postgres_database.models.resource_tracker_pricing_plans import ( + resource_tracker_pricing_plans, +) +from simcore_postgres_database.models.resource_tracker_pricing_unit_costs import ( + resource_tracker_pricing_unit_costs, +) +from simcore_postgres_database.models.resource_tracker_pricing_units import ( + resource_tracker_pricing_units, +) +from simcore_postgres_database.utils_repos import transaction_context +from sqlalchemy.dialects.postgresql import ARRAY, INTEGER +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine + +from ....exceptions.errors import ( + PricingPlanAndPricingUnitCombinationDoesNotExistsDBError, + PricingPlanDoesNotExistsDBError, + PricingPlanNotCreatedDBError, + PricingPlanToServiceNotCreatedDBError, + PricingUnitCostDoesNotExistsDBError, + PricingUnitCostNotCreatedDBError, + PricingUnitNotCreatedDBError, +) +from ....models.pricing_plans import ( + PricingPlansDB, + PricingPlansWithServiceDefaultPlanDB, + PricingPlanToServiceDB, +) +from ....models.pricing_unit_costs import PricingUnitCostsDB +from ....models.pricing_units import PricingUnitsDB + +_logger = logging.getLogger(__name__) + + +################################# +# Pricing plans +################################# + + +async def list_active_service_pricing_plans_by_product_and_service( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + product_name: ProductName, + service_key: ServiceKey, + service_version: ServiceVersion, +) -> list[PricingPlansWithServiceDefaultPlanDB]: + # NOTE: consilidate with utils_services_environmnets.py + def _version(column_or_value): + # converts version value string to array[integer] that can be compared + return sa.func.string_to_array(column_or_value, ".").cast(ARRAY(INTEGER)) + + async with transaction_context(engine, connection) as conn: + # Firstly find the correct service version + query = ( + sa.select( + resource_tracker_pricing_plan_to_service.c.service_key, + resource_tracker_pricing_plan_to_service.c.service_version, + ) + .select_from( + resource_tracker_pricing_plan_to_service.join( + resource_tracker_pricing_plans, + ( + resource_tracker_pricing_plan_to_service.c.pricing_plan_id + == resource_tracker_pricing_plans.c.pricing_plan_id + ), + ) + ) + .where( + ( + _version(resource_tracker_pricing_plan_to_service.c.service_version) + <= _version(service_version) + ) + & ( + resource_tracker_pricing_plan_to_service.c.service_key + == service_key + ) + & (resource_tracker_pricing_plans.c.product_name == product_name) + & (resource_tracker_pricing_plans.c.is_active.is_(True)) + ) + .order_by( + _version( + resource_tracker_pricing_plan_to_service.c.service_version + ).desc() + ) + .limit(1) + ) + result = await conn.execute(query) + row = result.first() + if row is None: + return [] + latest_service_key, latest_service_version = row + # Now choose all pricing plans connected to this service + query = ( + sa.select( + resource_tracker_pricing_plans.c.pricing_plan_id, + resource_tracker_pricing_plans.c.display_name, + resource_tracker_pricing_plans.c.description, + resource_tracker_pricing_plans.c.classification, + resource_tracker_pricing_plans.c.is_active, + resource_tracker_pricing_plans.c.created, + resource_tracker_pricing_plans.c.pricing_plan_key, + resource_tracker_pricing_plan_to_service.c.service_default_plan, + ) + .select_from( + resource_tracker_pricing_plan_to_service.join( + resource_tracker_pricing_plans, + ( + resource_tracker_pricing_plan_to_service.c.pricing_plan_id + == resource_tracker_pricing_plans.c.pricing_plan_id + ), + ) + ) + .where( + ( + _version(resource_tracker_pricing_plan_to_service.c.service_version) + == _version(latest_service_version) + ) + & ( + resource_tracker_pricing_plan_to_service.c.service_key + == latest_service_key + ) + & (resource_tracker_pricing_plans.c.product_name == product_name) + & (resource_tracker_pricing_plans.c.is_active.is_(True)) + ) + .order_by(resource_tracker_pricing_plan_to_service.c.pricing_plan_id.desc()) + ) + result = await conn.execute(query) + + return [ + PricingPlansWithServiceDefaultPlanDB.model_validate(row) + for row in result.fetchall() + ] + + +async def get_pricing_plan( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + product_name: ProductName, + pricing_plan_id: PricingPlanId, +) -> PricingPlansDB: + async with transaction_context(engine, connection) as conn: + select_stmt = sa.select( + resource_tracker_pricing_plans.c.pricing_plan_id, + resource_tracker_pricing_plans.c.display_name, + resource_tracker_pricing_plans.c.description, + resource_tracker_pricing_plans.c.classification, + resource_tracker_pricing_plans.c.is_active, + resource_tracker_pricing_plans.c.created, + resource_tracker_pricing_plans.c.pricing_plan_key, + ).where( + (resource_tracker_pricing_plans.c.pricing_plan_id == pricing_plan_id) + & (resource_tracker_pricing_plans.c.product_name == product_name) + ) + result = await conn.execute(select_stmt) + row = result.first() + if row is None: + raise PricingPlanDoesNotExistsDBError(pricing_plan_id=pricing_plan_id) + return PricingPlansDB.model_validate(row) + + +async def list_pricing_plans_by_product( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + product_name: ProductName, +) -> list[PricingPlansDB]: + async with transaction_context(engine, connection) as conn: + select_stmt = sa.select( + resource_tracker_pricing_plans.c.pricing_plan_id, + resource_tracker_pricing_plans.c.display_name, + resource_tracker_pricing_plans.c.description, + resource_tracker_pricing_plans.c.classification, + resource_tracker_pricing_plans.c.is_active, + resource_tracker_pricing_plans.c.created, + resource_tracker_pricing_plans.c.pricing_plan_key, + ).where(resource_tracker_pricing_plans.c.product_name == product_name) + result = await conn.execute(select_stmt) + + return [PricingPlansDB.model_validate(row) for row in result.fetchall()] + + +async def create_pricing_plan( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + data: PricingPlanCreate, +) -> PricingPlansDB: + async with transaction_context(engine, connection) as conn: + insert_stmt = ( + resource_tracker_pricing_plans.insert() + .values( + product_name=data.product_name, + display_name=data.display_name, + description=data.description, + classification=data.classification, + is_active=True, + created=sa.func.now(), + modified=sa.func.now(), + pricing_plan_key=data.pricing_plan_key, + ) + .returning( + *[ + resource_tracker_pricing_plans.c.pricing_plan_id, + resource_tracker_pricing_plans.c.display_name, + resource_tracker_pricing_plans.c.description, + resource_tracker_pricing_plans.c.classification, + resource_tracker_pricing_plans.c.is_active, + resource_tracker_pricing_plans.c.created, + resource_tracker_pricing_plans.c.pricing_plan_key, + ] + ) + ) + result = await conn.execute(insert_stmt) + row = result.first() + if row is None: + raise PricingPlanNotCreatedDBError(data=data) + return PricingPlansDB.model_validate(row) + + +async def update_pricing_plan( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + product_name: ProductName, + data: PricingPlanUpdate, +) -> PricingPlansDB | None: + async with transaction_context(engine, connection) as conn: + update_stmt = ( + resource_tracker_pricing_plans.update() + .values( + display_name=data.display_name, + description=data.description, + is_active=data.is_active, + modified=sa.func.now(), + ) + .where( + ( + resource_tracker_pricing_plans.c.pricing_plan_id + == data.pricing_plan_id + ) + & (resource_tracker_pricing_plans.c.product_name == product_name) + ) + .returning( + *[ + resource_tracker_pricing_plans.c.pricing_plan_id, + resource_tracker_pricing_plans.c.display_name, + resource_tracker_pricing_plans.c.description, + resource_tracker_pricing_plans.c.classification, + resource_tracker_pricing_plans.c.is_active, + resource_tracker_pricing_plans.c.created, + resource_tracker_pricing_plans.c.pricing_plan_key, + ] + ) + ) + result = await conn.execute(update_stmt) + row = result.first() + if row is None: + return None + return PricingPlansDB.model_validate(row) + + +################################# +# Pricing plan to service +################################# + + +async def list_connected_services_to_pricing_plan_by_pricing_plan( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + product_name: ProductName, + pricing_plan_id: PricingPlanId, +) -> list[PricingPlanToServiceDB]: + async with transaction_context(engine, connection) as conn: + query = ( + sa.select( + resource_tracker_pricing_plan_to_service.c.pricing_plan_id, + resource_tracker_pricing_plan_to_service.c.service_key, + resource_tracker_pricing_plan_to_service.c.service_version, + resource_tracker_pricing_plan_to_service.c.created, + ) + .select_from( + resource_tracker_pricing_plan_to_service.join( + resource_tracker_pricing_plans, + ( + resource_tracker_pricing_plan_to_service.c.pricing_plan_id + == resource_tracker_pricing_plans.c.pricing_plan_id + ), + ) + ) + .where( + (resource_tracker_pricing_plans.c.product_name == product_name) + & (resource_tracker_pricing_plans.c.pricing_plan_id == pricing_plan_id) + ) + .order_by(resource_tracker_pricing_plan_to_service.c.pricing_plan_id.desc()) + ) + result = await conn.execute(query) + + return [PricingPlanToServiceDB.model_validate(row) for row in result.fetchall()] + + +async def upsert_service_to_pricing_plan( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + product_name: ProductName, + pricing_plan_id: PricingPlanId, + service_key: ServiceKey, + service_version: ServiceVersion, +) -> PricingPlanToServiceDB: + async with transaction_context(engine, connection) as conn: + query = ( + sa.select( + resource_tracker_pricing_plan_to_service.c.pricing_plan_id, + resource_tracker_pricing_plan_to_service.c.service_key, + resource_tracker_pricing_plan_to_service.c.service_version, + resource_tracker_pricing_plan_to_service.c.created, + ) + .select_from( + resource_tracker_pricing_plan_to_service.join( + resource_tracker_pricing_plans, + ( + resource_tracker_pricing_plan_to_service.c.pricing_plan_id + == resource_tracker_pricing_plans.c.pricing_plan_id + ), + ) + ) + .where( + (resource_tracker_pricing_plans.c.product_name == product_name) + & (resource_tracker_pricing_plans.c.pricing_plan_id == pricing_plan_id) + & ( + resource_tracker_pricing_plan_to_service.c.service_key + == service_key + ) + & ( + resource_tracker_pricing_plan_to_service.c.service_version + == service_version + ) + ) + ) + result = await conn.execute(query) + row = result.first() + + if row is not None: + delete_stmt = resource_tracker_pricing_plan_to_service.delete().where( + (resource_tracker_pricing_plans.c.pricing_plan_id == pricing_plan_id) + & ( + resource_tracker_pricing_plan_to_service.c.service_key + == service_key + ) + & ( + resource_tracker_pricing_plan_to_service.c.service_version + == service_version + ) + ) + await conn.execute(delete_stmt) + + insert_stmt = ( + resource_tracker_pricing_plan_to_service.insert() + .values( + pricing_plan_id=pricing_plan_id, + service_key=service_key, + service_version=service_version, + created=sa.func.now(), + modified=sa.func.now(), + service_default_plan=True, + ) + .returning( + *[ + resource_tracker_pricing_plan_to_service.c.pricing_plan_id, + resource_tracker_pricing_plan_to_service.c.service_key, + resource_tracker_pricing_plan_to_service.c.service_version, + resource_tracker_pricing_plan_to_service.c.created, + ] + ) + ) + result = await conn.execute(insert_stmt) + row = result.first() + if row is None: + raise PricingPlanToServiceNotCreatedDBError( + data=f"pricing_plan_id {pricing_plan_id}, service_key {service_key}, service_version {service_version}" + ) + return PricingPlanToServiceDB.model_validate(row) + + +################################# +# Pricing units +################################# + + +def _pricing_units_select_stmt(): + return sa.select( + resource_tracker_pricing_units.c.pricing_unit_id, + resource_tracker_pricing_units.c.pricing_plan_id, + resource_tracker_pricing_units.c.unit_name, + resource_tracker_pricing_units.c.unit_extra_info, + resource_tracker_pricing_units.c.default, + resource_tracker_pricing_units.c.specific_info, + resource_tracker_pricing_units.c.created, + resource_tracker_pricing_units.c.modified, + resource_tracker_pricing_unit_costs.c.cost_per_unit.label( + "current_cost_per_unit" + ), + resource_tracker_pricing_unit_costs.c.pricing_unit_cost_id.label( + "current_cost_per_unit_id" + ), + ) + + +async def list_pricing_units_by_pricing_plan( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + pricing_plan_id: PricingPlanId, +) -> list[PricingUnitsDB]: + async with transaction_context(engine, connection) as conn: + query = ( + _pricing_units_select_stmt() + .select_from( + resource_tracker_pricing_units.join( + resource_tracker_pricing_unit_costs, + ( + ( + resource_tracker_pricing_units.c.pricing_plan_id + == resource_tracker_pricing_unit_costs.c.pricing_plan_id + ) + & ( + resource_tracker_pricing_units.c.pricing_unit_id + == resource_tracker_pricing_unit_costs.c.pricing_unit_id + ) + ), + ) + ) + .where( + (resource_tracker_pricing_units.c.pricing_plan_id == pricing_plan_id) + & (resource_tracker_pricing_unit_costs.c.valid_to.is_(None)) + ) + .order_by(resource_tracker_pricing_unit_costs.c.cost_per_unit.asc()) + ) + result = await conn.execute(query) + + return [PricingUnitsDB.model_validate(row) for row in result.fetchall()] + + +async def get_valid_pricing_unit( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + product_name: ProductName, + pricing_plan_id: PricingPlanId, + pricing_unit_id: PricingUnitId, +) -> PricingUnitsDB: + async with transaction_context(engine, connection) as conn: + query = ( + _pricing_units_select_stmt() + .select_from( + resource_tracker_pricing_units.join( + resource_tracker_pricing_unit_costs, + ( + ( + resource_tracker_pricing_units.c.pricing_plan_id + == resource_tracker_pricing_unit_costs.c.pricing_plan_id + ) + & ( + resource_tracker_pricing_units.c.pricing_unit_id + == resource_tracker_pricing_unit_costs.c.pricing_unit_id + ) + ), + ).join( + resource_tracker_pricing_plans, + ( + resource_tracker_pricing_plans.c.pricing_plan_id + == resource_tracker_pricing_units.c.pricing_plan_id + ), + ) + ) + .where( + (resource_tracker_pricing_units.c.pricing_plan_id == pricing_plan_id) + & (resource_tracker_pricing_units.c.pricing_unit_id == pricing_unit_id) + & (resource_tracker_pricing_unit_costs.c.valid_to.is_(None)) + & (resource_tracker_pricing_plans.c.product_name == product_name) + ) + ) + result = await conn.execute(query) + + row = result.first() + if row is None: + raise PricingPlanAndPricingUnitCombinationDoesNotExistsDBError( + pricing_plan_id=pricing_plan_id, + pricing_unit_id=pricing_unit_id, + product_name=product_name, + ) + return PricingUnitsDB.model_validate(row) + + +async def create_pricing_unit_with_cost( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + data: PricingUnitWithCostCreate, + pricing_plan_key: str, +) -> tuple[PricingUnitId, PricingUnitCostId]: + async with transaction_context(engine, connection) as conn: + # pricing units table + insert_stmt = ( + resource_tracker_pricing_units.insert() + .values( + pricing_plan_id=data.pricing_plan_id, + unit_name=data.unit_name, + unit_extra_info=data.unit_extra_info.model_dump(), + default=data.default, + specific_info=data.specific_info.model_dump(), + created=sa.func.now(), + modified=sa.func.now(), + ) + .returning(resource_tracker_pricing_units.c.pricing_unit_id) + ) + result = await conn.execute(insert_stmt) + row = result.first() + if row is None: + raise PricingUnitNotCreatedDBError(data=data) + _pricing_unit_id = row[0] + + # pricing unit cost table + insert_stmt = ( + resource_tracker_pricing_unit_costs.insert() + .values( + pricing_plan_id=data.pricing_plan_id, + pricing_plan_key=pricing_plan_key, + pricing_unit_id=_pricing_unit_id, + pricing_unit_name=data.unit_name, + cost_per_unit=data.cost_per_unit, + valid_from=sa.func.now(), + valid_to=None, + created=sa.func.now(), + comment=data.comment, + modified=sa.func.now(), + ) + .returning(resource_tracker_pricing_unit_costs.c.pricing_unit_cost_id) + ) + result = await conn.execute(insert_stmt) + row = result.first() + if row is None: + raise PricingUnitCostNotCreatedDBError(data=data) + _pricing_unit_cost_id = row[0] + + return (_pricing_unit_id, _pricing_unit_cost_id) + + +async def update_pricing_unit_with_cost( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + data: PricingUnitWithCostUpdate, + pricing_plan_key: str, +) -> None: + async with transaction_context(engine, connection) as conn: + # pricing units table + update_stmt = ( + resource_tracker_pricing_units.update() + .values( + unit_name=data.unit_name, + unit_extra_info=data.unit_extra_info.model_dump(), + default=data.default, + specific_info=data.specific_info.model_dump(), + modified=sa.func.now(), + ) + .where( + resource_tracker_pricing_units.c.pricing_unit_id == data.pricing_unit_id + ) + .returning(resource_tracker_pricing_units.c.pricing_unit_id) + ) + await conn.execute(update_stmt) + + # If price change, then we update pricing unit cost table + if data.pricing_unit_cost_update: + # Firstly we close previous price + update_stmt = ( + resource_tracker_pricing_unit_costs.update() + .values( + valid_to=sa.func.now(), # <-- Closing previous price + modified=sa.func.now(), + ) + .where( + resource_tracker_pricing_unit_costs.c.pricing_unit_id + == data.pricing_unit_id + ) + .returning(resource_tracker_pricing_unit_costs.c.pricing_unit_id) + ) + result = await conn.execute(update_stmt) + + # Then we create a new price + insert_stmt = ( + resource_tracker_pricing_unit_costs.insert() + .values( + pricing_plan_id=data.pricing_plan_id, + pricing_plan_key=pricing_plan_key, + pricing_unit_id=data.pricing_unit_id, + pricing_unit_name=data.unit_name, + cost_per_unit=data.pricing_unit_cost_update.cost_per_unit, + valid_from=sa.func.now(), + valid_to=None, # <-- New price is valid + created=sa.func.now(), + comment=data.pricing_unit_cost_update.comment, + modified=sa.func.now(), + ) + .returning(resource_tracker_pricing_unit_costs.c.pricing_unit_cost_id) + ) + result = await conn.execute(insert_stmt) + row = result.first() + if row is None: + raise PricingUnitCostNotCreatedDBError(data=data) + + +################################# +# Pricing unit-costs +################################# + + +async def get_pricing_unit_cost_by_id( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + pricing_unit_cost_id: PricingUnitCostId, +) -> PricingUnitCostsDB: + async with transaction_context(engine, connection) as conn: + query = sa.select( + resource_tracker_pricing_unit_costs.c.pricing_unit_cost_id, + resource_tracker_pricing_unit_costs.c.pricing_plan_id, + resource_tracker_pricing_unit_costs.c.pricing_plan_key, + resource_tracker_pricing_unit_costs.c.pricing_unit_id, + resource_tracker_pricing_unit_costs.c.pricing_unit_name, + resource_tracker_pricing_unit_costs.c.cost_per_unit, + resource_tracker_pricing_unit_costs.c.valid_from, + resource_tracker_pricing_unit_costs.c.valid_to, + resource_tracker_pricing_unit_costs.c.created, + resource_tracker_pricing_unit_costs.c.comment, + resource_tracker_pricing_unit_costs.c.modified, + ).where( + resource_tracker_pricing_unit_costs.c.pricing_unit_cost_id + == pricing_unit_cost_id + ) + result = await conn.execute(query) + + row = result.first() + if row is None: + raise PricingUnitCostDoesNotExistsDBError( + pricing_unit_cost_id=pricing_unit_cost_id + ) + return PricingUnitCostsDB.model_validate(row) diff --git a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/modules/db/repositories/__init__.py b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/modules/db/repositories/__init__.py deleted file mode 100644 index 93da4003de3..00000000000 --- a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/modules/db/repositories/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from ._base import BaseRepository - -__all__: tuple[str, ...] = ("BaseRepository",) diff --git a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/modules/db/repositories/_base.py b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/modules/db/repositories/_base.py deleted file mode 100644 index 4a20b37c735..00000000000 --- a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/modules/db/repositories/_base.py +++ /dev/null @@ -1,12 +0,0 @@ -from dataclasses import dataclass - -from sqlalchemy.ext.asyncio import AsyncEngine - - -@dataclass -class BaseRepository: - """ - Repositories are pulled at every request - """ - - db_engine: AsyncEngine diff --git a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/modules/db/repositories/resource_tracker.py b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/modules/db/repositories/resource_tracker.py deleted file mode 100644 index 46439f26e38..00000000000 --- a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/modules/db/repositories/resource_tracker.py +++ /dev/null @@ -1,1382 +0,0 @@ -import logging -from datetime import datetime -from decimal import Decimal -from typing import cast - -import sqlalchemy as sa -from models_library.api_schemas_resource_usage_tracker.credit_transactions import ( - WalletTotalCredits, -) -from models_library.api_schemas_storage import S3BucketName -from models_library.products import ProductName -from models_library.resource_tracker import ( - CreditClassification, - CreditTransactionId, - CreditTransactionStatus, - PricingPlanCreate, - PricingPlanId, - PricingPlanUpdate, - PricingUnitCostId, - PricingUnitId, - PricingUnitWithCostCreate, - PricingUnitWithCostUpdate, - ServiceRunId, - ServiceRunStatus, -) -from models_library.rest_ordering import OrderBy, OrderDirection -from models_library.services import ServiceKey, ServiceVersion -from models_library.users import UserID -from models_library.wallets import WalletID -from pydantic import PositiveInt -from simcore_postgres_database.models.projects_tags import projects_tags -from simcore_postgres_database.models.resource_tracker_credit_transactions import ( - resource_tracker_credit_transactions, -) -from simcore_postgres_database.models.resource_tracker_pricing_plan_to_service import ( - resource_tracker_pricing_plan_to_service, -) -from simcore_postgres_database.models.resource_tracker_pricing_plans import ( - resource_tracker_pricing_plans, -) -from simcore_postgres_database.models.resource_tracker_pricing_unit_costs import ( - resource_tracker_pricing_unit_costs, -) -from simcore_postgres_database.models.resource_tracker_pricing_units import ( - resource_tracker_pricing_units, -) -from simcore_postgres_database.models.resource_tracker_service_runs import ( - resource_tracker_service_runs, -) -from simcore_postgres_database.models.tags import tags -from sqlalchemy.dialects.postgresql import ARRAY, INTEGER - -from .....exceptions.errors import ( - CreditTransactionNotCreatedDBError, - PricingPlanAndPricingUnitCombinationDoesNotExistsDBError, - PricingPlanDoesNotExistsDBError, - PricingPlanNotCreatedDBError, - PricingPlanToServiceNotCreatedDBError, - PricingUnitCostDoesNotExistsDBError, - PricingUnitCostNotCreatedDBError, - PricingUnitNotCreatedDBError, - ServiceRunNotCreatedDBError, -) -from .....models.credit_transactions import ( - CreditTransactionCreate, - CreditTransactionCreditsAndStatusUpdate, - CreditTransactionCreditsUpdate, -) -from .....models.pricing_plans import ( - PricingPlansDB, - PricingPlansWithServiceDefaultPlanDB, - PricingPlanToServiceDB, -) -from .....models.pricing_unit_costs import PricingUnitCostsDB -from .....models.pricing_units import PricingUnitsDB -from .....models.service_runs import ( - OsparcCreditsAggregatedByServiceKeyDB, - ServiceRunCreate, - ServiceRunDB, - ServiceRunForCheckDB, - ServiceRunLastHeartbeatUpdate, - ServiceRunStoppedAtUpdate, - ServiceRunWithCreditsDB, -) -from ._base import BaseRepository - -_logger = logging.getLogger(__name__) - - -class ResourceTrackerRepository( - BaseRepository -): # pylint: disable=too-many-public-methods - ############### - # Service Run - ############### - - async def create_service_run(self, data: ServiceRunCreate) -> ServiceRunId: - async with self.db_engine.begin() as conn: - insert_stmt = ( - resource_tracker_service_runs.insert() - .values( - product_name=data.product_name, - service_run_id=data.service_run_id, - wallet_id=data.wallet_id, - wallet_name=data.wallet_name, - pricing_plan_id=data.pricing_plan_id, - pricing_unit_id=data.pricing_unit_id, - pricing_unit_cost_id=data.pricing_unit_cost_id, - pricing_unit_cost=data.pricing_unit_cost, - simcore_user_agent=data.simcore_user_agent, - user_id=data.user_id, - user_email=data.user_email, - project_id=f"{data.project_id}", - project_name=data.project_name, - node_id=f"{data.node_id}", - node_name=data.node_name, - parent_project_id=f"{data.parent_project_id}", - root_parent_project_id=f"{data.root_parent_project_id}", - root_parent_project_name=data.root_parent_project_name, - parent_node_id=f"{data.parent_node_id}", - root_parent_node_id=f"{data.root_parent_node_id}", - service_key=data.service_key, - service_version=data.service_version, - service_type=data.service_type, - service_resources=data.service_resources, - service_additional_metadata=data.service_additional_metadata, - started_at=data.started_at, - stopped_at=None, - service_run_status=ServiceRunStatus.RUNNING, - modified=sa.func.now(), - last_heartbeat_at=data.last_heartbeat_at, - ) - .returning(resource_tracker_service_runs.c.service_run_id) - ) - result = await conn.execute(insert_stmt) - row = result.first() - if row is None: - raise ServiceRunNotCreatedDBError(data=data) - return cast(ServiceRunId, row[0]) - - async def update_service_run_last_heartbeat( - self, data: ServiceRunLastHeartbeatUpdate - ) -> ServiceRunDB | None: - async with self.db_engine.begin() as conn: - update_stmt = ( - resource_tracker_service_runs.update() - .values( - modified=sa.func.now(), - last_heartbeat_at=data.last_heartbeat_at, - missed_heartbeat_counter=0, - ) - .where( - ( - resource_tracker_service_runs.c.service_run_id - == data.service_run_id - ) - & ( - resource_tracker_service_runs.c.service_run_status - == ServiceRunStatus.RUNNING - ) - & ( - resource_tracker_service_runs.c.last_heartbeat_at - <= data.last_heartbeat_at - ) - ) - .returning(sa.literal_column("*")) - ) - result = await conn.execute(update_stmt) - row = result.first() - if row is None: - return None - return ServiceRunDB.model_validate(row) - - async def update_service_run_stopped_at( - self, data: ServiceRunStoppedAtUpdate - ) -> ServiceRunDB | None: - async with self.db_engine.begin() as conn: - update_stmt = ( - resource_tracker_service_runs.update() - .values( - modified=sa.func.now(), - stopped_at=data.stopped_at, - service_run_status=data.service_run_status, - service_run_status_msg=data.service_run_status_msg, - ) - .where( - ( - resource_tracker_service_runs.c.service_run_id - == data.service_run_id - ) - & ( - resource_tracker_service_runs.c.service_run_status - == ServiceRunStatus.RUNNING - ) - ) - .returning(sa.literal_column("*")) - ) - result = await conn.execute(update_stmt) - row = result.first() - if row is None: - return None - return ServiceRunDB.model_validate(row) - - async def get_service_run_by_id( - self, service_run_id: ServiceRunId - ) -> ServiceRunDB | None: - async with self.db_engine.begin() as conn: - stmt = sa.select(resource_tracker_service_runs).where( - resource_tracker_service_runs.c.service_run_id == service_run_id - ) - result = await conn.execute(stmt) - row = result.first() - if row is None: - return None - return ServiceRunDB.model_validate(row) - - _project_tags_subquery = ( - sa.select( - projects_tags.c.project_uuid_for_rut, - sa.func.array_agg(tags.c.name).label("project_tags"), - ) - .select_from(projects_tags.join(tags, projects_tags.c.tag_id == tags.c.id)) - .group_by(projects_tags.c.project_uuid_for_rut) - ).subquery("project_tags_subquery") - - async def list_service_runs_by_product_and_user_and_wallet( - self, - product_name: ProductName, - *, - user_id: UserID | None, - wallet_id: WalletID | None, - offset: int, - limit: int, - service_run_status: ServiceRunStatus | None = None, - started_from: datetime | None = None, - started_until: datetime | None = None, - order_by: OrderBy | None = None, - ) -> list[ServiceRunWithCreditsDB]: - async with self.db_engine.begin() as conn: - query = ( - sa.select( - resource_tracker_service_runs.c.product_name, - resource_tracker_service_runs.c.service_run_id, - resource_tracker_service_runs.c.wallet_id, - resource_tracker_service_runs.c.wallet_name, - resource_tracker_service_runs.c.pricing_plan_id, - resource_tracker_service_runs.c.pricing_unit_id, - resource_tracker_service_runs.c.pricing_unit_cost_id, - resource_tracker_service_runs.c.pricing_unit_cost, - resource_tracker_service_runs.c.user_id, - resource_tracker_service_runs.c.user_email, - resource_tracker_service_runs.c.project_id, - resource_tracker_service_runs.c.project_name, - resource_tracker_service_runs.c.node_id, - resource_tracker_service_runs.c.node_name, - resource_tracker_service_runs.c.parent_project_id, - resource_tracker_service_runs.c.root_parent_project_id, - resource_tracker_service_runs.c.root_parent_project_name, - resource_tracker_service_runs.c.parent_node_id, - resource_tracker_service_runs.c.root_parent_node_id, - resource_tracker_service_runs.c.service_key, - resource_tracker_service_runs.c.service_version, - resource_tracker_service_runs.c.service_type, - resource_tracker_service_runs.c.service_resources, - resource_tracker_service_runs.c.started_at, - resource_tracker_service_runs.c.stopped_at, - resource_tracker_service_runs.c.service_run_status, - resource_tracker_service_runs.c.modified, - resource_tracker_service_runs.c.last_heartbeat_at, - resource_tracker_service_runs.c.service_run_status_msg, - resource_tracker_service_runs.c.missed_heartbeat_counter, - resource_tracker_credit_transactions.c.osparc_credits, - resource_tracker_credit_transactions.c.transaction_status, - sa.func.coalesce( - self._project_tags_subquery.c.project_tags, - sa.cast(sa.text("'{}'"), sa.ARRAY(sa.String)), - ).label("project_tags"), - ) - .select_from( - resource_tracker_service_runs.join( - resource_tracker_credit_transactions, - ( - resource_tracker_service_runs.c.product_name - == resource_tracker_credit_transactions.c.product_name - ) - & ( - resource_tracker_service_runs.c.service_run_id - == resource_tracker_credit_transactions.c.service_run_id - ), - isouter=True, - ).join( - self._project_tags_subquery, - resource_tracker_service_runs.c.project_id - == self._project_tags_subquery.c.project_uuid_for_rut, - isouter=True, - ) - ) - .where(resource_tracker_service_runs.c.product_name == product_name) - .offset(offset) - .limit(limit) - ) - - if user_id: - query = query.where(resource_tracker_service_runs.c.user_id == user_id) - if wallet_id: - query = query.where( - resource_tracker_service_runs.c.wallet_id == wallet_id - ) - if service_run_status: - query = query.where( - resource_tracker_service_runs.c.service_run_status - == service_run_status - ) - if started_from: - query = query.where( - sa.func.DATE(resource_tracker_service_runs.c.started_at) - >= started_from.date() - ) - if started_until: - query = query.where( - sa.func.DATE(resource_tracker_service_runs.c.started_at) - <= started_until.date() - ) - - if order_by: - if order_by.direction == OrderDirection.ASC: - query = query.order_by(sa.asc(order_by.field)) - else: - query = query.order_by(sa.desc(order_by.field)) - else: - # Default ordering - query = query.order_by( - resource_tracker_service_runs.c.started_at.desc() - ) - - result = await conn.execute(query) - - return [ - ServiceRunWithCreditsDB.model_validate(row) for row in result.fetchall() - ] - - async def get_osparc_credits_aggregated_by_service( - self, - product_name: ProductName, - *, - user_id: UserID | None, - wallet_id: WalletID, - offset: int, - limit: int, - started_from: datetime | None = None, - started_until: datetime | None = None, - ) -> tuple[int, list[OsparcCreditsAggregatedByServiceKeyDB]]: - async with self.db_engine.begin() as conn: - base_query = ( - sa.select( - resource_tracker_service_runs.c.service_key, - sa.func.SUM( - resource_tracker_credit_transactions.c.osparc_credits - ).label("osparc_credits"), - sa.func.SUM( - sa.func.round( - ( - sa.func.extract( - "epoch", - resource_tracker_service_runs.c.stopped_at, - ) - - sa.func.extract( - "epoch", - resource_tracker_service_runs.c.started_at, - ) - ) - / 3600, - 2, - ) - ).label("running_time_in_hours"), - ) - .select_from( - resource_tracker_service_runs.join( - resource_tracker_credit_transactions, - ( - resource_tracker_service_runs.c.product_name - == resource_tracker_credit_transactions.c.product_name - ) - & ( - resource_tracker_service_runs.c.service_run_id - == resource_tracker_credit_transactions.c.service_run_id - ), - isouter=True, - ) - ) - .where( - (resource_tracker_service_runs.c.product_name == product_name) - & ( - resource_tracker_credit_transactions.c.transaction_status - == CreditTransactionStatus.BILLED - ) - & ( - resource_tracker_credit_transactions.c.transaction_classification - == CreditClassification.DEDUCT_SERVICE_RUN - ) - & (resource_tracker_credit_transactions.c.wallet_id == wallet_id) - ) - .group_by(resource_tracker_service_runs.c.service_key) - ) - - if user_id: - base_query = base_query.where( - resource_tracker_service_runs.c.user_id == user_id - ) - if started_from: - base_query = base_query.where( - sa.func.DATE(resource_tracker_service_runs.c.started_at) - >= started_from.date() - ) - if started_until: - base_query = base_query.where( - sa.func.DATE(resource_tracker_service_runs.c.started_at) - <= started_until.date() - ) - - subquery = base_query.subquery() - count_query = sa.select(sa.func.count()).select_from(subquery) - count_result = await conn.execute(count_query) - - # Default ordering and pagination - list_query = ( - base_query.order_by(resource_tracker_service_runs.c.service_key.asc()) - .offset(offset) - .limit(limit) - ) - list_result = await conn.execute(list_query) - - return ( - cast(int, count_result.scalar()), - [ - OsparcCreditsAggregatedByServiceKeyDB.model_validate(row) - for row in list_result.fetchall() - ], - ) - - async def export_service_runs_table_to_s3( - self, - product_name: ProductName, - s3_bucket_name: S3BucketName, - s3_key: str, - s3_region: str, - *, - user_id: UserID | None, - wallet_id: WalletID | None, - started_from: datetime | None = None, - started_until: datetime | None = None, - order_by: OrderBy | None = None, - ): - async with self.db_engine.begin() as conn: - query = ( - sa.select( - resource_tracker_service_runs.c.product_name, - resource_tracker_service_runs.c.service_run_id, - resource_tracker_service_runs.c.wallet_name, - resource_tracker_service_runs.c.user_email, - resource_tracker_service_runs.c.root_parent_project_name.label( - "project_name" - ), - resource_tracker_service_runs.c.node_name, - resource_tracker_service_runs.c.service_key, - resource_tracker_service_runs.c.service_version, - resource_tracker_service_runs.c.service_type, - resource_tracker_service_runs.c.started_at, - resource_tracker_service_runs.c.stopped_at, - resource_tracker_credit_transactions.c.osparc_credits, - resource_tracker_credit_transactions.c.transaction_status, - sa.func.coalesce( - self._project_tags_subquery.c.project_tags, - sa.cast(sa.text("'{}'"), sa.ARRAY(sa.String)), - ).label("project_tags"), - ) - .select_from( - resource_tracker_service_runs.join( - resource_tracker_credit_transactions, - resource_tracker_service_runs.c.service_run_id - == resource_tracker_credit_transactions.c.service_run_id, - isouter=True, - ).join( - self._project_tags_subquery, - resource_tracker_service_runs.c.project_id - == self._project_tags_subquery.c.project_uuid_for_rut, - isouter=True, - ) - ) - .where(resource_tracker_service_runs.c.product_name == product_name) - ) - - if user_id: - query = query.where(resource_tracker_service_runs.c.user_id == user_id) - if wallet_id: - query = query.where( - resource_tracker_service_runs.c.wallet_id == wallet_id - ) - if started_from: - query = query.where( - sa.func.DATE(resource_tracker_service_runs.c.started_at) - >= started_from.date() - ) - if started_until: - query = query.where( - sa.func.DATE(resource_tracker_service_runs.c.started_at) - <= started_until.date() - ) - - if order_by: - if order_by.direction == OrderDirection.ASC: - query = query.order_by(sa.asc(order_by.field)) - else: - query = query.order_by(sa.desc(order_by.field)) - else: - # Default ordering - query = query.order_by( - resource_tracker_service_runs.c.started_at.desc() - ) - - compiled_query = ( - str(query.compile(compile_kwargs={"literal_binds": True})) - .replace("\n", "") - .replace("'", "''") - ) - - result = await conn.execute( - sa.DDL( - f""" - SELECT * from aws_s3.query_export_to_s3('{compiled_query}', - aws_commons.create_s3_uri('{s3_bucket_name}', '{s3_key}', '{s3_region}'), 'format csv, HEADER true'); - """ # noqa: S608 - ) - ) - row = result.first() - assert row - _logger.info( - "Rows uploaded %s, Files uploaded %s, Bytes uploaded %s", - row[0], - row[1], - row[2], - ) - - async def total_service_runs_by_product_and_user_and_wallet( - self, - product_name: ProductName, - *, - user_id: UserID | None, - wallet_id: WalletID | None, - service_run_status: ServiceRunStatus | None = None, - started_from: datetime | None = None, - started_until: datetime | None = None, - ) -> PositiveInt: - async with self.db_engine.begin() as conn: - query = ( - sa.select(sa.func.count()) - .select_from(resource_tracker_service_runs) - .where(resource_tracker_service_runs.c.product_name == product_name) - ) - - if user_id: - query = query.where(resource_tracker_service_runs.c.user_id == user_id) - if wallet_id: - query = query.where( - resource_tracker_service_runs.c.wallet_id == wallet_id - ) - if started_from: - query = query.where( - sa.func.DATE(resource_tracker_service_runs.c.started_at) - >= started_from.date() - ) - if started_until: - query = query.where( - sa.func.DATE(resource_tracker_service_runs.c.started_at) - <= started_until.date() - ) - if service_run_status: - query = query.where( - resource_tracker_service_runs.c.service_run_status - == service_run_status - ) - - result = await conn.execute(query) - row = result.first() - return cast(PositiveInt, row[0]) if row else 0 - - ### For Background check purpose: - - async def list_service_runs_with_running_status_across_all_products( - self, - *, - offset: int, - limit: int, - ) -> list[ServiceRunForCheckDB]: - async with self.db_engine.begin() as conn: - query = ( - sa.select( - resource_tracker_service_runs.c.service_run_id, - resource_tracker_service_runs.c.last_heartbeat_at, - resource_tracker_service_runs.c.missed_heartbeat_counter, - resource_tracker_service_runs.c.modified, - ) - .where( - resource_tracker_service_runs.c.service_run_status - == ServiceRunStatus.RUNNING - ) - .order_by(resource_tracker_service_runs.c.started_at.desc()) # NOTE: - .offset(offset) - .limit(limit) - ) - result = await conn.execute(query) - - return [ServiceRunForCheckDB.model_validate(row) for row in result.fetchall()] - - async def total_service_runs_with_running_status_across_all_products( - self, - ) -> PositiveInt: - async with self.db_engine.begin() as conn: - query = ( - sa.select(sa.func.count()) - .select_from(resource_tracker_service_runs) - .where( - resource_tracker_service_runs.c.service_run_status - == ServiceRunStatus.RUNNING - ) - ) - result = await conn.execute(query) - row = result.first() - return cast(PositiveInt, row[0]) if row else 0 - - async def update_service_missed_heartbeat_counter( - self, - service_run_id: ServiceRunId, - last_heartbeat_at: datetime, - missed_heartbeat_counter: int, - ) -> ServiceRunDB | None: - async with self.db_engine.begin() as conn: - update_stmt = ( - resource_tracker_service_runs.update() - .values( - modified=sa.func.now(), - missed_heartbeat_counter=missed_heartbeat_counter, - ) - .where( - (resource_tracker_service_runs.c.service_run_id == service_run_id) - & ( - resource_tracker_service_runs.c.service_run_status - == ServiceRunStatus.RUNNING - ) - & ( - resource_tracker_service_runs.c.last_heartbeat_at - == last_heartbeat_at - ) - ) - .returning(sa.literal_column("*")) - ) - - result = await conn.execute(update_stmt) - row = result.first() - if row is None: - return None - return ServiceRunDB.model_validate(row) - - ################################# - # Credit transactions - ################################# - - async def create_credit_transaction( - self, data: CreditTransactionCreate - ) -> CreditTransactionId: - async with self.db_engine.begin() as conn: - insert_stmt = ( - resource_tracker_credit_transactions.insert() - .values( - product_name=data.product_name, - wallet_id=data.wallet_id, - wallet_name=data.wallet_name, - pricing_plan_id=data.pricing_plan_id, - pricing_unit_id=data.pricing_unit_id, - pricing_unit_cost_id=data.pricing_unit_cost_id, - user_id=data.user_id, - user_email=data.user_email, - osparc_credits=data.osparc_credits, - transaction_status=data.transaction_status, - transaction_classification=data.transaction_classification, - service_run_id=data.service_run_id, - payment_transaction_id=data.payment_transaction_id, - created=data.created_at, - last_heartbeat_at=data.last_heartbeat_at, - modified=sa.func.now(), - ) - .returning(resource_tracker_credit_transactions.c.transaction_id) - ) - result = await conn.execute(insert_stmt) - row = result.first() - if row is None: - raise CreditTransactionNotCreatedDBError(data=data) - return cast(CreditTransactionId, row[0]) - - async def update_credit_transaction_credits( - self, data: CreditTransactionCreditsUpdate - ) -> CreditTransactionId | None: - async with self.db_engine.begin() as conn: - update_stmt = ( - resource_tracker_credit_transactions.update() - .values( - modified=sa.func.now(), - osparc_credits=data.osparc_credits, - last_heartbeat_at=data.last_heartbeat_at, - ) - .where( - ( - resource_tracker_credit_transactions.c.service_run_id - == data.service_run_id - ) - & ( - resource_tracker_credit_transactions.c.transaction_status - == CreditTransactionStatus.PENDING - ) - & ( - resource_tracker_credit_transactions.c.last_heartbeat_at - <= data.last_heartbeat_at - ) - ) - .returning(resource_tracker_credit_transactions.c.service_run_id) - ) - result = await conn.execute(update_stmt) - row = result.first() - if row is None: - return None - return cast(CreditTransactionId | None, row[0]) - - async def update_credit_transaction_credits_and_status( - self, data: CreditTransactionCreditsAndStatusUpdate - ) -> CreditTransactionId | None: - async with self.db_engine.begin() as conn: - update_stmt = ( - resource_tracker_credit_transactions.update() - .values( - modified=sa.func.now(), - osparc_credits=data.osparc_credits, - transaction_status=data.transaction_status, - ) - .where( - ( - resource_tracker_credit_transactions.c.service_run_id - == data.service_run_id - ) - & ( - resource_tracker_credit_transactions.c.transaction_status - == CreditTransactionStatus.PENDING - ) - ) - .returning(resource_tracker_credit_transactions.c.service_run_id) - ) - result = await conn.execute(update_stmt) - row = result.first() - if row is None: - return None - return cast(CreditTransactionId | None, row[0]) - - async def sum_credit_transactions_by_product_and_wallet( - self, product_name: ProductName, wallet_id: WalletID - ) -> WalletTotalCredits: - async with self.db_engine.begin() as conn: - sum_stmt = sa.select( - sa.func.sum(resource_tracker_credit_transactions.c.osparc_credits) - ).where( - (resource_tracker_credit_transactions.c.product_name == product_name) - & (resource_tracker_credit_transactions.c.wallet_id == wallet_id) - & ( - resource_tracker_credit_transactions.c.transaction_status.in_( - [ - CreditTransactionStatus.BILLED, - CreditTransactionStatus.PENDING, - ] - ) - ) - ) - result = await conn.execute(sum_stmt) - row = result.first() - if row is None or row[0] is None: - return WalletTotalCredits( - wallet_id=wallet_id, available_osparc_credits=Decimal(0) - ) - return WalletTotalCredits(wallet_id=wallet_id, available_osparc_credits=row[0]) - - ################################# - # Pricing plans - ################################# - - async def list_active_service_pricing_plans_by_product_and_service( - self, - product_name: ProductName, - service_key: ServiceKey, - service_version: ServiceVersion, - ) -> list[PricingPlansWithServiceDefaultPlanDB]: - # NOTE: consilidate with utils_services_environmnets.py - def _version(column_or_value): - # converts version value string to array[integer] that can be compared - return sa.func.string_to_array(column_or_value, ".").cast(ARRAY(INTEGER)) - - async with self.db_engine.begin() as conn: - # Firstly find the correct service version - query = ( - sa.select( - resource_tracker_pricing_plan_to_service.c.service_key, - resource_tracker_pricing_plan_to_service.c.service_version, - ) - .select_from( - resource_tracker_pricing_plan_to_service.join( - resource_tracker_pricing_plans, - ( - resource_tracker_pricing_plan_to_service.c.pricing_plan_id - == resource_tracker_pricing_plans.c.pricing_plan_id - ), - ) - ) - .where( - ( - _version( - resource_tracker_pricing_plan_to_service.c.service_version - ) - <= _version(service_version) - ) - & ( - resource_tracker_pricing_plan_to_service.c.service_key - == service_key - ) - & (resource_tracker_pricing_plans.c.product_name == product_name) - & (resource_tracker_pricing_plans.c.is_active.is_(True)) - ) - .order_by( - _version( - resource_tracker_pricing_plan_to_service.c.service_version - ).desc() - ) - .limit(1) - ) - result = await conn.execute(query) - row = result.first() - if row is None: - return [] - latest_service_key, latest_service_version = row - # Now choose all pricing plans connected to this service - query = ( - sa.select( - resource_tracker_pricing_plans.c.pricing_plan_id, - resource_tracker_pricing_plans.c.display_name, - resource_tracker_pricing_plans.c.description, - resource_tracker_pricing_plans.c.classification, - resource_tracker_pricing_plans.c.is_active, - resource_tracker_pricing_plans.c.created, - resource_tracker_pricing_plans.c.pricing_plan_key, - resource_tracker_pricing_plan_to_service.c.service_default_plan, - ) - .select_from( - resource_tracker_pricing_plan_to_service.join( - resource_tracker_pricing_plans, - ( - resource_tracker_pricing_plan_to_service.c.pricing_plan_id - == resource_tracker_pricing_plans.c.pricing_plan_id - ), - ) - ) - .where( - ( - _version( - resource_tracker_pricing_plan_to_service.c.service_version - ) - == _version(latest_service_version) - ) - & ( - resource_tracker_pricing_plan_to_service.c.service_key - == latest_service_key - ) - & (resource_tracker_pricing_plans.c.product_name == product_name) - & (resource_tracker_pricing_plans.c.is_active.is_(True)) - ) - .order_by( - resource_tracker_pricing_plan_to_service.c.pricing_plan_id.desc() - ) - ) - result = await conn.execute(query) - - return [ - PricingPlansWithServiceDefaultPlanDB.model_validate(row) - for row in result.fetchall() - ] - - async def get_pricing_plan( - self, product_name: ProductName, pricing_plan_id: PricingPlanId - ) -> PricingPlansDB: - async with self.db_engine.begin() as conn: - select_stmt = sa.select( - resource_tracker_pricing_plans.c.pricing_plan_id, - resource_tracker_pricing_plans.c.display_name, - resource_tracker_pricing_plans.c.description, - resource_tracker_pricing_plans.c.classification, - resource_tracker_pricing_plans.c.is_active, - resource_tracker_pricing_plans.c.created, - resource_tracker_pricing_plans.c.pricing_plan_key, - ).where( - (resource_tracker_pricing_plans.c.pricing_plan_id == pricing_plan_id) - & (resource_tracker_pricing_plans.c.product_name == product_name) - ) - result = await conn.execute(select_stmt) - row = result.first() - if row is None: - raise PricingPlanDoesNotExistsDBError(pricing_plan_id=pricing_plan_id) - return PricingPlansDB.model_validate(row) - - async def list_pricing_plans_by_product( - self, product_name: ProductName - ) -> list[PricingPlansDB]: - async with self.db_engine.begin() as conn: - select_stmt = sa.select( - resource_tracker_pricing_plans.c.pricing_plan_id, - resource_tracker_pricing_plans.c.display_name, - resource_tracker_pricing_plans.c.description, - resource_tracker_pricing_plans.c.classification, - resource_tracker_pricing_plans.c.is_active, - resource_tracker_pricing_plans.c.created, - resource_tracker_pricing_plans.c.pricing_plan_key, - ).where(resource_tracker_pricing_plans.c.product_name == product_name) - result = await conn.execute(select_stmt) - - return [PricingPlansDB.model_validate(row) for row in result.fetchall()] - - async def create_pricing_plan(self, data: PricingPlanCreate) -> PricingPlansDB: - async with self.db_engine.begin() as conn: - insert_stmt = ( - resource_tracker_pricing_plans.insert() - .values( - product_name=data.product_name, - display_name=data.display_name, - description=data.description, - classification=data.classification, - is_active=True, - created=sa.func.now(), - modified=sa.func.now(), - pricing_plan_key=data.pricing_plan_key, - ) - .returning( - *[ - resource_tracker_pricing_plans.c.pricing_plan_id, - resource_tracker_pricing_plans.c.display_name, - resource_tracker_pricing_plans.c.description, - resource_tracker_pricing_plans.c.classification, - resource_tracker_pricing_plans.c.is_active, - resource_tracker_pricing_plans.c.created, - resource_tracker_pricing_plans.c.pricing_plan_key, - ] - ) - ) - result = await conn.execute(insert_stmt) - row = result.first() - if row is None: - raise PricingPlanNotCreatedDBError(data=data) - return PricingPlansDB.model_validate(row) - - async def update_pricing_plan( - self, product_name: ProductName, data: PricingPlanUpdate - ) -> PricingPlansDB | None: - async with self.db_engine.begin() as conn: - update_stmt = ( - resource_tracker_pricing_plans.update() - .values( - display_name=data.display_name, - description=data.description, - is_active=data.is_active, - modified=sa.func.now(), - ) - .where( - ( - resource_tracker_pricing_plans.c.pricing_plan_id - == data.pricing_plan_id - ) - & (resource_tracker_pricing_plans.c.product_name == product_name) - ) - .returning( - *[ - resource_tracker_pricing_plans.c.pricing_plan_id, - resource_tracker_pricing_plans.c.display_name, - resource_tracker_pricing_plans.c.description, - resource_tracker_pricing_plans.c.classification, - resource_tracker_pricing_plans.c.is_active, - resource_tracker_pricing_plans.c.created, - resource_tracker_pricing_plans.c.pricing_plan_key, - ] - ) - ) - result = await conn.execute(update_stmt) - row = result.first() - if row is None: - return None - return PricingPlansDB.model_validate(row) - - ################################# - # Pricing plan to service - ################################# - - async def list_connected_services_to_pricing_plan_by_pricing_plan( - self, product_name: ProductName, pricing_plan_id: PricingPlanId - ) -> list[PricingPlanToServiceDB]: - async with self.db_engine.begin() as conn: - query = ( - sa.select( - resource_tracker_pricing_plan_to_service.c.pricing_plan_id, - resource_tracker_pricing_plan_to_service.c.service_key, - resource_tracker_pricing_plan_to_service.c.service_version, - resource_tracker_pricing_plan_to_service.c.created, - ) - .select_from( - resource_tracker_pricing_plan_to_service.join( - resource_tracker_pricing_plans, - ( - resource_tracker_pricing_plan_to_service.c.pricing_plan_id - == resource_tracker_pricing_plans.c.pricing_plan_id - ), - ) - ) - .where( - (resource_tracker_pricing_plans.c.product_name == product_name) - & ( - resource_tracker_pricing_plans.c.pricing_plan_id - == pricing_plan_id - ) - ) - .order_by( - resource_tracker_pricing_plan_to_service.c.pricing_plan_id.desc() - ) - ) - result = await conn.execute(query) - - return [ - PricingPlanToServiceDB.model_validate(row) for row in result.fetchall() - ] - - async def upsert_service_to_pricing_plan( - self, - product_name: ProductName, - pricing_plan_id: PricingPlanId, - service_key: ServiceKey, - service_version: ServiceVersion, - ) -> PricingPlanToServiceDB: - async with self.db_engine.begin() as conn: - query = ( - sa.select( - resource_tracker_pricing_plan_to_service.c.pricing_plan_id, - resource_tracker_pricing_plan_to_service.c.service_key, - resource_tracker_pricing_plan_to_service.c.service_version, - resource_tracker_pricing_plan_to_service.c.created, - ) - .select_from( - resource_tracker_pricing_plan_to_service.join( - resource_tracker_pricing_plans, - ( - resource_tracker_pricing_plan_to_service.c.pricing_plan_id - == resource_tracker_pricing_plans.c.pricing_plan_id - ), - ) - ) - .where( - (resource_tracker_pricing_plans.c.product_name == product_name) - & ( - resource_tracker_pricing_plans.c.pricing_plan_id - == pricing_plan_id - ) - & ( - resource_tracker_pricing_plan_to_service.c.service_key - == service_key - ) - & ( - resource_tracker_pricing_plan_to_service.c.service_version - == service_version - ) - ) - ) - result = await conn.execute(query) - row = result.first() - - if row is not None: - delete_stmt = resource_tracker_pricing_plan_to_service.delete().where( - ( - resource_tracker_pricing_plans.c.pricing_plan_id - == pricing_plan_id - ) - & ( - resource_tracker_pricing_plan_to_service.c.service_key - == service_key - ) - & ( - resource_tracker_pricing_plan_to_service.c.service_version - == service_version - ) - ) - await conn.execute(delete_stmt) - - insert_stmt = ( - resource_tracker_pricing_plan_to_service.insert() - .values( - pricing_plan_id=pricing_plan_id, - service_key=service_key, - service_version=service_version, - created=sa.func.now(), - modified=sa.func.now(), - service_default_plan=True, - ) - .returning( - *[ - resource_tracker_pricing_plan_to_service.c.pricing_plan_id, - resource_tracker_pricing_plan_to_service.c.service_key, - resource_tracker_pricing_plan_to_service.c.service_version, - resource_tracker_pricing_plan_to_service.c.created, - ] - ) - ) - result = await conn.execute(insert_stmt) - row = result.first() - if row is None: - raise PricingPlanToServiceNotCreatedDBError( - data=f"pricing_plan_id {pricing_plan_id}, service_key {service_key}, service_version {service_version}" - ) - return PricingPlanToServiceDB.model_validate(row) - - ################################# - # Pricing units - ################################# - - @staticmethod - def _pricing_units_select_stmt(): - return sa.select( - resource_tracker_pricing_units.c.pricing_unit_id, - resource_tracker_pricing_units.c.pricing_plan_id, - resource_tracker_pricing_units.c.unit_name, - resource_tracker_pricing_units.c.unit_extra_info, - resource_tracker_pricing_units.c.default, - resource_tracker_pricing_units.c.specific_info, - resource_tracker_pricing_units.c.created, - resource_tracker_pricing_units.c.modified, - resource_tracker_pricing_unit_costs.c.cost_per_unit.label( - "current_cost_per_unit" - ), - resource_tracker_pricing_unit_costs.c.pricing_unit_cost_id.label( - "current_cost_per_unit_id" - ), - ) - - async def list_pricing_units_by_pricing_plan( - self, - pricing_plan_id: PricingPlanId, - ) -> list[PricingUnitsDB]: - async with self.db_engine.begin() as conn: - query = ( - self._pricing_units_select_stmt() - .select_from( - resource_tracker_pricing_units.join( - resource_tracker_pricing_unit_costs, - ( - ( - resource_tracker_pricing_units.c.pricing_plan_id - == resource_tracker_pricing_unit_costs.c.pricing_plan_id - ) - & ( - resource_tracker_pricing_units.c.pricing_unit_id - == resource_tracker_pricing_unit_costs.c.pricing_unit_id - ) - ), - ) - ) - .where( - ( - resource_tracker_pricing_units.c.pricing_plan_id - == pricing_plan_id - ) - & (resource_tracker_pricing_unit_costs.c.valid_to.is_(None)) - ) - .order_by(resource_tracker_pricing_unit_costs.c.cost_per_unit.asc()) - ) - result = await conn.execute(query) - - return [PricingUnitsDB.model_validate(row) for row in result.fetchall()] - - async def get_valid_pricing_unit( - self, - product_name: ProductName, - pricing_plan_id: PricingPlanId, - pricing_unit_id: PricingUnitId, - ) -> PricingUnitsDB: - async with self.db_engine.begin() as conn: - query = ( - self._pricing_units_select_stmt() - .select_from( - resource_tracker_pricing_units.join( - resource_tracker_pricing_unit_costs, - ( - ( - resource_tracker_pricing_units.c.pricing_plan_id - == resource_tracker_pricing_unit_costs.c.pricing_plan_id - ) - & ( - resource_tracker_pricing_units.c.pricing_unit_id - == resource_tracker_pricing_unit_costs.c.pricing_unit_id - ) - ), - ).join( - resource_tracker_pricing_plans, - ( - resource_tracker_pricing_plans.c.pricing_plan_id - == resource_tracker_pricing_units.c.pricing_plan_id - ), - ) - ) - .where( - ( - resource_tracker_pricing_units.c.pricing_plan_id - == pricing_plan_id - ) - & ( - resource_tracker_pricing_units.c.pricing_unit_id - == pricing_unit_id - ) - & (resource_tracker_pricing_unit_costs.c.valid_to.is_(None)) - & (resource_tracker_pricing_plans.c.product_name == product_name) - ) - ) - result = await conn.execute(query) - - row = result.first() - if row is None: - raise PricingPlanAndPricingUnitCombinationDoesNotExistsDBError( - pricing_plan_id=pricing_plan_id, - pricing_unit_id=pricing_unit_id, - product_name=product_name, - ) - return PricingUnitsDB.model_validate(row) - - async def create_pricing_unit_with_cost( - self, data: PricingUnitWithCostCreate, pricing_plan_key: str - ) -> tuple[PricingUnitId, PricingUnitCostId]: - async with self.db_engine.begin() as conn: - # pricing units table - insert_stmt = ( - resource_tracker_pricing_units.insert() - .values( - pricing_plan_id=data.pricing_plan_id, - unit_name=data.unit_name, - unit_extra_info=data.unit_extra_info.model_dump(), - default=data.default, - specific_info=data.specific_info.model_dump(), - created=sa.func.now(), - modified=sa.func.now(), - ) - .returning(resource_tracker_pricing_units.c.pricing_unit_id) - ) - result = await conn.execute(insert_stmt) - row = result.first() - if row is None: - raise PricingUnitNotCreatedDBError(data=data) - _pricing_unit_id = row[0] - - # pricing unit cost table - insert_stmt = ( - resource_tracker_pricing_unit_costs.insert() - .values( - pricing_plan_id=data.pricing_plan_id, - pricing_plan_key=pricing_plan_key, - pricing_unit_id=_pricing_unit_id, - pricing_unit_name=data.unit_name, - cost_per_unit=data.cost_per_unit, - valid_from=sa.func.now(), - valid_to=None, - created=sa.func.now(), - comment=data.comment, - modified=sa.func.now(), - ) - .returning(resource_tracker_pricing_unit_costs.c.pricing_unit_cost_id) - ) - result = await conn.execute(insert_stmt) - row = result.first() - if row is None: - raise PricingUnitCostNotCreatedDBError(data=data) - _pricing_unit_cost_id = row[0] - - return (_pricing_unit_id, _pricing_unit_cost_id) - - async def update_pricing_unit_with_cost( - self, data: PricingUnitWithCostUpdate, pricing_plan_key: str - ) -> None: - async with self.db_engine.begin() as conn: - # pricing units table - update_stmt = ( - resource_tracker_pricing_units.update() - .values( - unit_name=data.unit_name, - unit_extra_info=data.unit_extra_info.model_dump(), - default=data.default, - specific_info=data.specific_info.model_dump(), - modified=sa.func.now(), - ) - .where( - resource_tracker_pricing_units.c.pricing_unit_id - == data.pricing_unit_id - ) - .returning(resource_tracker_pricing_units.c.pricing_unit_id) - ) - await conn.execute(update_stmt) - - # If price change, then we update pricing unit cost table - if data.pricing_unit_cost_update: - # Firstly we close previous price - update_stmt = ( - resource_tracker_pricing_unit_costs.update() - .values( - valid_to=sa.func.now(), # <-- Closing previous price - modified=sa.func.now(), - ) - .where( - resource_tracker_pricing_unit_costs.c.pricing_unit_id - == data.pricing_unit_id - ) - .returning(resource_tracker_pricing_unit_costs.c.pricing_unit_id) - ) - result = await conn.execute(update_stmt) - - # Then we create a new price - insert_stmt = ( - resource_tracker_pricing_unit_costs.insert() - .values( - pricing_plan_id=data.pricing_plan_id, - pricing_plan_key=pricing_plan_key, - pricing_unit_id=data.pricing_unit_id, - pricing_unit_name=data.unit_name, - cost_per_unit=data.pricing_unit_cost_update.cost_per_unit, - valid_from=sa.func.now(), - valid_to=None, # <-- New price is valid - created=sa.func.now(), - comment=data.pricing_unit_cost_update.comment, - modified=sa.func.now(), - ) - .returning( - resource_tracker_pricing_unit_costs.c.pricing_unit_cost_id - ) - ) - result = await conn.execute(insert_stmt) - row = result.first() - if row is None: - raise PricingUnitCostNotCreatedDBError(data=data) - - ################################# - # Pricing unit-costs - ################################# - - async def get_pricing_unit_cost_by_id( - self, pricing_unit_cost_id: PricingUnitCostId - ) -> PricingUnitCostsDB: - async with self.db_engine.begin() as conn: - query = sa.select( - resource_tracker_pricing_unit_costs.c.pricing_unit_cost_id, - resource_tracker_pricing_unit_costs.c.pricing_plan_id, - resource_tracker_pricing_unit_costs.c.pricing_plan_key, - resource_tracker_pricing_unit_costs.c.pricing_unit_id, - resource_tracker_pricing_unit_costs.c.pricing_unit_name, - resource_tracker_pricing_unit_costs.c.cost_per_unit, - resource_tracker_pricing_unit_costs.c.valid_from, - resource_tracker_pricing_unit_costs.c.valid_to, - resource_tracker_pricing_unit_costs.c.created, - resource_tracker_pricing_unit_costs.c.comment, - resource_tracker_pricing_unit_costs.c.modified, - ).where( - resource_tracker_pricing_unit_costs.c.pricing_unit_cost_id - == pricing_unit_cost_id - ) - result = await conn.execute(query) - - row = result.first() - if row is None: - raise PricingUnitCostDoesNotExistsDBError( - pricing_unit_cost_id=pricing_unit_cost_id - ) - return PricingUnitCostsDB.model_validate(row) diff --git a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/modules/db/service_runs_db.py b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/modules/db/service_runs_db.py new file mode 100644 index 00000000000..a4ea563803d --- /dev/null +++ b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/modules/db/service_runs_db.py @@ -0,0 +1,622 @@ +# pylint: disable=too-many-arguments +import logging +from datetime import datetime +from typing import cast + +import sqlalchemy as sa +from models_library.api_schemas_storage import S3BucketName +from models_library.products import ProductName +from models_library.resource_tracker import ( + CreditClassification, + CreditTransactionStatus, + ServiceRunId, + ServiceRunStatus, +) +from models_library.rest_ordering import OrderBy, OrderDirection +from models_library.users import UserID +from models_library.wallets import WalletID +from pydantic import PositiveInt +from simcore_postgres_database.models.projects_tags import projects_tags +from simcore_postgres_database.models.resource_tracker_credit_transactions import ( + resource_tracker_credit_transactions, +) +from simcore_postgres_database.models.resource_tracker_service_runs import ( + resource_tracker_service_runs, +) +from simcore_postgres_database.models.tags import tags +from simcore_postgres_database.utils_repos import transaction_context +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine + +from ....exceptions.errors import ServiceRunNotCreatedDBError +from ....models.service_runs import ( + OsparcCreditsAggregatedByServiceKeyDB, + ServiceRunCreate, + ServiceRunDB, + ServiceRunForCheckDB, + ServiceRunLastHeartbeatUpdate, + ServiceRunStoppedAtUpdate, + ServiceRunWithCreditsDB, +) + +_logger = logging.getLogger(__name__) + + +async def create_service_run( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + data: ServiceRunCreate, +) -> ServiceRunId: + async with transaction_context(engine, connection) as conn: + insert_stmt = ( + resource_tracker_service_runs.insert() + .values( + product_name=data.product_name, + service_run_id=data.service_run_id, + wallet_id=data.wallet_id, + wallet_name=data.wallet_name, + pricing_plan_id=data.pricing_plan_id, + pricing_unit_id=data.pricing_unit_id, + pricing_unit_cost_id=data.pricing_unit_cost_id, + pricing_unit_cost=data.pricing_unit_cost, + simcore_user_agent=data.simcore_user_agent, + user_id=data.user_id, + user_email=data.user_email, + project_id=f"{data.project_id}", + project_name=data.project_name, + node_id=f"{data.node_id}", + node_name=data.node_name, + parent_project_id=f"{data.parent_project_id}", + root_parent_project_id=f"{data.root_parent_project_id}", + root_parent_project_name=data.root_parent_project_name, + parent_node_id=f"{data.parent_node_id}", + root_parent_node_id=f"{data.root_parent_node_id}", + service_key=data.service_key, + service_version=data.service_version, + service_type=data.service_type, + service_resources=data.service_resources, + service_additional_metadata=data.service_additional_metadata, + started_at=data.started_at, + stopped_at=None, + service_run_status=ServiceRunStatus.RUNNING, + modified=sa.func.now(), + last_heartbeat_at=data.last_heartbeat_at, + ) + .returning(resource_tracker_service_runs.c.service_run_id) + ) + result = await conn.execute(insert_stmt) + row = result.first() + if row is None: + raise ServiceRunNotCreatedDBError(data=data) + return cast(ServiceRunId, row[0]) + + +async def update_service_run_last_heartbeat( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + data: ServiceRunLastHeartbeatUpdate, +) -> ServiceRunDB | None: + async with transaction_context(engine, connection) as conn: + update_stmt = ( + resource_tracker_service_runs.update() + .values( + modified=sa.func.now(), + last_heartbeat_at=data.last_heartbeat_at, + missed_heartbeat_counter=0, + ) + .where( + (resource_tracker_service_runs.c.service_run_id == data.service_run_id) + & ( + resource_tracker_service_runs.c.service_run_status + == ServiceRunStatus.RUNNING + ) + & ( + resource_tracker_service_runs.c.last_heartbeat_at + <= data.last_heartbeat_at + ) + ) + .returning(sa.literal_column("*")) + ) + result = await conn.execute(update_stmt) + row = result.first() + if row is None: + return None + return ServiceRunDB.model_validate(row) + + +async def update_service_run_stopped_at( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + data: ServiceRunStoppedAtUpdate, +) -> ServiceRunDB | None: + async with transaction_context(engine, connection) as conn: + update_stmt = ( + resource_tracker_service_runs.update() + .values( + modified=sa.func.now(), + stopped_at=data.stopped_at, + service_run_status=data.service_run_status, + service_run_status_msg=data.service_run_status_msg, + ) + .where( + (resource_tracker_service_runs.c.service_run_id == data.service_run_id) + & ( + resource_tracker_service_runs.c.service_run_status + == ServiceRunStatus.RUNNING + ) + ) + .returning(sa.literal_column("*")) + ) + result = await conn.execute(update_stmt) + row = result.first() + if row is None: + return None + return ServiceRunDB.model_validate(row) + + +async def get_service_run_by_id( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + service_run_id: ServiceRunId, +) -> ServiceRunDB | None: + async with transaction_context(engine, connection) as conn: + stmt = sa.select(resource_tracker_service_runs).where( + resource_tracker_service_runs.c.service_run_id == service_run_id + ) + result = await conn.execute(stmt) + row = result.first() + if row is None: + return None + return ServiceRunDB.model_validate(row) + + +_project_tags_subquery = ( + sa.select( + projects_tags.c.project_uuid_for_rut, + sa.func.array_agg(tags.c.name).label("project_tags"), + ) + .select_from(projects_tags.join(tags, projects_tags.c.tag_id == tags.c.id)) + .group_by(projects_tags.c.project_uuid_for_rut) +).subquery("project_tags_subquery") + + +async def list_service_runs_by_product_and_user_and_wallet( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + product_name: ProductName, + user_id: UserID | None, + wallet_id: WalletID | None, + offset: int, + limit: int, + service_run_status: ServiceRunStatus | None = None, + started_from: datetime | None = None, + started_until: datetime | None = None, + order_by: OrderBy | None = None, +) -> list[ServiceRunWithCreditsDB]: + async with transaction_context(engine, connection) as conn: + query = ( + sa.select( + resource_tracker_service_runs.c.product_name, + resource_tracker_service_runs.c.service_run_id, + resource_tracker_service_runs.c.wallet_id, + resource_tracker_service_runs.c.wallet_name, + resource_tracker_service_runs.c.pricing_plan_id, + resource_tracker_service_runs.c.pricing_unit_id, + resource_tracker_service_runs.c.pricing_unit_cost_id, + resource_tracker_service_runs.c.pricing_unit_cost, + resource_tracker_service_runs.c.user_id, + resource_tracker_service_runs.c.user_email, + resource_tracker_service_runs.c.project_id, + resource_tracker_service_runs.c.project_name, + resource_tracker_service_runs.c.node_id, + resource_tracker_service_runs.c.node_name, + resource_tracker_service_runs.c.parent_project_id, + resource_tracker_service_runs.c.root_parent_project_id, + resource_tracker_service_runs.c.root_parent_project_name, + resource_tracker_service_runs.c.parent_node_id, + resource_tracker_service_runs.c.root_parent_node_id, + resource_tracker_service_runs.c.service_key, + resource_tracker_service_runs.c.service_version, + resource_tracker_service_runs.c.service_type, + resource_tracker_service_runs.c.service_resources, + resource_tracker_service_runs.c.started_at, + resource_tracker_service_runs.c.stopped_at, + resource_tracker_service_runs.c.service_run_status, + resource_tracker_service_runs.c.modified, + resource_tracker_service_runs.c.last_heartbeat_at, + resource_tracker_service_runs.c.service_run_status_msg, + resource_tracker_service_runs.c.missed_heartbeat_counter, + resource_tracker_credit_transactions.c.osparc_credits, + resource_tracker_credit_transactions.c.transaction_status, + sa.func.coalesce( + _project_tags_subquery.c.project_tags, + sa.cast(sa.text("'{}'"), sa.ARRAY(sa.String)), + ).label("project_tags"), + ) + .select_from( + resource_tracker_service_runs.join( + resource_tracker_credit_transactions, + ( + resource_tracker_service_runs.c.product_name + == resource_tracker_credit_transactions.c.product_name + ) + & ( + resource_tracker_service_runs.c.service_run_id + == resource_tracker_credit_transactions.c.service_run_id + ), + isouter=True, + ).join( + _project_tags_subquery, + resource_tracker_service_runs.c.project_id + == _project_tags_subquery.c.project_uuid_for_rut, + isouter=True, + ) + ) + .where(resource_tracker_service_runs.c.product_name == product_name) + .offset(offset) + .limit(limit) + ) + + if user_id: + query = query.where(resource_tracker_service_runs.c.user_id == user_id) + if wallet_id: + query = query.where(resource_tracker_service_runs.c.wallet_id == wallet_id) + if service_run_status: + query = query.where( + resource_tracker_service_runs.c.service_run_status == service_run_status + ) + if started_from: + query = query.where( + sa.func.DATE(resource_tracker_service_runs.c.started_at) + >= started_from.date() + ) + if started_until: + query = query.where( + sa.func.DATE(resource_tracker_service_runs.c.started_at) + <= started_until.date() + ) + + if order_by: + if order_by.direction == OrderDirection.ASC: + query = query.order_by(sa.asc(order_by.field)) + else: + query = query.order_by(sa.desc(order_by.field)) + else: + # Default ordering + query = query.order_by(resource_tracker_service_runs.c.started_at.desc()) + + result = await conn.execute(query) + + return [ServiceRunWithCreditsDB.model_validate(row) for row in result.fetchall()] + + +async def get_osparc_credits_aggregated_by_service( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + product_name: ProductName, + user_id: UserID | None, + wallet_id: WalletID, + offset: int, + limit: int, + started_from: datetime | None = None, + started_until: datetime | None = None, +) -> tuple[int, list[OsparcCreditsAggregatedByServiceKeyDB]]: + async with transaction_context(engine, connection) as conn: + base_query = ( + sa.select( + resource_tracker_service_runs.c.service_key, + sa.func.SUM( + resource_tracker_credit_transactions.c.osparc_credits + ).label("osparc_credits"), + sa.func.SUM( + sa.func.round( + ( + sa.func.extract( + "epoch", + resource_tracker_service_runs.c.stopped_at, + ) + - sa.func.extract( + "epoch", + resource_tracker_service_runs.c.started_at, + ) + ) + / 3600, + 2, + ) + ).label("running_time_in_hours"), + ) + .select_from( + resource_tracker_service_runs.join( + resource_tracker_credit_transactions, + ( + resource_tracker_service_runs.c.product_name + == resource_tracker_credit_transactions.c.product_name + ) + & ( + resource_tracker_service_runs.c.service_run_id + == resource_tracker_credit_transactions.c.service_run_id + ), + isouter=True, + ) + ) + .where( + (resource_tracker_service_runs.c.product_name == product_name) + & ( + resource_tracker_credit_transactions.c.transaction_status + == CreditTransactionStatus.BILLED + ) + & ( + resource_tracker_credit_transactions.c.transaction_classification + == CreditClassification.DEDUCT_SERVICE_RUN + ) + & (resource_tracker_credit_transactions.c.wallet_id == wallet_id) + ) + .group_by(resource_tracker_service_runs.c.service_key) + ) + + if user_id: + base_query = base_query.where( + resource_tracker_service_runs.c.user_id == user_id + ) + if started_from: + base_query = base_query.where( + sa.func.DATE(resource_tracker_service_runs.c.started_at) + >= started_from.date() + ) + if started_until: + base_query = base_query.where( + sa.func.DATE(resource_tracker_service_runs.c.started_at) + <= started_until.date() + ) + + subquery = base_query.subquery() + count_query = sa.select(sa.func.count()).select_from(subquery) + count_result = await conn.execute(count_query) + + # Default ordering and pagination + list_query = ( + base_query.order_by(resource_tracker_service_runs.c.service_key.asc()) + .offset(offset) + .limit(limit) + ) + list_result = await conn.execute(list_query) + + return ( + cast(int, count_result.scalar()), + [ + OsparcCreditsAggregatedByServiceKeyDB.model_validate(row) + for row in list_result.fetchall() + ], + ) + + +async def export_service_runs_table_to_s3( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + product_name: ProductName, + s3_bucket_name: S3BucketName, + s3_key: str, + s3_region: str, + user_id: UserID | None, + wallet_id: WalletID | None, + started_from: datetime | None = None, + started_until: datetime | None = None, + order_by: OrderBy | None = None, +): + async with transaction_context(engine, connection) as conn: + query = ( + sa.select( + resource_tracker_service_runs.c.product_name, + resource_tracker_service_runs.c.service_run_id, + resource_tracker_service_runs.c.wallet_name, + resource_tracker_service_runs.c.user_email, + resource_tracker_service_runs.c.root_parent_project_name.label( + "project_name" + ), + resource_tracker_service_runs.c.node_name, + resource_tracker_service_runs.c.service_key, + resource_tracker_service_runs.c.service_version, + resource_tracker_service_runs.c.service_type, + resource_tracker_service_runs.c.started_at, + resource_tracker_service_runs.c.stopped_at, + resource_tracker_credit_transactions.c.osparc_credits, + resource_tracker_credit_transactions.c.transaction_status, + sa.func.coalesce( + _project_tags_subquery.c.project_tags, + sa.cast(sa.text("'{}'"), sa.ARRAY(sa.String)), + ).label("project_tags"), + ) + .select_from( + resource_tracker_service_runs.join( + resource_tracker_credit_transactions, + resource_tracker_service_runs.c.service_run_id + == resource_tracker_credit_transactions.c.service_run_id, + isouter=True, + ).join( + _project_tags_subquery, + resource_tracker_service_runs.c.project_id + == _project_tags_subquery.c.project_uuid_for_rut, + isouter=True, + ) + ) + .where(resource_tracker_service_runs.c.product_name == product_name) + ) + + if user_id: + query = query.where(resource_tracker_service_runs.c.user_id == user_id) + if wallet_id: + query = query.where(resource_tracker_service_runs.c.wallet_id == wallet_id) + if started_from: + query = query.where( + sa.func.DATE(resource_tracker_service_runs.c.started_at) + >= started_from.date() + ) + if started_until: + query = query.where( + sa.func.DATE(resource_tracker_service_runs.c.started_at) + <= started_until.date() + ) + + if order_by: + if order_by.direction == OrderDirection.ASC: + query = query.order_by(sa.asc(order_by.field)) + else: + query = query.order_by(sa.desc(order_by.field)) + else: + # Default ordering + query = query.order_by(resource_tracker_service_runs.c.started_at.desc()) + + compiled_query = ( + str(query.compile(compile_kwargs={"literal_binds": True})) + .replace("\n", "") + .replace("'", "''") + ) + + result = await conn.execute( + sa.DDL( + f""" + SELECT * from aws_s3.query_export_to_s3('{compiled_query}', + aws_commons.create_s3_uri('{s3_bucket_name}', '{s3_key}', '{s3_region}'), 'format csv, HEADER true'); + """ # noqa: S608 + ) + ) + row = result.first() + assert row + _logger.info( + "Rows uploaded %s, Files uploaded %s, Bytes uploaded %s", + row[0], + row[1], + row[2], + ) + + +async def total_service_runs_by_product_and_user_and_wallet( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + product_name: ProductName, + user_id: UserID | None, + wallet_id: WalletID | None, + service_run_status: ServiceRunStatus | None = None, + started_from: datetime | None = None, + started_until: datetime | None = None, +) -> PositiveInt: + async with transaction_context(engine, connection) as conn: + query = ( + sa.select(sa.func.count()) + .select_from(resource_tracker_service_runs) + .where(resource_tracker_service_runs.c.product_name == product_name) + ) + + if user_id: + query = query.where(resource_tracker_service_runs.c.user_id == user_id) + if wallet_id: + query = query.where(resource_tracker_service_runs.c.wallet_id == wallet_id) + if started_from: + query = query.where( + sa.func.DATE(resource_tracker_service_runs.c.started_at) + >= started_from.date() + ) + if started_until: + query = query.where( + sa.func.DATE(resource_tracker_service_runs.c.started_at) + <= started_until.date() + ) + if service_run_status: + query = query.where( + resource_tracker_service_runs.c.service_run_status == service_run_status + ) + + result = await conn.execute(query) + row = result.first() + return cast(PositiveInt, row[0]) if row else 0 + + +### For Background check purpose: + + +async def list_service_runs_with_running_status_across_all_products( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + offset: int, + limit: int, +) -> list[ServiceRunForCheckDB]: + async with transaction_context(engine, connection) as conn: + query = ( + sa.select( + resource_tracker_service_runs.c.service_run_id, + resource_tracker_service_runs.c.last_heartbeat_at, + resource_tracker_service_runs.c.missed_heartbeat_counter, + resource_tracker_service_runs.c.modified, + ) + .where( + resource_tracker_service_runs.c.service_run_status + == ServiceRunStatus.RUNNING + ) + .order_by(resource_tracker_service_runs.c.started_at.desc()) # NOTE: + .offset(offset) + .limit(limit) + ) + result = await conn.execute(query) + + return [ServiceRunForCheckDB.model_validate(row) for row in result.fetchall()] + + +async def total_service_runs_with_running_status_across_all_products( + engine: AsyncEngine, connection: AsyncConnection | None = None +) -> PositiveInt: + async with transaction_context(engine, connection) as conn: + query = ( + sa.select(sa.func.count()) + .select_from(resource_tracker_service_runs) + .where( + resource_tracker_service_runs.c.service_run_status + == ServiceRunStatus.RUNNING + ) + ) + result = await conn.execute(query) + row = result.first() + return cast(PositiveInt, row[0]) if row else 0 + + +async def update_service_missed_heartbeat_counter( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + service_run_id: ServiceRunId, + last_heartbeat_at: datetime, + missed_heartbeat_counter: int, +) -> ServiceRunDB | None: + async with transaction_context(engine, connection) as conn: + update_stmt = ( + resource_tracker_service_runs.update() + .values( + modified=sa.func.now(), + missed_heartbeat_counter=missed_heartbeat_counter, + ) + .where( + (resource_tracker_service_runs.c.service_run_id == service_run_id) + & ( + resource_tracker_service_runs.c.service_run_status + == ServiceRunStatus.RUNNING + ) + & ( + resource_tracker_service_runs.c.last_heartbeat_at + == last_heartbeat_at + ) + ) + .returning(sa.literal_column("*")) + ) + + result = await conn.execute(update_stmt) + row = result.first() + if row is None: + return None + return ServiceRunDB.model_validate(row) diff --git a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/pricing_plans.py b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/pricing_plans.py index 9c3dc38bef3..ed34c334187 100644 --- a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/pricing_plans.py +++ b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/pricing_plans.py @@ -14,12 +14,13 @@ ) from models_library.services import ServiceKey, ServiceVersion from pydantic import TypeAdapter +from sqlalchemy.ext.asyncio import AsyncEngine -from ..api.rest.dependencies import get_repository +from ..api.rest.dependencies import get_resource_tracker_db_engine from ..exceptions.errors import PricingPlanNotFoundForServiceError from ..models.pricing_plans import PricingPlansDB, PricingPlanToServiceDB from ..models.pricing_units import PricingUnitsDB -from .modules.db.repositories.resource_tracker import ResourceTrackerRepository +from .modules.db import pricing_plans_db async def _create_pricing_plan_get( @@ -52,12 +53,15 @@ async def get_service_default_pricing_plan( product_name: ProductName, service_key: ServiceKey, service_version: ServiceVersion, - resource_tracker_repo: Annotated[ - ResourceTrackerRepository, Depends(get_repository(ResourceTrackerRepository)) - ], + db_engine: Annotated[AsyncEngine, Depends(get_resource_tracker_db_engine)], ) -> PricingPlanGet: - active_service_pricing_plans = await resource_tracker_repo.list_active_service_pricing_plans_by_product_and_service( - product_name, service_key, service_version + active_service_pricing_plans = ( + await pricing_plans_db.list_active_service_pricing_plans_by_product_and_service( + db_engine, + product_name=product_name, + service_key=service_key, + service_version=service_version, + ) ) default_pricing_plan = None @@ -71,10 +75,8 @@ async def get_service_default_pricing_plan( service_key=service_key, service_version=service_version ) - pricing_plan_unit_db = ( - await resource_tracker_repo.list_pricing_units_by_pricing_plan( - pricing_plan_id=default_pricing_plan.pricing_plan_id - ) + pricing_plan_unit_db = await pricing_plans_db.list_pricing_units_by_pricing_plan( + db_engine, pricing_plan_id=default_pricing_plan.pricing_plan_id ) return await _create_pricing_plan_get(default_pricing_plan, pricing_plan_unit_db) @@ -83,14 +85,12 @@ async def get_service_default_pricing_plan( async def list_connected_services_to_pricing_plan_by_pricing_plan( product_name: ProductName, pricing_plan_id: PricingPlanId, - resource_tracker_repo: Annotated[ - ResourceTrackerRepository, Depends(get_repository(ResourceTrackerRepository)) - ], + db_engine: Annotated[AsyncEngine, Depends(get_resource_tracker_db_engine)], ): output_list: list[ PricingPlanToServiceDB - ] = await resource_tracker_repo.list_connected_services_to_pricing_plan_by_pricing_plan( - product_name=product_name, pricing_plan_id=pricing_plan_id + ] = await pricing_plans_db.list_connected_services_to_pricing_plan_by_pricing_plan( + db_engine, product_name=product_name, pricing_plan_id=pricing_plan_id ) return [ TypeAdapter(PricingPlanToServiceGet).validate_python(item.model_dump()) @@ -103,12 +103,11 @@ async def connect_service_to_pricing_plan( pricing_plan_id: PricingPlanId, service_key: ServiceKey, service_version: ServiceVersion, - resource_tracker_repo: Annotated[ - ResourceTrackerRepository, Depends(get_repository(ResourceTrackerRepository)) - ], + db_engine: Annotated[AsyncEngine, Depends(get_resource_tracker_db_engine)], ) -> PricingPlanToServiceGet: output: PricingPlanToServiceDB = ( - await resource_tracker_repo.upsert_service_to_pricing_plan( + await pricing_plans_db.upsert_service_to_pricing_plan( + db_engine, product_name=product_name, pricing_plan_id=pricing_plan_id, service_key=service_key, @@ -120,14 +119,12 @@ async def connect_service_to_pricing_plan( async def list_pricing_plans_by_product( product_name: ProductName, - resource_tracker_repo: Annotated[ - ResourceTrackerRepository, Depends(get_repository(ResourceTrackerRepository)) - ], + db_engine: Annotated[AsyncEngine, Depends(get_resource_tracker_db_engine)], ) -> list[PricingPlanGet]: pricing_plans_list_db: list[ PricingPlansDB - ] = await resource_tracker_repo.list_pricing_plans_by_product( - product_name=product_name + ] = await pricing_plans_db.list_pricing_plans_by_product( + db_engine, product_name=product_name ) return [ PricingPlanGet( @@ -147,32 +144,24 @@ async def list_pricing_plans_by_product( async def get_pricing_plan( product_name: ProductName, pricing_plan_id: PricingPlanId, - resource_tracker_repo: Annotated[ - ResourceTrackerRepository, Depends(get_repository(ResourceTrackerRepository)) - ], + db_engine: Annotated[AsyncEngine, Depends(get_resource_tracker_db_engine)], ) -> PricingPlanGet: - pricing_plan_db = await resource_tracker_repo.get_pricing_plan( - product_name=product_name, pricing_plan_id=pricing_plan_id + pricing_plan_db = await pricing_plans_db.get_pricing_plan( + db_engine, product_name=product_name, pricing_plan_id=pricing_plan_id ) - pricing_plan_unit_db = ( - await resource_tracker_repo.list_pricing_units_by_pricing_plan( - pricing_plan_id=pricing_plan_db.pricing_plan_id - ) + pricing_plan_unit_db = await pricing_plans_db.list_pricing_units_by_pricing_plan( + db_engine, pricing_plan_id=pricing_plan_db.pricing_plan_id ) return await _create_pricing_plan_get(pricing_plan_db, pricing_plan_unit_db) async def create_pricing_plan( data: PricingPlanCreate, - resource_tracker_repo: Annotated[ - ResourceTrackerRepository, Depends(get_repository(ResourceTrackerRepository)) - ], + db_engine: Annotated[AsyncEngine, Depends(get_resource_tracker_db_engine)], ) -> PricingPlanGet: - pricing_plan_db = await resource_tracker_repo.create_pricing_plan(data=data) - pricing_plan_unit_db = ( - await resource_tracker_repo.list_pricing_units_by_pricing_plan( - pricing_plan_id=pricing_plan_db.pricing_plan_id - ) + pricing_plan_db = await pricing_plans_db.create_pricing_plan(db_engine, data=data) + pricing_plan_unit_db = await pricing_plans_db.list_pricing_units_by_pricing_plan( + db_engine, pricing_plan_id=pricing_plan_db.pricing_plan_id ) return await _create_pricing_plan_get(pricing_plan_db, pricing_plan_unit_db) @@ -180,24 +169,20 @@ async def create_pricing_plan( async def update_pricing_plan( product_name: ProductName, data: PricingPlanUpdate, - resource_tracker_repo: Annotated[ - ResourceTrackerRepository, Depends(get_repository(ResourceTrackerRepository)) - ], + db_engine: Annotated[AsyncEngine, Depends(get_resource_tracker_db_engine)], ) -> PricingPlanGet: # Check whether pricing plan exists - pricing_plan_db = await resource_tracker_repo.get_pricing_plan( - product_name=product_name, pricing_plan_id=data.pricing_plan_id + pricing_plan_db = await pricing_plans_db.get_pricing_plan( + db_engine, product_name=product_name, pricing_plan_id=data.pricing_plan_id ) # Update pricing plan - pricing_plan_updated_db = await resource_tracker_repo.update_pricing_plan( - product_name=product_name, data=data + pricing_plan_updated_db = await pricing_plans_db.update_pricing_plan( + db_engine, product_name=product_name, data=data ) if pricing_plan_updated_db: pricing_plan_db = pricing_plan_updated_db - pricing_plan_unit_db = ( - await resource_tracker_repo.list_pricing_units_by_pricing_plan( - pricing_plan_id=pricing_plan_db.pricing_plan_id - ) + pricing_plan_unit_db = await pricing_plans_db.list_pricing_units_by_pricing_plan( + db_engine, pricing_plan_id=pricing_plan_db.pricing_plan_id ) return await _create_pricing_plan_get(pricing_plan_db, pricing_plan_unit_db) diff --git a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/pricing_units.py b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/pricing_units.py index f2aee53dd80..0a1e72cad65 100644 --- a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/pricing_units.py +++ b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/pricing_units.py @@ -11,21 +11,23 @@ PricingUnitWithCostCreate, PricingUnitWithCostUpdate, ) +from sqlalchemy.ext.asyncio import AsyncEngine -from ..api.rest.dependencies import get_repository -from .modules.db.repositories.resource_tracker import ResourceTrackerRepository +from ..api.rest.dependencies import get_resource_tracker_db_engine +from .modules.db import pricing_plans_db async def get_pricing_unit( product_name: ProductName, pricing_plan_id: PricingPlanId, pricing_unit_id: PricingUnitId, - resource_tracker_repo: Annotated[ - ResourceTrackerRepository, Depends(get_repository(ResourceTrackerRepository)) - ], + db_engine: Annotated[AsyncEngine, Depends(get_resource_tracker_db_engine)], ) -> PricingUnitGet: - pricing_unit = await resource_tracker_repo.get_valid_pricing_unit( - product_name, pricing_plan_id, pricing_unit_id + pricing_unit = await pricing_plans_db.get_valid_pricing_unit( + db_engine, + product_name=product_name, + pricing_plan_id=pricing_plan_id, + pricing_unit_id=pricing_unit_id, ) return PricingUnitGet( @@ -42,21 +44,22 @@ async def get_pricing_unit( async def create_pricing_unit( product_name: ProductName, data: PricingUnitWithCostCreate, - resource_tracker_repo: Annotated[ - ResourceTrackerRepository, Depends(get_repository(ResourceTrackerRepository)) - ], + db_engine: Annotated[AsyncEngine, Depends(get_resource_tracker_db_engine)], ) -> PricingUnitGet: # Check whether pricing plan exists - pricing_plan_db = await resource_tracker_repo.get_pricing_plan( - product_name=product_name, pricing_plan_id=data.pricing_plan_id + pricing_plan_db = await pricing_plans_db.get_pricing_plan( + db_engine, product_name=product_name, pricing_plan_id=data.pricing_plan_id ) # Create new pricing unit - pricing_unit_id, _ = await resource_tracker_repo.create_pricing_unit_with_cost( - data=data, pricing_plan_key=pricing_plan_db.pricing_plan_key + pricing_unit_id, _ = await pricing_plans_db.create_pricing_unit_with_cost( + db_engine, data=data, pricing_plan_key=pricing_plan_db.pricing_plan_key ) - pricing_unit = await resource_tracker_repo.get_valid_pricing_unit( - product_name, data.pricing_plan_id, pricing_unit_id + pricing_unit = await pricing_plans_db.get_valid_pricing_unit( + db_engine, + product_name=product_name, + pricing_plan_id=data.pricing_plan_id, + pricing_unit_id=pricing_unit_id, ) return PricingUnitGet( pricing_unit_id=pricing_unit.pricing_unit_id, @@ -72,26 +75,30 @@ async def create_pricing_unit( async def update_pricing_unit( product_name: ProductName, data: PricingUnitWithCostUpdate, - resource_tracker_repo: Annotated[ - ResourceTrackerRepository, Depends(get_repository(ResourceTrackerRepository)) - ], + db_engine: Annotated[AsyncEngine, Depends(get_resource_tracker_db_engine)], ) -> PricingUnitGet: # Check whether pricing unit exists - await resource_tracker_repo.get_valid_pricing_unit( - product_name, data.pricing_plan_id, data.pricing_unit_id + await pricing_plans_db.get_valid_pricing_unit( + db_engine, + product_name=product_name, + pricing_plan_id=data.pricing_plan_id, + pricing_unit_id=data.pricing_unit_id, ) # Get pricing plan - pricing_plan_db = await resource_tracker_repo.get_pricing_plan( - product_name, data.pricing_plan_id + pricing_plan_db = await pricing_plans_db.get_pricing_plan( + db_engine, product_name=product_name, pricing_plan_id=data.pricing_plan_id ) # Update pricing unit and cost - await resource_tracker_repo.update_pricing_unit_with_cost( - data=data, pricing_plan_key=pricing_plan_db.pricing_plan_key + await pricing_plans_db.update_pricing_unit_with_cost( + db_engine, data=data, pricing_plan_key=pricing_plan_db.pricing_plan_key ) - pricing_unit = await resource_tracker_repo.get_valid_pricing_unit( - product_name, data.pricing_plan_id, data.pricing_unit_id + pricing_unit = await pricing_plans_db.get_valid_pricing_unit( + db_engine, + product_name=product_name, + pricing_plan_id=data.pricing_plan_id, + pricing_unit_id=data.pricing_unit_id, ) return PricingUnitGet( pricing_unit_id=pricing_unit.pricing_unit_id, diff --git a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/process_message_running_service.py b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/process_message_running_service.py index 4907c84ecb1..8300ede8283 100644 --- a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/process_message_running_service.py +++ b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/process_message_running_service.py @@ -21,6 +21,7 @@ ) from models_library.services import ServiceType from pydantic import TypeAdapter +from sqlalchemy.ext.asyncio import AsyncEngine from ..models.credit_transactions import ( CreditTransactionCreate, @@ -32,7 +33,7 @@ ServiceRunLastHeartbeatUpdate, ServiceRunStoppedAtUpdate, ) -from .modules.db.repositories.resource_tracker import ResourceTrackerRepository +from .modules.db import credit_transactions_db, pricing_plans_db, service_runs_db from .modules.rabbitmq import RabbitMQClient, get_rabbitmq_client from .utils import ( compute_service_run_credit_costs, @@ -53,24 +54,22 @@ async def process_message(app: FastAPI, data: bytes) -> bool: rabbit_message.message_type, rabbit_message.service_run_id, ) - resource_tracker_repo: ResourceTrackerRepository = ResourceTrackerRepository( - db_engine=app.state.engine - ) + _db_engine = app.state.engine rabbitmq_client = get_rabbitmq_client(app) await RABBIT_MSG_TYPE_TO_PROCESS_HANDLER[rabbit_message.message_type]( - resource_tracker_repo, rabbit_message, rabbitmq_client + _db_engine, rabbit_message, rabbitmq_client ) return True async def _process_start_event( - resource_tracker_repo: ResourceTrackerRepository, + db_engine: AsyncEngine, msg: RabbitResourceTrackingStartedMessage, rabbitmq_client: RabbitMQClient, ): - service_run_db = await resource_tracker_repo.get_service_run_by_id( - service_run_id=msg.service_run_id + service_run_db = await service_runs_db.get_service_run_by_id( + db_engine, service_run_id=msg.service_run_id ) if service_run_db: # NOTE: After we find out why sometimes RUT recieves multiple start events and fix it, we can change it to log level `error` @@ -90,8 +89,8 @@ async def _process_start_event( ) pricing_unit_cost = None if msg.pricing_unit_cost_id: - pricing_unit_cost_db = await resource_tracker_repo.get_pricing_unit_cost_by_id( - pricing_unit_cost_id=msg.pricing_unit_cost_id + pricing_unit_cost_db = await pricing_plans_db.get_pricing_unit_cost_by_id( + db_engine, pricing_unit_cost_id=msg.pricing_unit_cost_id ) pricing_unit_cost = pricing_unit_cost_db.cost_per_unit @@ -125,7 +124,9 @@ async def _process_start_event( service_run_status=ServiceRunStatus.RUNNING, last_heartbeat_at=msg.created_at, ) - service_run_id = await resource_tracker_repo.create_service_run(create_service_run) + service_run_id = await service_runs_db.create_service_run( + db_engine, data=create_service_run + ) if msg.wallet_id and msg.wallet_name: transaction_create = CreditTransactionCreate( @@ -145,21 +146,23 @@ async def _process_start_event( created_at=msg.created_at, last_heartbeat_at=msg.created_at, ) - await resource_tracker_repo.create_credit_transaction(transaction_create) + await credit_transactions_db.create_credit_transaction( + db_engine, data=transaction_create + ) # Publish wallet total credits to RabbitMQ await sum_credit_transactions_and_publish_to_rabbitmq( - resource_tracker_repo, rabbitmq_client, msg.product_name, msg.wallet_id + db_engine, rabbitmq_client, msg.product_name, msg.wallet_id ) async def _process_heartbeat_event( - resource_tracker_repo: ResourceTrackerRepository, + db_engine: AsyncEngine, msg: RabbitResourceTrackingHeartbeatMessage, rabbitmq_client: RabbitMQClient, ): - service_run_db = await resource_tracker_repo.get_service_run_by_id( - service_run_id=msg.service_run_id + service_run_db = await service_runs_db.get_service_run_by_id( + db_engine, service_run_id=msg.service_run_id ) if not service_run_db: _logger.error( @@ -181,8 +184,8 @@ async def _process_heartbeat_event( update_service_run_last_heartbeat = ServiceRunLastHeartbeatUpdate( service_run_id=msg.service_run_id, last_heartbeat_at=msg.created_at ) - running_service = await resource_tracker_repo.update_service_run_last_heartbeat( - update_service_run_last_heartbeat + running_service = await service_runs_db.update_service_run_last_heartbeat( + db_engine, data=update_service_run_last_heartbeat ) if running_service is None: _logger.info("Nothing to update: %s", msg) @@ -201,19 +204,19 @@ async def _process_heartbeat_event( osparc_credits=make_negative(computed_credits), last_heartbeat_at=msg.created_at, ) - await resource_tracker_repo.update_credit_transaction_credits( - update_credit_transaction + await credit_transactions_db.update_credit_transaction_credits( + db_engine, data=update_credit_transaction ) # Publish wallet total credits to RabbitMQ wallet_total_credits = await sum_credit_transactions_and_publish_to_rabbitmq( - resource_tracker_repo, + db_engine, rabbitmq_client, running_service.product_name, running_service.wallet_id, ) if wallet_total_credits.available_osparc_credits < CreditsLimit.OUT_OF_CREDITS: await publish_to_rabbitmq_wallet_credits_limit_reached( - resource_tracker_repo, + db_engine, rabbitmq_client, product_name=running_service.product_name, wallet_id=running_service.wallet_id, @@ -223,12 +226,12 @@ async def _process_heartbeat_event( async def _process_stop_event( - resource_tracker_repo: ResourceTrackerRepository, + db_engine: AsyncEngine, msg: RabbitResourceTrackingStoppedMessage, rabbitmq_client: RabbitMQClient, ): - service_run_db = await resource_tracker_repo.get_service_run_by_id( - service_run_id=msg.service_run_id + service_run_db = await service_runs_db.get_service_run_by_id( + db_engine, service_run_id=msg.service_run_id ) if not service_run_db: # NOTE: ANE/MD discussed. When the RUT receives a stop event and has not received before any start or heartbeat event, it probably means that @@ -262,8 +265,8 @@ async def _process_stop_event( service_run_status_msg=_run_status_msg, ) - running_service = await resource_tracker_repo.update_service_run_stopped_at( - update_service_run_stopped_at + running_service = await service_runs_db.update_service_run_stopped_at( + db_engine, data=update_service_run_stopped_at ) if running_service is None: @@ -287,12 +290,12 @@ async def _process_stop_event( else CreditTransactionStatus.NOT_BILLED ), ) - await resource_tracker_repo.update_credit_transaction_credits_and_status( - update_credit_transaction + await credit_transactions_db.update_credit_transaction_credits_and_status( + db_engine, data=update_credit_transaction ) # Publish wallet total credits to RabbitMQ await sum_credit_transactions_and_publish_to_rabbitmq( - resource_tracker_repo, + db_engine, rabbitmq_client, running_service.product_name, running_service.wallet_id, diff --git a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/service_runs.py b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/service_runs.py index fff896c8ec0..b4d9127733e 100644 --- a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/service_runs.py +++ b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/service_runs.py @@ -19,9 +19,10 @@ from models_library.users import UserID from models_library.wallets import WalletID from pydantic import AnyUrl, PositiveInt, TypeAdapter +from sqlalchemy.ext.asyncio import AsyncEngine from ..models.service_runs import ServiceRunWithCreditsDB -from .modules.db.repositories.resource_tracker import ResourceTrackerRepository +from .modules.db import service_runs_db _PRESIGNED_LINK_EXPIRATION_SEC = 7200 @@ -29,7 +30,7 @@ async def list_service_runs( user_id: UserID, product_name: ProductName, - resource_tracker_repo: ResourceTrackerRepository, + db_engine: AsyncEngine, limit: int = 20, offset: int = 0, wallet_id: WalletID | None = None, @@ -45,17 +46,21 @@ async def list_service_runs( # Situation when we want to see all usage of a specific user (ex. for Non billable product) if wallet_id is None and access_all_wallet_usage is False: - total_service_runs: PositiveInt = await resource_tracker_repo.total_service_runs_by_product_and_user_and_wallet( - product_name, - user_id=user_id, - wallet_id=None, - started_from=started_from, - started_until=started_until, + total_service_runs: PositiveInt = ( + await service_runs_db.total_service_runs_by_product_and_user_and_wallet( + db_engine, + product_name=product_name, + user_id=user_id, + wallet_id=None, + started_from=started_from, + started_until=started_until, + ) ) service_runs_db_model: list[ ServiceRunWithCreditsDB - ] = await resource_tracker_repo.list_service_runs_by_product_and_user_and_wallet( - product_name, + ] = await service_runs_db.list_service_runs_by_product_and_user_and_wallet( + db_engine, + product_name=product_name, user_id=user_id, wallet_id=None, offset=offset, @@ -66,8 +71,9 @@ async def list_service_runs( ) # Situation when accountant user can see all users usage of the wallet elif wallet_id and access_all_wallet_usage is True: - total_service_runs: PositiveInt = await resource_tracker_repo.total_service_runs_by_product_and_user_and_wallet( # type: ignore[no-redef] - product_name, + total_service_runs: PositiveInt = await service_runs_db.total_service_runs_by_product_and_user_and_wallet( # type: ignore[no-redef] + db_engine, + product_name=product_name, user_id=None, wallet_id=wallet_id, started_from=started_from, @@ -75,8 +81,9 @@ async def list_service_runs( ) service_runs_db_model: list[ # type: ignore[no-redef] ServiceRunWithCreditsDB - ] = await resource_tracker_repo.list_service_runs_by_product_and_user_and_wallet( - product_name, + ] = await service_runs_db.list_service_runs_by_product_and_user_and_wallet( + db_engine, + product_name=product_name, user_id=None, wallet_id=wallet_id, offset=offset, @@ -87,8 +94,9 @@ async def list_service_runs( ) # Situation when regular user can see only his usage of the wallet elif wallet_id and access_all_wallet_usage is False: - total_service_runs: PositiveInt = await resource_tracker_repo.total_service_runs_by_product_and_user_and_wallet( # type: ignore[no-redef] - product_name, + total_service_runs: PositiveInt = await service_runs_db.total_service_runs_by_product_and_user_and_wallet( # type: ignore[no-redef] + db_engine, + product_name=product_name, user_id=user_id, wallet_id=wallet_id, started_from=started_from, @@ -96,8 +104,9 @@ async def list_service_runs( ) service_runs_db_model: list[ # type: ignore[no-redef] ServiceRunWithCreditsDB - ] = await resource_tracker_repo.list_service_runs_by_product_and_user_and_wallet( - product_name, + ] = await service_runs_db.list_service_runs_by_product_and_user_and_wallet( + db_engine, + product_name=product_name, user_id=user_id, wallet_id=wallet_id, offset=offset, @@ -147,7 +156,7 @@ async def export_service_runs( s3_region: str, user_id: UserID, product_name: ProductName, - resource_tracker_repo: ResourceTrackerRepository, + db_engine: AsyncEngine, wallet_id: WalletID | None = None, access_all_wallet_usage: bool = False, order_by: OrderBy | None = None, @@ -165,7 +174,8 @@ async def export_service_runs( ) # Export CSV to S3 - await resource_tracker_repo.export_service_runs_table_to_s3( + await service_runs_db.export_service_runs_table_to_s3( + db_engine, product_name=product_name, s3_bucket_name=s3_bucket_name, s3_key=s3_object_key, @@ -188,7 +198,7 @@ async def export_service_runs( async def get_osparc_credits_aggregated_usages_page( user_id: UserID, product_name: ProductName, - resource_tracker_repo: ResourceTrackerRepository, + db_engine: AsyncEngine, aggregated_by: ServicesAggregatedUsagesType, time_period: ServicesAggregatedUsagesTimePeriod, wallet_id: WalletID, @@ -204,7 +214,8 @@ async def get_osparc_credits_aggregated_usages_page( ( count_output_list_db, output_list_db, - ) = await resource_tracker_repo.get_osparc_credits_aggregated_by_service( + ) = await service_runs_db.get_osparc_credits_aggregated_by_service( + db_engine, product_name=product_name, user_id=user_id if access_all_wallet_usage is False else None, wallet_id=wallet_id, diff --git a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/utils.py b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/utils.py index 73aa7416244..6047ac2e904 100644 --- a/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/utils.py +++ b/services/resource-usage-tracker/src/simcore_service_resource_usage_tracker/services/utils.py @@ -19,8 +19,9 @@ from models_library.wallets import WalletID from pydantic import PositiveInt from servicelib.rabbitmq import RabbitMQClient +from sqlalchemy.ext.asyncio import AsyncEngine -from .modules.db.repositories.resource_tracker import ResourceTrackerRepository +from .modules.db import credit_transactions_db, service_runs_db _logger = logging.getLogger(__name__) @@ -30,15 +31,16 @@ def make_negative(n): async def sum_credit_transactions_and_publish_to_rabbitmq( - resource_tracker_repo: ResourceTrackerRepository, + db_engine: AsyncEngine, rabbitmq_client: RabbitMQClient, product_name: ProductName, wallet_id: WalletID, ) -> WalletTotalCredits: wallet_total_credits = ( - await resource_tracker_repo.sum_credit_transactions_by_product_and_wallet( - product_name, - wallet_id, + await credit_transactions_db.sum_credit_transactions_by_product_and_wallet( + db_engine, + product_name=product_name, + wallet_id=wallet_id, ) ) publish_message = WalletCreditsMessage.model_construct( @@ -77,7 +79,7 @@ async def _publish_to_rabbitmq_wallet_credits_limit_reached( async def publish_to_rabbitmq_wallet_credits_limit_reached( - resource_tracker_repo: ResourceTrackerRepository, + db_engine: AsyncEngine, rabbitmq_client: RabbitMQClient, product_name: ProductName, wallet_id: WalletID, @@ -86,8 +88,9 @@ async def publish_to_rabbitmq_wallet_credits_limit_reached( ): # Get all current running services for that wallet total_count: PositiveInt = ( - await resource_tracker_repo.total_service_runs_by_product_and_user_and_wallet( - product_name, + await service_runs_db.total_service_runs_by_product_and_user_and_wallet( + db_engine, + product_name=product_name, user_id=None, wallet_id=wallet_id, service_run_status=ServiceRunStatus.RUNNING, @@ -95,13 +98,16 @@ async def publish_to_rabbitmq_wallet_credits_limit_reached( ) for offset in range(0, total_count, _BATCH_SIZE): - batch_services = await resource_tracker_repo.list_service_runs_by_product_and_user_and_wallet( - product_name, - user_id=None, - wallet_id=wallet_id, - offset=offset, - limit=_BATCH_SIZE, - service_run_status=ServiceRunStatus.RUNNING, + batch_services = ( + await service_runs_db.list_service_runs_by_product_and_user_and_wallet( + db_engine, + product_name=product_name, + user_id=None, + wallet_id=wallet_id, + offset=offset, + limit=_BATCH_SIZE, + service_run_status=ServiceRunStatus.RUNNING, + ) ) await asyncio.gather( diff --git a/services/resource-usage-tracker/tests/unit/with_dbs/test_api_resource_tracker_service_runs__export.py b/services/resource-usage-tracker/tests/unit/with_dbs/test_api_resource_tracker_service_runs__export.py index 56c9c102df6..44a6ce56016 100644 --- a/services/resource-usage-tracker/tests/unit/with_dbs/test_api_resource_tracker_service_runs__export.py +++ b/services/resource-usage-tracker/tests/unit/with_dbs/test_api_resource_tracker_service_runs__export.py @@ -31,7 +31,7 @@ @pytest.fixture async def mocked_export(mocker: MockerFixture) -> AsyncMock: return mocker.patch( - "simcore_service_resource_usage_tracker.services.service_runs.ResourceTrackerRepository.export_service_runs_table_to_s3", + "simcore_service_resource_usage_tracker.services.service_runs.service_runs_db.export_service_runs_table_to_s3", autospec=True, ) diff --git a/services/resource-usage-tracker/tests/unit/with_dbs/test_background_task_periodic_heartbeat_check.py b/services/resource-usage-tracker/tests/unit/with_dbs/test_background_task_periodic_heartbeat_check.py index 35114a3cdf6..8ebe34bbd2d 100644 --- a/services/resource-usage-tracker/tests/unit/with_dbs/test_background_task_periodic_heartbeat_check.py +++ b/services/resource-usage-tracker/tests/unit/with_dbs/test_background_task_periodic_heartbeat_check.py @@ -23,9 +23,6 @@ from simcore_service_resource_usage_tracker.services.background_task_periodic_heartbeat_check import ( periodic_check_of_running_services_task, ) -from simcore_service_resource_usage_tracker.services.modules.db.repositories.resource_tracker import ( - ResourceTrackerRepository, -) pytest_simcore_core_services_selection = ["postgres", "rabbit"] pytest_simcore_ops_services_selection = [ @@ -132,9 +129,6 @@ async def test_process_event_functions( ): engine = initialized_app.state.engine app_settings: ApplicationSettings = initialized_app.state.settings - resource_tracker_repo: ResourceTrackerRepository = ResourceTrackerRepository( - db_engine=engine - ) for _ in range(app_settings.RESOURCE_USAGE_TRACKER_MISSED_HEARTBEAT_COUNTER_FAIL): await periodic_check_of_running_services_task(initialized_app) diff --git a/services/resource-usage-tracker/tests/unit/with_dbs/test_process_rabbitmq_message.py b/services/resource-usage-tracker/tests/unit/with_dbs/test_process_rabbitmq_message.py index da321f593f3..57eb9735e68 100644 --- a/services/resource-usage-tracker/tests/unit/with_dbs/test_process_rabbitmq_message.py +++ b/services/resource-usage-tracker/tests/unit/with_dbs/test_process_rabbitmq_message.py @@ -8,9 +8,6 @@ SimcorePlatformStatus, ) from servicelib.rabbitmq import RabbitMQClient -from simcore_service_resource_usage_tracker.services.modules.db.repositories.resource_tracker import ( - ResourceTrackerRepository, -) from simcore_service_resource_usage_tracker.services.process_message_running_service import ( _process_heartbeat_event, _process_start_event, @@ -43,10 +40,7 @@ async def test_process_event_functions( pricing_unit_id=None, pricing_unit_cost_id=None, ) - resource_tracker_repo: ResourceTrackerRepository = ResourceTrackerRepository( - db_engine=engine - ) - await _process_start_event(resource_tracker_repo, msg, publisher) + await _process_start_event(engine, msg, publisher) output = await assert_service_runs_db_row(postgres_db, msg.service_run_id) assert output.stopped_at is None assert output.service_run_status == "RUNNING" @@ -55,7 +49,7 @@ async def test_process_event_functions( heartbeat_msg = RabbitResourceTrackingHeartbeatMessage( service_run_id=msg.service_run_id, created_at=datetime.now(tz=timezone.utc) ) - await _process_heartbeat_event(resource_tracker_repo, heartbeat_msg, publisher) + await _process_heartbeat_event(engine, heartbeat_msg, publisher) output = await assert_service_runs_db_row(postgres_db, msg.service_run_id) assert output.stopped_at is None assert output.service_run_status == "RUNNING" @@ -66,7 +60,7 @@ async def test_process_event_functions( created_at=datetime.now(tz=timezone.utc), simcore_platform_status=SimcorePlatformStatus.OK, ) - await _process_stop_event(resource_tracker_repo, stopped_msg, publisher) + await _process_stop_event(engine, stopped_msg, publisher) output = await assert_service_runs_db_row(postgres_db, msg.service_run_id) assert output.stopped_at is not None assert output.service_run_status == "SUCCESS" diff --git a/services/resource-usage-tracker/tests/unit/with_dbs/test_process_rabbitmq_message_with_billing.py b/services/resource-usage-tracker/tests/unit/with_dbs/test_process_rabbitmq_message_with_billing.py index 637a2219f94..b29863f0b57 100644 --- a/services/resource-usage-tracker/tests/unit/with_dbs/test_process_rabbitmq_message_with_billing.py +++ b/services/resource-usage-tracker/tests/unit/with_dbs/test_process_rabbitmq_message_with_billing.py @@ -31,9 +31,6 @@ resource_tracker_pricing_units, ) from simcore_postgres_database.models.services import services_meta_data -from simcore_service_resource_usage_tracker.services.modules.db.repositories.resource_tracker import ( - ResourceTrackerRepository, -) from simcore_service_resource_usage_tracker.services.process_message_running_service import ( _process_heartbeat_event, _process_start_event, @@ -207,10 +204,8 @@ async def test_process_event_functions( pricing_unit_id=1, pricing_unit_cost_id=1, ) - resource_tracker_repo: ResourceTrackerRepository = ResourceTrackerRepository( - db_engine=engine - ) - await _process_start_event(resource_tracker_repo, msg, publisher) + + await _process_start_event(engine, msg, publisher) output = await assert_credit_transactions_db_row(postgres_db, msg.service_run_id) assert output.osparc_credits == 0.0 assert output.transaction_status == "PENDING" @@ -222,7 +217,7 @@ async def test_process_event_functions( heartbeat_msg = RabbitResourceTrackingHeartbeatMessage( service_run_id=msg.service_run_id, created_at=datetime.now(tz=timezone.utc) ) - await _process_heartbeat_event(resource_tracker_repo, heartbeat_msg, publisher) + await _process_heartbeat_event(engine, heartbeat_msg, publisher) output = await assert_credit_transactions_db_row( postgres_db, msg.service_run_id, modified_at ) @@ -240,7 +235,7 @@ async def test_process_event_functions( created_at=datetime.now(tz=timezone.utc), simcore_platform_status=SimcorePlatformStatus.OK, ) - await _process_stop_event(resource_tracker_repo, stopped_msg, publisher) + await _process_stop_event(engine, stopped_msg, publisher) output = await assert_credit_transactions_db_row( postgres_db, msg.service_run_id, modified_at ) diff --git a/services/resource-usage-tracker/tests/unit/with_dbs/test_process_rabbitmq_message_with_billing_cost_0.py b/services/resource-usage-tracker/tests/unit/with_dbs/test_process_rabbitmq_message_with_billing_cost_0.py index 5b903cf759d..ccffbc9f42e 100644 --- a/services/resource-usage-tracker/tests/unit/with_dbs/test_process_rabbitmq_message_with_billing_cost_0.py +++ b/services/resource-usage-tracker/tests/unit/with_dbs/test_process_rabbitmq_message_with_billing_cost_0.py @@ -31,9 +31,6 @@ resource_tracker_pricing_units, ) from simcore_postgres_database.models.services import services_meta_data -from simcore_service_resource_usage_tracker.services.modules.db.repositories.resource_tracker import ( - ResourceTrackerRepository, -) from simcore_service_resource_usage_tracker.services.process_message_running_service import ( _process_heartbeat_event, _process_start_event, @@ -149,10 +146,8 @@ async def test_process_event_functions( pricing_unit_id=1, pricing_unit_cost_id=1, ) - resource_tracker_repo: ResourceTrackerRepository = ResourceTrackerRepository( - db_engine=engine - ) - await _process_start_event(resource_tracker_repo, msg, publisher) + + await _process_start_event(engine, msg, publisher) output = await assert_credit_transactions_db_row(postgres_db, msg.service_run_id) assert output.osparc_credits == 0.0 assert output.transaction_status == "PENDING" @@ -164,7 +159,7 @@ async def test_process_event_functions( heartbeat_msg = RabbitResourceTrackingHeartbeatMessage( service_run_id=msg.service_run_id, created_at=datetime.now(tz=timezone.utc) ) - await _process_heartbeat_event(resource_tracker_repo, heartbeat_msg, publisher) + await _process_heartbeat_event(engine, heartbeat_msg, publisher) output = await assert_credit_transactions_db_row( postgres_db, msg.service_run_id, modified_at ) @@ -177,7 +172,7 @@ async def test_process_event_functions( created_at=datetime.now(tz=timezone.utc), simcore_platform_status=SimcorePlatformStatus.OK, ) - await _process_stop_event(resource_tracker_repo, stopped_msg, publisher) + await _process_stop_event(engine, stopped_msg, publisher) output = await assert_credit_transactions_db_row( postgres_db, msg.service_run_id, modified_at )