Skip to content

Commit

Permalink
Move tasks initialization to workflow service
Browse files Browse the repository at this point in the history
  • Loading branch information
dgarros committed Sep 13, 2024
1 parent 974fc49 commit 2f8d489
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 31 deletions.
9 changes: 9 additions & 0 deletions backend/infrahub/cli/git_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from infrahub.services.adapters.cache.redis import RedisCache
from infrahub.services.adapters.message_bus.nats import NATSMessageBus
from infrahub.services.adapters.message_bus.rabbitmq import RabbitMQMessageBus
from infrahub.services.adapters.workflow.local import WorkflowLocalExecution
from infrahub.services.adapters.workflow.worker import WorkflowWorkerExecution
from infrahub.trace import configure_trace

if TYPE_CHECKING:
Expand Down Expand Up @@ -111,6 +113,12 @@ async def start(

database = await context.get_db(retry=1)

workflow = config.OVERRIDE.workflow or (
WorkflowWorkerExecution()
if config.SETTINGS.workflow.driver == config.WorkflowDriver.WORKER
else WorkflowLocalExecution()
)

message_bus = config.OVERRIDE.message_bus or (
NATSMessageBus() if config.SETTINGS.broker.driver == config.BrokerDriver.NATS else RabbitMQMessageBus()
)
Expand All @@ -122,6 +130,7 @@ async def start(
cache=cache,
client=client,
database=database,
workflow=workflow,
message_bus=message_bus,
component_type=ComponentType.GIT_AGENT,
)
Expand Down
27 changes: 0 additions & 27 deletions backend/infrahub/core/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@
from typing import Optional
from uuid import uuid4

from prefect.client.orchestration import get_client
from prefect.client.schemas.actions import WorkPoolCreate
from prefect.exceptions import ObjectAlreadyExists

from infrahub import config, lock
from infrahub.core import registry
from infrahub.core.branch import Branch
Expand All @@ -32,7 +28,6 @@
from infrahub.permissions import PermissionBackend
from infrahub.storage import InfrahubObjectStorage
from infrahub.utils import format_label
from infrahub.workflows.catalogue import worker_pools, workflows

log = get_logger()

Expand Down Expand Up @@ -115,26 +110,6 @@ async def initialize_registry(db: InfrahubDatabase, initialize: bool = False) ->
registry.permission_backends = initialize_permission_backends()


async def initialize_tasks() -> None:
async with get_client(sync_client=False) as client:
for worker in worker_pools:
wp = WorkPoolCreate(
name=worker.name,
type=worker.worker_type,
description=worker.description,
)
try:
await client.create_work_pool(work_pool=wp)
log.info(f"work pool {worker} created successfully ... ")
except ObjectAlreadyExists:
log.info(f"work pool {worker} already present ")

# Create deployment
for workflow in workflows:
flow_id = await client.create_flow_from_name(workflow.name)
await client.create_deployment(flow_id=flow_id, **workflow.to_deployment())


async def initialization(db: InfrahubDatabase) -> None:
if config.SETTINGS.database.db_type == config.DatabaseType.MEMGRAPH:
session = await db.session()
Expand All @@ -156,8 +131,6 @@ async def initialization(db: InfrahubDatabase) -> None:
else:
log.warning("The database index manager hasn't been initialized.")

await initialize_tasks()

# ---------------------------------------------------
# Load all schema in the database into the registry
# ... Unless the schema has been initialized already
Expand Down
1 change: 1 addition & 0 deletions backend/infrahub/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ async def initialize(self) -> None:
await self.message_bus.initialize(service=self)
await self.cache.initialize(service=self)
await self.scheduler.initialize(service=self)
await self.workflow.initialize(service=self)

async def shutdown(self) -> None:
"""Initialize the Services"""
Expand Down
13 changes: 10 additions & 3 deletions backend/infrahub/services/adapters/workflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Any, Awaitable, Callable, ParamSpec, TypeVar
from __future__ import annotations

from infrahub.workflows.models import WorkflowDefinition
from typing import TYPE_CHECKING, Any, Awaitable, Callable, ParamSpec, TypeVar

if TYPE_CHECKING:
from infrahub.services import InfrahubServices
from infrahub.workflows.models import WorkflowDefinition

Return = TypeVar("Return")
Params = ParamSpec("Params")
Expand All @@ -9,10 +13,13 @@


class InfrahubWorkflow:
async def initialize(self, service: InfrahubServices) -> None:
"""Initialize the Workflow engine"""

async def execute(
self,
workflow: WorkflowDefinition | None = None,
function: Callable[..., Awaitable[Return]] | None = None,
**kwargs: dict[str, Any],
) -> Any:
) -> Return:
raise NotImplementedError()
31 changes: 30 additions & 1 deletion backend/infrahub/services/adapters/workflow/worker.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,49 @@
from __future__ import annotations

