Skip to content

Commit

Permalink
Update worker execution code
Browse files Browse the repository at this point in the history
Ensure prefect worker read the proper config file
Set transform_render_jinja2_template as persistent
  • Loading branch information
dgarros committed Sep 13, 2024
1 parent 2f8d489 commit c9d1d35
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 16 deletions.
2 changes: 1 addition & 1 deletion backend/infrahub/message_bus/operations/transform/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def template(message: TransformJinjaTemplate, service: InfrahubServices) -
await service.reply(message=response, initiator=message)


@flow
@flow(persist_result=True)
async def transform_render_jinja2_template(message: TransformJinjaTemplateData) -> str:
service = services.service
repo = await get_initialized_repo(
Expand Down
15 changes: 1 addition & 14 deletions backend/infrahub/services/adapters/workflow/worker.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from __future__ import annotations

import base64
import json
from pathlib import Path
from typing import TYPE_CHECKING, Any, Awaitable, Callable

import cloudpickle
from prefect.client.orchestration import get_client
from prefect.client.schemas.actions import WorkPoolCreate
from prefect.deployments import run_deployment
Expand Down Expand Up @@ -54,16 +50,7 @@ async def execute(
response: FlowRun = await run_deployment(name=workflow.full_name, parameters=kwargs or {}) # type: ignore[return-value, misc]
if not response.state:
raise RuntimeError("Unable to read state from the response")
result = response.state.result()

with Path(result.storage_key).open(encoding="utf-8") as f:
result_data = json.load(f)
encoded_data = result_data["data"]
decoded_data = base64.b64decode(encoded_data)

if result_data["serializer"]["type"] == "pickle":
return cloudpickle.loads(decoded_data)
raise ValueError("Unsupported serializer type")
return await response.state.result(fetch=True, raise_on_failure=True) # type: ignore[call-overload]

if function:
return await function(**kwargs)
Expand Down
4 changes: 3 additions & 1 deletion backend/infrahub/workers/infrahub_async.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib
import logging
import os
from typing import Any, Optional

import typer
Expand Down Expand Up @@ -57,7 +58,8 @@ async def setup(self, **kwargs: dict[str, Any]) -> None:
logging.getLogger("aiormq").setLevel(logging.ERROR)
logging.getLogger("git").setLevel(logging.ERROR)

config.load_and_exit()
config_file = os.environ.get("INFRAHUB_CONFIG", "infrahub.toml")
config.load_and_exit(config_file_name=config_file)

self._logger.debug(f"Using Infrahub API at {config.SETTINGS.main.internal_address}")
client = InfrahubClient(
Expand Down

0 comments on commit c9d1d35

Please sign in to comment.