From 2f6ab0a64c5fcc35e10567b8fefc2637e1db092c Mon Sep 17 00:00:00 2001 From: matusdrobuliak66 <60785969+matusdrobuliak66@users.noreply.github.com> Date: Fri, 17 May 2024 16:48:08 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20Check=20for=20zero=20credits=20(?= =?UTF-8?q?if=20pricing=20unit=20cost=20is=20greater=20than=200)=20(#5835)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/models_library/wallets.py | 9 +++- services/director-v2/openapi.json | 12 +++-- .../api/routes/computations.py | 7 ++- .../core/errors.py | 4 ++ .../models/pricing.py | 28 +++++++++++ .../db/repositories/comp_tasks/_core.py | 5 +- .../db/repositories/comp_tasks/_utils.py | 49 ++++++++++++++----- .../with_dbs/test_api_route_computations.py | 47 ++++++++++++++---- .../director_v2/_api_utils.py | 13 +++-- .../projects/_nodes_handlers.py | 25 +++++----- .../projects/projects_api.py | 33 ++++++++----- 11 files changed, 171 insertions(+), 61 deletions(-) create mode 100644 services/director-v2/src/simcore_service_director_v2/models/pricing.py diff --git a/packages/models-library/src/models_library/wallets.py b/packages/models-library/src/models_library/wallets.py index f2b9439e818..08651353daa 100644 --- a/packages/models-library/src/models_library/wallets.py +++ b/packages/models-library/src/models_library/wallets.py @@ -18,10 +18,17 @@ class WalletStatus(StrAutoEnum): class WalletInfo(BaseModel): wallet_id: WalletID wallet_name: str + wallet_credit_amount: Decimal class Config: schema_extra: ClassVar[dict[str, Any]] = { - "examples": [{"wallet_id": 1, "wallet_name": "My Wallet"}] + "examples": [ + { + "wallet_id": 1, + "wallet_name": "My Wallet", + "wallet_credit_amount": Decimal(10), + } + ] } diff --git a/services/director-v2/openapi.json b/services/director-v2/openapi.json index a329a785efe..5be4cf92355 100644 --- a/services/director-v2/openapi.json +++ b/services/director-v2/openapi.json @@ -3,7 +3,7 @@ "info": { "title": "simcore-service-director-v2", "description": "Orchestrates the pipeline of services defined by the user", - "version": "2.2.0" + "version": "2.3.0" }, "servers": [ { @@ -2494,7 +2494,8 @@ }, "wallet_info": { "wallet_id": 1, - "wallet_name": "My Wallet" + "wallet_name": "My Wallet", + "wallet_credit_amount": 10 }, "pricing_info": { "pricing_plan_id": 1, @@ -3859,12 +3860,17 @@ "wallet_name": { "type": "string", "title": "Wallet Name" + }, + "wallet_credit_amount": { + "type": "number", + "title": "Wallet Credit Amount" } }, "type": "object", "required": [ "wallet_id", - "wallet_name" + "wallet_name", + "wallet_credit_amount" ], "title": "WalletInfo" }, diff --git a/services/director-v2/src/simcore_service_director_v2/api/routes/computations.py b/services/director-v2/src/simcore_service_director_v2/api/routes/computations.py index d523a2042f4..83a9b9a3b4d 100644 --- a/services/director-v2/src/simcore_service_director_v2/api/routes/computations.py +++ b/services/director-v2/src/simcore_service_director_v2/api/routes/computations.py @@ -55,6 +55,7 @@ PricingPlanUnitNotFoundError, ProjectNotFoundError, SchedulerError, + WalletNotEnoughCreditsError, ) from ...models.comp_pipelines import CompPipelineAtDB from ...models.comp_runs import CompRunsAtDB, ProjectMetadataDict, RunMetadataDict @@ -318,7 +319,7 @@ async def create_computation( # noqa: PLR0913 user_id=computation.user_id, product_name=computation.product_name, rut_client=rut_client, - is_wallet=bool(computation.wallet_info), + wallet_info=computation.wallet_info, rabbitmq_rpc_client=rpc_client, ) @@ -393,6 +394,10 @@ async def create_computation( # noqa: PLR0913 ) from e except ConfigurationError as e: raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=f"{e}") from e + except WalletNotEnoughCreditsError as e: + raise HTTPException( + status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=f"{e}" + ) from e @router.get( diff --git a/services/director-v2/src/simcore_service_director_v2/core/errors.py b/services/director-v2/src/simcore_service_director_v2/core/errors.py index cf12fe0379c..e6cfc7c8cd8 100644 --- a/services/director-v2/src/simcore_service_director_v2/core/errors.py +++ b/services/director-v2/src/simcore_service_director_v2/core/errors.py @@ -114,6 +114,10 @@ class ComputationalTaskNotFoundError(PydanticErrorMixin, DirectorError): msg_template = "Computational task {node_id} not found" +class WalletNotEnoughCreditsError(PydanticErrorMixin, DirectorError): + msg_template = "Wallet '{wallet_name}' has {wallet_credit_amount} credits." + + # # SCHEDULER ERRORS # diff --git a/services/director-v2/src/simcore_service_director_v2/models/pricing.py b/services/director-v2/src/simcore_service_director_v2/models/pricing.py new file mode 100644 index 00000000000..4aabef7cd10 --- /dev/null +++ b/services/director-v2/src/simcore_service_director_v2/models/pricing.py @@ -0,0 +1,28 @@ +from decimal import Decimal +from typing import Any, ClassVar + +from models_library.resource_tracker import ( + PricingPlanId, + PricingUnitCostId, + PricingUnitId, +) +from pydantic import BaseModel + + +class PricingInfo(BaseModel): + pricing_plan_id: PricingPlanId + pricing_unit_id: PricingUnitId + pricing_unit_cost_id: PricingUnitCostId + pricing_unit_cost: Decimal + + class Config: + schema_extra: ClassVar[dict[str, Any]] = { + "examples": [ + { + "pricing_plan_id": 1, + "pricing_unit_id": 1, + "pricing_unit_cost_id": 1, + "pricing_unit_cost": Decimal(10), + } + ] + } diff --git a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_tasks/_core.py b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_tasks/_core.py index 03ee104ae17..4ab318a8d10 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_tasks/_core.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_tasks/_core.py @@ -10,6 +10,7 @@ from models_library.projects_nodes_io import NodeID from models_library.projects_state import RunningState from models_library.users import UserID +from models_library.wallets import WalletInfo from servicelib.logging_utils import log_context from servicelib.rabbitmq import RabbitMQRPCClient from servicelib.utils import logged_gather @@ -94,7 +95,7 @@ async def upsert_tasks_from_project( user_id: UserID, product_name: str, rut_client: ResourceUsageTrackerClient, - is_wallet: bool, + wallet_info: WalletInfo | None, rabbitmq_rpc_client: RabbitMQRPCClient, ) -> list[CompTaskAtDB]: # NOTE: really do an upsert here because of issue https://github.com/ITISFoundation/osparc-simcore/issues/2125 @@ -110,7 +111,7 @@ async def upsert_tasks_from_project( product_name=product_name, connection=conn, rut_client=rut_client, - is_wallet=is_wallet, + wallet_info=wallet_info, rabbitmq_rpc_client=rabbitmq_rpc_client, ) # get current tasks diff --git a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_tasks/_utils.py b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_tasks/_utils.py index f497131d3f9..19641b217f1 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_tasks/_utils.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_tasks/_utils.py @@ -1,5 +1,6 @@ import asyncio import logging +from decimal import Decimal from typing import Any, Final, cast import aiopg.sa @@ -15,7 +16,7 @@ from models_library.projects_nodes import Node from models_library.projects_nodes_io import NodeID from models_library.projects_state import RunningState -from models_library.resource_tracker import HardwareInfo, PricingInfo +from models_library.resource_tracker import HardwareInfo from models_library.service_settings_labels import ( SimcoreServiceLabels, SimcoreServiceSettingsLabel, @@ -34,6 +35,7 @@ ServiceResourcesDictHelpers, ) from models_library.users import UserID +from models_library.wallets import ZERO_CREDITS, WalletInfo from pydantic import parse_obj_as from servicelib.rabbitmq import ( RabbitMQRPCClient, @@ -45,8 +47,13 @@ ) from simcore_postgres_database.utils_projects_nodes import ProjectNodesRepo -from .....core.errors import ClustersKeeperNotAvailableError, ConfigurationError +from .....core.errors import ( + ClustersKeeperNotAvailableError, + ConfigurationError, + WalletNotEnoughCreditsError, +) from .....models.comp_tasks import CompTaskAtDB, Image, NodeSchema +from .....models.pricing import PricingInfo from .....modules.resource_usage_tracker_client import ResourceUsageTrackerClient from .....utils.comp_scheduler import COMPLETED_STATES from .....utils.computations import to_node_class @@ -201,17 +208,12 @@ async def _get_pricing_and_hardware_infos( # this will need to move away and be in sync. if output: pricing_plan_id, pricing_unit_id = output - pricing_unit_get = await rut_client.get_pricing_unit( - product_name, pricing_plan_id, pricing_unit_id - ) - pricing_unit_cost_id = pricing_unit_get.current_cost_per_unit_id - aws_ec2_instances = pricing_unit_get.specific_info.aws_ec2_instances else: ( pricing_plan_id, pricing_unit_id, - pricing_unit_cost_id, - aws_ec2_instances, + _, + _, ) = await rut_client.get_default_pricing_and_hardware_info( product_name, node_key, node_version ) @@ -222,10 +224,17 @@ async def _get_pricing_and_hardware_infos( pricing_unit_id=pricing_unit_id, ) + pricing_unit_get = await rut_client.get_pricing_unit( + product_name, pricing_plan_id, pricing_unit_id + ) + pricing_unit_cost_id = pricing_unit_get.current_cost_per_unit_id + aws_ec2_instances = pricing_unit_get.specific_info.aws_ec2_instances + pricing_info = PricingInfo( pricing_plan_id=pricing_plan_id, pricing_unit_id=pricing_unit_id, pricing_unit_cost_id=pricing_unit_cost_id, + pricing_unit_cost=pricing_unit_get.current_cost_per_unit, ) hardware_info = HardwareInfo(aws_ec2_instances=aws_ec2_instances) return pricing_info, hardware_info @@ -323,7 +332,7 @@ async def generate_tasks_list_from_project( product_name: str, connection: aiopg.sa.connection.SAConnection, rut_client: ResourceUsageTrackerClient, - is_wallet: bool, + wallet_info: WalletInfo | None, rabbitmq_rpc_client: RabbitMQRPCClient, ) -> list[CompTaskAtDB]: list_comp_tasks = [] @@ -373,17 +382,29 @@ async def generate_tasks_list_from_project( pricing_info, hardware_info = await _get_pricing_and_hardware_infos( connection, rut_client, - is_wallet=is_wallet, + is_wallet=bool(wallet_info), project_id=project.uuid, node_id=NodeID(node_id), product_name=product_name, node_key=node.key, node_version=node.version, ) + # Check for zero credits (if pricing unit is greater than 0). + if ( + wallet_info + and pricing_info + and pricing_info.pricing_unit_cost > Decimal(0) + and wallet_info.wallet_credit_amount <= ZERO_CREDITS + ): + raise WalletNotEnoughCreditsError( + wallet_name=wallet_info.wallet_name, + wallet_credit_amount=wallet_info.wallet_credit_amount, + ) + assert rabbitmq_rpc_client # nosec await _update_project_node_resources_from_hardware_info( connection, - is_wallet=is_wallet, + is_wallet=bool(wallet_info), project_id=project.uuid, node_id=NodeID(node_id), hardware_info=hardware_info, @@ -420,7 +441,9 @@ async def generate_tasks_list_from_project( last_heartbeat=None, created=arrow.utcnow().datetime, modified=arrow.utcnow().datetime, - pricing_info=pricing_info.dict() if pricing_info else None, + pricing_info=pricing_info.dict(exclude={"pricing_unit_cost"}) + if pricing_info + else None, hardware_info=hardware_info, ) diff --git a/services/director-v2/tests/unit/with_dbs/test_api_route_computations.py b/services/director-v2/tests/unit/with_dbs/test_api_route_computations.py index 4000f9cfa94..290422710ac 100644 --- a/services/director-v2/tests/unit/with_dbs/test_api_route_computations.py +++ b/services/director-v2/tests/unit/with_dbs/test_api_route_computations.py @@ -10,6 +10,7 @@ import re import urllib.parse from collections.abc import Awaitable, Callable, Iterator +from decimal import Decimal from pathlib import Path from random import choice from typing import Any @@ -29,6 +30,7 @@ from models_library.api_schemas_directorv2.services import ServiceExtras from models_library.api_schemas_resource_usage_tracker.pricing_plans import ( PricingPlanGet, + PricingUnitGet, ) from models_library.basic_types import VersionStr from models_library.clusters import DEFAULT_CLUSTER_ID, Cluster, ClusterID @@ -291,12 +293,26 @@ def _mocked_service_default_pricing_plan( 200, json=jsonable_encoder(default_pricing_plan, by_alias=True) ) + def _mocked_get_pricing_unit(request, pricing_plan_id: int) -> httpx.Response: + return httpx.Response( + 200, + json=jsonable_encoder( + ( + default_pricing_plan.pricing_units[0] + if default_pricing_plan.pricing_units + else PricingUnitGet.Config.schema_extra["examples"][0] + ), + by_alias=True, + ), + ) + # pylint: disable=not-context-manager with respx.mock( base_url=minimal_app.state.settings.DIRECTOR_V2_RESOURCE_USAGE_TRACKER.api_base_url, assert_all_called=False, assert_all_mocked=True, ) as respx_mock: + respx_mock.get( re.compile( r"services/(?Psimcore/services/(comp|dynamic|frontend)/[^/]+)/(?P[^\.]+.[^\.]+.[^/\?]+)/pricing-plan.+" @@ -304,6 +320,11 @@ def _mocked_service_default_pricing_plan( name="get_service_default_pricing_plan", ).mock(side_effect=_mocked_service_default_pricing_plan) + respx_mock.get( + re.compile(r"pricing-plans/(?P\d+)/pricing-units.+"), + name="get_pricing_unit", + ).mock(side_effect=_mocked_get_pricing_unit) + yield respx_mock @@ -384,7 +405,11 @@ async def test_create_computation( @pytest.fixture def wallet_info(faker: Faker) -> WalletInfo: - return WalletInfo(wallet_id=faker.pyint(), wallet_name=faker.name()) + return WalletInfo( + wallet_id=faker.pyint(), + wallet_name=faker.name(), + wallet_credit_amount=Decimal(faker.pyint(min_value=12, max_value=129312)), + ) @pytest.fixture @@ -483,12 +508,16 @@ async def test_create_computation_with_wallet( assert response.status_code == status.HTTP_201_CREATED, response.text if default_pricing_plan_aws_ec2_type: mocked_clusters_keeper_service_get_instance_type_details.assert_called() - assert mocked_resource_usage_tracker_service_fcts.calls.call_count == len( - [ - v - for v in proj.workbench.values() - if to_node_class(v.key) != NodeClass.FRONTEND - ] + assert ( + mocked_resource_usage_tracker_service_fcts.calls.call_count + == len( + [ + v + for v in proj.workbench.values() + if to_node_class(v.key) != NodeClass.FRONTEND + ] + ) + * 2 ) # check the project nodes were really overriden now async with aiopg_engine.acquire() as connection: @@ -540,7 +569,7 @@ async def test_create_computation_with_wallet( @pytest.mark.parametrize( "default_pricing_plan", - [PricingPlanGet.Config.schema_extra["examples"][0]], + [PricingPlanGet.construct(**PricingPlanGet.Config.schema_extra["examples"][0])], ) async def test_create_computation_with_wallet_with_invalid_pricing_unit_name_raises_409( minimal_configuration: None, @@ -578,7 +607,7 @@ async def test_create_computation_with_wallet_with_invalid_pricing_unit_name_rai @pytest.mark.parametrize( "default_pricing_plan", - [PricingPlanGet.Config.schema_extra["examples"][0]], + [PricingPlanGet.construct(**PricingPlanGet.Config.schema_extra["examples"][0])], ) async def test_create_computation_with_wallet_with_no_clusters_keeper_raises_503( minimal_configuration: None, diff --git a/services/web/server/src/simcore_service_webserver/director_v2/_api_utils.py b/services/web/server/src/simcore_service_webserver/director_v2/_api_utils.py index ce5795d5385..e9bbca91c50 100644 --- a/services/web/server/src/simcore_service_webserver/director_v2/_api_utils.py +++ b/services/web/server/src/simcore_service_webserver/director_v2/_api_utils.py @@ -1,7 +1,7 @@ from aiohttp import web from models_library.projects import ProjectID from models_library.users import UserID -from models_library.wallets import ZERO_CREDITS, WalletID, WalletInfo +from models_library.wallets import WalletID, WalletInfo from pydantic import parse_obj_as from ..application_settings import get_application_settings @@ -10,7 +10,6 @@ from ..users import preferences_api as user_preferences_api from ..users.exceptions import UserDefaultWalletNotFoundError from ..wallets import api as wallets_api -from ..wallets.errors import WalletNotEnoughCreditsError async def get_wallet_info( @@ -54,8 +53,8 @@ async def get_wallet_info( wallet_id=project_wallet_id, product_name=product_name, ) - if wallet.available_credits <= ZERO_CREDITS: - raise WalletNotEnoughCreditsError( - reason=f"Wallet '{wallet.name}' has {wallet.available_credits} credits." - ) - return WalletInfo(wallet_id=project_wallet_id, wallet_name=wallet.name) + return WalletInfo( + wallet_id=project_wallet_id, + wallet_name=wallet.name, + wallet_credit_amount=wallet.available_credits, + ) diff --git a/services/web/server/src/simcore_service_webserver/projects/_nodes_handlers.py b/services/web/server/src/simcore_service_webserver/projects/_nodes_handlers.py index e3e22c01064..67f3104a829 100644 --- a/services/web/server/src/simcore_service_webserver/projects/_nodes_handlers.py +++ b/services/web/server/src/simcore_service_webserver/projects/_nodes_handlers.py @@ -101,6 +101,10 @@ async def wrapper(request: web.Request) -> web.StreamResponse: raise web.HTTPPaymentRequired(reason=f"{exc}") from exc except ProjectInvalidRightsError as exc: raise web.HTTPUnauthorized(reason=f"{exc}") from exc + except ProjectStartsTooManyDynamicNodesError as exc: + raise web.HTTPConflict(reason=f"{exc}") from exc + except ClustersKeeperNotAvailableError as exc: + raise web.HTTPServiceUnavailable(reason=f"{exc}") from exc return wrapper @@ -303,21 +307,16 @@ async def start_node(request: web.Request) -> web.Response: """Has only effect on nodes associated to dynamic services""" req_ctx = RequestContext.parse_obj(request) path_params = parse_request_path_parameters_as(NodePathParams, request) - try: - await projects_api.start_project_node( - request, - product_name=req_ctx.product_name, - user_id=req_ctx.user_id, - project_id=path_params.project_id, - node_id=path_params.node_id, - ) - raise web.HTTPNoContent(content_type=MIMETYPE_APPLICATION_JSON) + await projects_api.start_project_node( + request, + product_name=req_ctx.product_name, + user_id=req_ctx.user_id, + project_id=path_params.project_id, + node_id=path_params.node_id, + ) - except ProjectStartsTooManyDynamicNodesError as exc: - raise web.HTTPConflict(reason=f"{exc}") from exc - except ClustersKeeperNotAvailableError as exc: - raise web.HTTPServiceUnavailable(reason=f"{exc}") from exc + raise web.HTTPNoContent(content_type=MIMETYPE_APPLICATION_JSON) async def _stop_dynamic_service_task( diff --git a/services/web/server/src/simcore_service_webserver/projects/projects_api.py b/services/web/server/src/simcore_service_webserver/projects/projects_api.py index e6c5587a7cc..2e9dfcc3dfd 100644 --- a/services/web/server/src/simcore_service_webserver/projects/projects_api.py +++ b/services/web/server/src/simcore_service_webserver/projects/projects_api.py @@ -16,6 +16,7 @@ from collections import defaultdict from collections.abc import Generator from contextlib import suppress +from decimal import Decimal from pprint import pformat from typing import Any, Final from uuid import UUID, uuid4 @@ -535,12 +536,10 @@ async def _start_dynamic_service( product_name=product_name, ) ) - if wallet.available_credits <= ZERO_CREDITS: - raise WalletNotEnoughCreditsError( - reason=f"Wallet '{wallet.name}' has {wallet.available_credits} credits." - ) wallet_info = WalletInfo( - wallet_id=project_wallet_id, wallet_name=wallet.name + wallet_id=project_wallet_id, + wallet_name=wallet.name, + wallet_credit_amount=wallet.available_credits, ) # Deal with Pricing plan/unit @@ -549,17 +548,12 @@ async def _start_dynamic_service( ) if output: pricing_plan_id, pricing_unit_id = output - pricing_unit_get = await rut_api.get_pricing_plan_unit( - request.app, product_name, pricing_plan_id, pricing_unit_id - ) - pricing_unit_cost_id = pricing_unit_get.current_cost_per_unit_id - aws_ec2_instances = pricing_unit_get.specific_info.aws_ec2_instances else: ( pricing_plan_id, pricing_unit_id, - pricing_unit_cost_id, - aws_ec2_instances, + _, + _, ) = await _get_default_pricing_and_hardware_info( request.app, product_name, @@ -575,6 +569,21 @@ async def _start_dynamic_service( pricing_unit_id, ) + # Check for zero credits (if pricing unit is greater than 0). + pricing_unit_get = await rut_api.get_pricing_plan_unit( + request.app, product_name, pricing_plan_id, pricing_unit_id + ) + pricing_unit_cost_id = pricing_unit_get.current_cost_per_unit_id + aws_ec2_instances = pricing_unit_get.specific_info.aws_ec2_instances + + if ( + pricing_unit_get.current_cost_per_unit > Decimal(0) + and wallet.available_credits <= ZERO_CREDITS + ): + raise WalletNotEnoughCreditsError( + reason=f"Wallet '{wallet.name}' has {wallet.available_credits} credits." + ) + pricing_info = PricingInfo( pricing_plan_id=pricing_plan_id, pricing_unit_id=pricing_unit_id,