From 2f8d48902b967eefc7fd258bef95c9796445f604 Mon Sep 17 00:00:00 2001 From: Damien Garros Date: Tue, 10 Sep 2024 08:52:27 +0200 Subject: [PATCH] Move tasks initialization to workflow service --- backend/infrahub/cli/git_agent.py | 9 ++++++ backend/infrahub/core/initialization.py | 27 ---------------- backend/infrahub/services/__init__.py | 1 + .../services/adapters/workflow/__init__.py | 13 ++++++-- .../services/adapters/workflow/worker.py | 31 ++++++++++++++++++- backend/infrahub/workers/infrahub_async.py | 9 ++++++ 6 files changed, 59 insertions(+), 31 deletions(-) diff --git a/backend/infrahub/cli/git_agent.py b/backend/infrahub/cli/git_agent.py index 5985f995d5..3f0e20d319 100644 --- a/backend/infrahub/cli/git_agent.py +++ b/backend/infrahub/cli/git_agent.py @@ -23,6 +23,8 @@ from infrahub.services.adapters.cache.redis import RedisCache from infrahub.services.adapters.message_bus.nats import NATSMessageBus from infrahub.services.adapters.message_bus.rabbitmq import RabbitMQMessageBus +from infrahub.services.adapters.workflow.local import WorkflowLocalExecution +from infrahub.services.adapters.workflow.worker import WorkflowWorkerExecution from infrahub.trace import configure_trace if TYPE_CHECKING: @@ -111,6 +113,12 @@ async def start( database = await context.get_db(retry=1) + workflow = config.OVERRIDE.workflow or ( + WorkflowWorkerExecution() + if config.SETTINGS.workflow.driver == config.WorkflowDriver.WORKER + else WorkflowLocalExecution() + ) + message_bus = config.OVERRIDE.message_bus or ( NATSMessageBus() if config.SETTINGS.broker.driver == config.BrokerDriver.NATS else RabbitMQMessageBus() ) @@ -122,6 +130,7 @@ async def start( cache=cache, client=client, database=database, + workflow=workflow, message_bus=message_bus, component_type=ComponentType.GIT_AGENT, ) diff --git a/backend/infrahub/core/initialization.py b/backend/infrahub/core/initialization.py index 47b0164e42..7620b9aa5a 100644 --- a/backend/infrahub/core/initialization.py +++ b/backend/infrahub/core/initialization.py @@ -2,10 +2,6 @@ from typing import Optional from uuid import uuid4 -from prefect.client.orchestration import get_client -from prefect.client.schemas.actions import WorkPoolCreate -from prefect.exceptions import ObjectAlreadyExists - from infrahub import config, lock from infrahub.core import registry from infrahub.core.branch import Branch @@ -32,7 +28,6 @@ from infrahub.permissions import PermissionBackend from infrahub.storage import InfrahubObjectStorage from infrahub.utils import format_label -from infrahub.workflows.catalogue import worker_pools, workflows log = get_logger() @@ -115,26 +110,6 @@ async def initialize_registry(db: InfrahubDatabase, initialize: bool = False) -> registry.permission_backends = initialize_permission_backends() -async def initialize_tasks() -> None: - async with get_client(sync_client=False) as client: - for worker in worker_pools: - wp = WorkPoolCreate( - name=worker.name, - type=worker.worker_type, - description=worker.description, - ) - try: - await client.create_work_pool(work_pool=wp) - log.info(f"work pool {worker} created successfully ... ") - except ObjectAlreadyExists: - log.info(f"work pool {worker} already present ") - - # Create deployment - for workflow in workflows: - flow_id = await client.create_flow_from_name(workflow.name) - await client.create_deployment(flow_id=flow_id, **workflow.to_deployment()) - - async def initialization(db: InfrahubDatabase) -> None: if config.SETTINGS.database.db_type == config.DatabaseType.MEMGRAPH: session = await db.session() @@ -156,8 +131,6 @@ async def initialization(db: InfrahubDatabase) -> None: else: log.warning("The database index manager hasn't been initialized.") - await initialize_tasks() - # --------------------------------------------------- # Load all schema in the database into the registry # ... Unless the schema has been initialized already diff --git a/backend/infrahub/services/__init__.py b/backend/infrahub/services/__init__.py index 5cfddbd845..7ecef707a6 100644 --- a/backend/infrahub/services/__init__.py +++ b/backend/infrahub/services/__init__.py @@ -102,6 +102,7 @@ async def initialize(self) -> None: await self.message_bus.initialize(service=self) await self.cache.initialize(service=self) await self.scheduler.initialize(service=self) + await self.workflow.initialize(service=self) async def shutdown(self) -> None: """Initialize the Services""" diff --git a/backend/infrahub/services/adapters/workflow/__init__.py b/backend/infrahub/services/adapters/workflow/__init__.py index 4ce2e1bec2..1e16e65ab9 100644 --- a/backend/infrahub/services/adapters/workflow/__init__.py +++ b/backend/infrahub/services/adapters/workflow/__init__.py @@ -1,6 +1,10 @@ -from typing import Any, Awaitable, Callable, ParamSpec, TypeVar +from __future__ import annotations -from infrahub.workflows.models import WorkflowDefinition +from typing import TYPE_CHECKING, Any, Awaitable, Callable, ParamSpec, TypeVar + +if TYPE_CHECKING: + from infrahub.services import InfrahubServices + from infrahub.workflows.models import WorkflowDefinition Return = TypeVar("Return") Params = ParamSpec("Params") @@ -9,10 +13,13 @@ class InfrahubWorkflow: + async def initialize(self, service: InfrahubServices) -> None: + """Initialize the Workflow engine""" + async def execute( self, workflow: WorkflowDefinition | None = None, function: Callable[..., Awaitable[Return]] | None = None, **kwargs: dict[str, Any], - ) -> Any: + ) -> Return: raise NotImplementedError() diff --git a/backend/infrahub/services/adapters/workflow/worker.py b/backend/infrahub/services/adapters/workflow/worker.py index 7ff7b0caae..668b9e1700 100644 --- a/backend/infrahub/services/adapters/workflow/worker.py +++ b/backend/infrahub/services/adapters/workflow/worker.py @@ -1,20 +1,49 @@ +from __future__ import annotations + import base64 import json from pathlib import Path from typing import TYPE_CHECKING, Any, Awaitable, Callable import cloudpickle +from prefect.client.orchestration import get_client +from prefect.client.schemas.actions import WorkPoolCreate from prefect.deployments import run_deployment +from prefect.exceptions import ObjectAlreadyExists -from infrahub.workflows.models import WorkflowDefinition +from infrahub.workflows.catalogue import worker_pools, workflows from . import InfrahubWorkflow, Return if TYPE_CHECKING: from prefect.client.schemas.objects import FlowRun + from infrahub.services import InfrahubServices + from infrahub.workflows.models import WorkflowDefinition + class WorkflowWorkerExecution(InfrahubWorkflow): + async def initialize(self, service: InfrahubServices) -> None: + """Initialize the Workflow engine""" + + async with get_client(sync_client=False) as client: + for worker in worker_pools: + wp = WorkPoolCreate( + name=worker.name, + type=worker.worker_type, + description=worker.description, + ) + try: + await client.create_work_pool(work_pool=wp) + service.log.info(f"work pool {worker} created successfully ... ") + except ObjectAlreadyExists: + service.log.info(f"work pool {worker} already present ") + + # Create deployment + for workflow in workflows: + flow_id = await client.create_flow_from_name(workflow.name) + await client.create_deployment(flow_id=flow_id, **workflow.to_deployment()) + async def execute( self, workflow: WorkflowDefinition | None = None, diff --git a/backend/infrahub/workers/infrahub_async.py b/backend/infrahub/workers/infrahub_async.py index 0b007ce1db..63d14efcb4 100644 --- a/backend/infrahub/workers/infrahub_async.py +++ b/backend/infrahub/workers/infrahub_async.py @@ -22,6 +22,8 @@ from infrahub.services.adapters.cache.redis import RedisCache from infrahub.services.adapters.message_bus.nats import NATSMessageBus from infrahub.services.adapters.message_bus.rabbitmq import RabbitMQMessageBus +from infrahub.services.adapters.workflow.local import WorkflowLocalExecution +from infrahub.services.adapters.workflow.worker import WorkflowWorkerExecution class InfrahubWorkerAsyncConfiguration(BaseJobConfiguration): @@ -69,6 +71,12 @@ async def setup(self, **kwargs: dict[str, Any]) -> None: database = InfrahubDatabase(driver=await get_db(retry=1)) + workflow = config.OVERRIDE.workflow or ( + WorkflowWorkerExecution() + if config.SETTINGS.workflow.driver == config.WorkflowDriver.WORKER + else WorkflowLocalExecution() + ) + message_bus = config.OVERRIDE.message_bus or ( NATSMessageBus() if config.SETTINGS.broker.driver == config.BrokerDriver.NATS else RabbitMQMessageBus() ) @@ -81,6 +89,7 @@ async def setup(self, **kwargs: dict[str, Any]) -> None: client=client, database=database, message_bus=message_bus, + workflow=workflow, component_type=ComponentType.GIT_AGENT, ) services.service = service