import base64
import json
from pathlib import Path
from typing import TYPE_CHECKING, Any, Awaitable, Callable

import cloudpickle
from prefect.client.orchestration import get_client
from prefect.client.schemas.actions import WorkPoolCreate
from prefect.deployments import run_deployment
from prefect.exceptions import ObjectAlreadyExists

from infrahub.workflows.models import WorkflowDefinition
from infrahub.workflows.catalogue import worker_pools, workflows

from . import InfrahubWorkflow, Return

if TYPE_CHECKING:
from prefect.client.schemas.objects import FlowRun

from infrahub.services import InfrahubServices
from infrahub.workflows.models import WorkflowDefinition


class WorkflowWorkerExecution(InfrahubWorkflow):
async def initialize(self, service: InfrahubServices) -> None:
"""Initialize the Workflow engine"""

async with get_client(sync_client=False) as client:
for worker in worker_pools:
wp = WorkPoolCreate(
name=worker.name,
type=worker.worker_type,
description=worker.description,
)
try:
await client.create_work_pool(work_pool=wp)
service.log.info(f"work pool {worker} created successfully ... ")
except ObjectAlreadyExists:
service.log.info(f"work pool {worker} already present ")

# Create deployment
for workflow in workflows:
flow_id = await client.create_flow_from_name(workflow.name)
await client.create_deployment(flow_id=flow_id, **workflow.to_deployment())

async def execute(
self,
workflow: WorkflowDefinition | None = None,
Expand Down
9 changes: 9 additions & 0 deletions backend/infrahub/workers/infrahub_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from infrahub.services.adapters.cache.redis import RedisCache
from infrahub.services.adapters.message_bus.nats import NATSMessageBus
from infrahub.services.adapters.message_bus.rabbitmq import RabbitMQMessageBus
from infrahub.services.adapters.workflow.local import WorkflowLocalExecution
from infrahub.services.adapters.workflow.worker import WorkflowWorkerExecution


class InfrahubWorkerAsyncConfiguration(BaseJobConfiguration):
Expand Down Expand Up @@ -69,6 +71,12 @@ async def setup(self, **kwargs: dict[str, Any]) -> None:

database = InfrahubDatabase(driver=await get_db(retry=1))

workflow = config.OVERRIDE.workflow or (
WorkflowWorkerExecution()
if config.SETTINGS.workflow.driver == config.WorkflowDriver.WORKER
else WorkflowLocalExecution()
)

message_bus = config.OVERRIDE.message_bus or (
NATSMessageBus() if config.SETTINGS.broker.driver == config.BrokerDriver.NATS else RabbitMQMessageBus()
)
Expand All @@ -81,6 +89,7 @@ async def setup(self, **kwargs: dict[str, Any]) -> None:
client=client,
database=database,
message_bus=message_bus,
workflow=workflow,
component_type=ComponentType.GIT_AGENT,
)
services.service = service
Expand Down

0 comments on commit 2f8d489

Please sign in to comment.