Skip to content

Commit

Permalink
Merge pull request #4147 from opsmill/dga-20240818-prefect
Browse files Browse the repository at this point in the history
Initial integration of Prefect for task & workflow execution
  • Loading branch information
dgarros authored Sep 17, 2024
2 parents 536f408 + 5a3bcf6 commit 76bf288
Show file tree
Hide file tree
Showing 46 changed files with 2,381 additions and 185 deletions.
15 changes: 8 additions & 7 deletions backend/infrahub/api/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
from infrahub.graphql import prepare_graphql_params
from infrahub.graphql.utils import extract_data
from infrahub.message_bus.messages import (
TransformJinjaTemplate,
TransformJinjaTemplateResponse,
TransformPythonData,
TransformPythonDataResponse,
)
from infrahub.message_bus.messages.transform_jinja_template import TransformJinjaTemplateData
from infrahub.workflows.catalogue import TRANSFORM_JINJA2_RENDER

if TYPE_CHECKING:
from infrahub.services import InfrahubServices
Expand Down Expand Up @@ -134,9 +134,7 @@ async def transform_jinja2(

data = extract_data(query_name=query.name.value, result=result)

service: InfrahubServices = request.app.state.service

message = TransformJinjaTemplate(
message = TransformJinjaTemplateData(
repository_id=repository.id,
repository_name=repository.name.value,
repository_kind=repository.get_kind(),
Expand All @@ -146,5 +144,8 @@ async def transform_jinja2(
data=data,
)

response = await service.message_bus.rpc(message=message, response_class=TransformJinjaTemplateResponse)
return PlainTextResponse(content=response.data.rendered_template)
service: InfrahubServices = request.app.state.service

response: str = await service.workflow.execute(workflow=TRANSFORM_JINJA2_RENDER, message=message) # type: ignore[arg-type]

return PlainTextResponse(content=response)
10 changes: 10 additions & 0 deletions backend/infrahub/cli/git_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from infrahub.services.adapters.cache.redis import RedisCache
from infrahub.services.adapters.message_bus.nats import NATSMessageBus
from infrahub.services.adapters.message_bus.rabbitmq import RabbitMQMessageBus
from infrahub.services.adapters.workflow.local import WorkflowLocalExecution
from infrahub.services.adapters.workflow.worker import WorkflowWorkerExecution
from infrahub.trace import configure_trace

if TYPE_CHECKING:
Expand Down Expand Up @@ -69,6 +71,7 @@ async def start(
logging.getLogger("aio_pika").setLevel(logging.ERROR)
logging.getLogger("aiormq").setLevel(logging.ERROR)
logging.getLogger("git").setLevel(logging.ERROR)
logging.getLogger("aiosqlite").setLevel(logging.ERROR)

log.debug(f"Config file : {config_file}")
# Prevent git from interactively prompting the user for passwords if the credentials provided
Expand Down Expand Up @@ -110,6 +113,12 @@ async def start(

database = await context.get_db(retry=1)

workflow = config.OVERRIDE.workflow or (
WorkflowWorkerExecution()
if config.SETTINGS.workflow.driver == config.WorkflowDriver.WORKER
else WorkflowLocalExecution()
)

message_bus = config.OVERRIDE.message_bus or (
NATSMessageBus() if config.SETTINGS.broker.driver == config.BrokerDriver.NATS else RabbitMQMessageBus()
)
Expand All @@ -121,6 +130,7 @@ async def start(
cache=cache,
client=client,
database=database,
workflow=workflow,
message_bus=message_bus,
component_type=ComponentType.GIT_AGENT,
)
Expand Down
30 changes: 30 additions & 0 deletions backend/infrahub/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
if TYPE_CHECKING:
from infrahub.services.adapters.cache import InfrahubCache
from infrahub.services.adapters.message_bus import InfrahubMessageBus
from infrahub.services.adapters.workflow import InfrahubWorkflow


VALID_DATABASE_NAME_REGEX = r"^[a-z][a-z0-9\.]+$"
Expand Down Expand Up @@ -62,6 +63,11 @@ class CacheDriver(str, Enum):
NATS = "nats"


class WorkflowDriver(str, Enum):
LOCAL = "local"
WORKER = "worker"


class MainSettings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="INFRAHUB_")
docs_index_path: str = Field(
Expand Down Expand Up @@ -220,6 +226,24 @@ def service_port(self) -> int:
return self.port or default_ports


class WorkflowSettings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="INFRAHUB_WORKFLOW_")
enable: bool = True
address: str = "localhost"
port: Optional[int] = Field(default=None, ge=1, le=65535, description="Specified if running on a non default port.")
tls_enabled: bool = Field(default=False, description="Indicates if TLS is enabled for the connection")
driver: WorkflowDriver = WorkflowDriver.WORKER

@property
def api_endpoint(self) -> str:
url = "https://" if self.tls_enabled else "http://"
url += self.address
if self.port:
url += f":{self.port}"
url += "/api"
return url


class ApiSettings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="INFRAHUB_API_")
cors_allow_origins: list[str] = Field(
Expand Down Expand Up @@ -337,6 +361,7 @@ class TraceSettings(BaseSettings):
class Override:
message_bus: Optional[InfrahubMessageBus] = None
cache: Optional[InfrahubCache] = None
workflow: Optional[InfrahubWorkflow] = None


@dataclass
Expand Down Expand Up @@ -395,6 +420,10 @@ def broker(self) -> BrokerSettings:
def cache(self) -> CacheSettings:
return self.active_settings.cache

@property
def workflow(self) -> WorkflowSettings:
return self.active_settings.workflow

@property
def miscellaneous(self) -> MiscellaneousSettings:
return self.active_settings.miscellaneous
Expand Down Expand Up @@ -437,6 +466,7 @@ class Settings(BaseSettings):
database: DatabaseSettings = DatabaseSettings()
broker: BrokerSettings = BrokerSettings()
cache: CacheSettings = CacheSettings()
workflow: WorkflowSettings = WorkflowSettings()
miscellaneous: MiscellaneousSettings = MiscellaneousSettings()
logging: LoggingSettings = LoggingSettings()
analytics: AnalyticsSettings = AnalyticsSettings()
Expand Down
8 changes: 6 additions & 2 deletions backend/infrahub/message_bus/messages/send_webhook_event.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from pydantic import Field
from pydantic import BaseModel, Field

from infrahub.message_bus import InfrahubMessage


class SendWebhookEvent(InfrahubMessage):
class SendWebhookData(BaseModel):
"""Sent a webhook to an external source."""

webhook_id: str = Field(..., description="The unique ID of the webhook")
event_type: str = Field(..., description="The event type")
event_data: dict = Field(..., description="The data tied to the event")


class SendWebhookEvent(SendWebhookData, InfrahubMessage):
"""Sent a webhook to an external source."""
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import Optional

from pydantic import Field
from pydantic import BaseModel, Field

from infrahub.message_bus import InfrahubMessage, InfrahubResponse, InfrahubResponseData

ROUTING_KEY = "transform.jinja.template"


class TransformJinjaTemplate(InfrahubMessage):
class TransformJinjaTemplateData(BaseModel):
"""Sent to trigger the checks for a repository to be executed."""

repository_id: str = Field(..., description="The unique ID of the Repository")
Expand All @@ -19,6 +19,10 @@ class TransformJinjaTemplate(InfrahubMessage):
commit: str = Field(..., description="The commit id to use when rendering the template")


class TransformJinjaTemplate(TransformJinjaTemplateData, InfrahubMessage):
"""Sent to trigger the checks for a repository to be executed."""


class TransformJinjaTemplateResponseData(InfrahubResponseData):
rendered_template: Optional[str] = Field(None, description="Rendered template in string format")

Expand Down
36 changes: 35 additions & 1 deletion backend/infrahub/message_bus/operations/send/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,22 @@
from typing import Any

import httpx
from prefect import flow, task
from prefect.logging import get_run_logger

from infrahub import __version__, config
from infrahub.core import registry, utils
from infrahub.core.branch import Branch
from infrahub.core.constants import InfrahubKind
from infrahub.core.graph.schema import GRAPH_SCHEMA
from infrahub.message_bus import messages
from infrahub.services import InfrahubServices
from infrahub.services import InfrahubServices, services

TELEMETRY_KIND: str = "community"
TELEMETRY_VERSION: str = "20240524"


@task
async def gather_database_information(service: InfrahubServices, branch: Branch) -> dict: # pylint: disable=unused-argument
data: dict[str, Any] = {
"database_type": service.database.db_type.value,
Expand All @@ -34,6 +37,7 @@ async def gather_database_information(service: InfrahubServices, branch: Branch)
return data


@task
async def gather_schema_information(service: InfrahubServices, branch: Branch) -> dict: # pylint: disable=unused-argument
data: dict[str, Any] = {}
main_schema = registry.schema.get_schema_branch(name=branch.name)
Expand All @@ -44,6 +48,7 @@ async def gather_schema_information(service: InfrahubServices, branch: Branch) -
return data


@task
async def gather_feature_information(service: InfrahubServices, branch: Branch) -> dict: # pylint: disable=unused-argument
data = {}
features_to_count = [
Expand All @@ -61,6 +66,7 @@ async def gather_feature_information(service: InfrahubServices, branch: Branch)
return data


@task
async def gather_anonymous_telemetry_data(service: InfrahubServices) -> dict:
start_time = time.time()

Expand Down Expand Up @@ -112,3 +118,31 @@ async def push(
response.raise_for_status()
except httpx.HTTPError as exc:
service.log.debug(f"HTTP exception while pushing anonymous telemetry: {exc}")


@task(retries=5)
async def post_telemetry_data(service: InfrahubServices, url: str, payload: dict[str, Any]) -> None: # pylint: disable=unused-argument
"""Send the telemetry data to the specified URL, using HTTP POST."""
async with httpx.AsyncClient() as client:
response = await client.post(url=url, json=payload)
response.raise_for_status()


@flow
async def send_telemetry_push() -> None:
service = services.service

log = get_run_logger()
log.info(f"Pushing anonymous telemetry data to {config.SETTINGS.main.telemetry_endpoint}...")

data = await gather_anonymous_telemetry_data(service=service)
log.info(f"Anonymous usage telemetry gathered in {data['execution_time']} seconds. | {data}")

payload = {
"kind": TELEMETRY_KIND,
"payload_format": TELEMETRY_VERSION,
"data": data,
"checksum": hashlib.sha256(json.dumps(data).encode()).hexdigest(),
}

await post_telemetry_data(service=service, url=config.SETTINGS.main.telemetry_endpoint, payload=payload)
32 changes: 31 additions & 1 deletion backend/infrahub/message_bus/operations/send/webhook.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Any

import ujson
from prefect import flow
from prefect.logging import get_run_logger

from infrahub.exceptions import NodeNotFoundError
from infrahub.message_bus import messages
from infrahub.services import InfrahubServices
from infrahub.message_bus.messages.send_webhook_event import SendWebhookData
from infrahub.services import InfrahubServices, services
from infrahub.webhook import CustomWebhook, StandardWebhook, TransformWebhook, Webhook


Expand Down Expand Up @@ -35,3 +38,30 @@ async def event(message: messages.SendWebhookEvent, service: InfrahubServices) -
title=webhook.webhook_type,
logs={"message": "Successfully sent webhook", "severity": "INFO"},
)


@flow
async def send_webhook(message: SendWebhookData) -> None:
service = services.service
log = get_run_logger()

webhook_definition = await service.cache.get(key=f"webhook:active:{message.webhook_id}")
if not webhook_definition:
log.warning("Webhook not found")
raise NodeNotFoundError(
node_type="Webhook", identifier=message.webhook_id, message="The requested Webhook was not found"
)

webhook_data = ujson.loads(webhook_definition)
payload: dict[str, Any] = {"event_type": message.event_type, "data": message.event_data, "service": service}
webhook_map: dict[str, type[Webhook]] = {
"standard": StandardWebhook,
"custom": CustomWebhook,
"transform": TransformWebhook,
}
webhook_class = webhook_map[webhook_data["webhook_type"]]
payload.update(webhook_data["webhook_configuration"])
webhook = webhook_class(**payload)
await webhook.send()

log.info("Successfully sent webhook")
22 changes: 21 additions & 1 deletion backend/infrahub/message_bus/operations/transform/jinja.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from prefect import flow

from infrahub.git.repository import get_initialized_repo
from infrahub.log import get_logger
from infrahub.message_bus.messages.transform_jinja_template import (
TransformJinjaTemplate,
TransformJinjaTemplateData,
TransformJinjaTemplateResponse,
TransformJinjaTemplateResponseData,
)
from infrahub.services import InfrahubServices
from infrahub.services import InfrahubServices, services

log = get_logger()

Expand All @@ -28,3 +31,20 @@ async def template(message: TransformJinjaTemplate, service: InfrahubServices) -
data=TransformJinjaTemplateResponseData(rendered_template=rendered_template),
)
await service.reply(message=response, initiator=message)


@flow(persist_result=True)
async def transform_render_jinja2_template(message: TransformJinjaTemplateData) -> str:
service = services.service
repo = await get_initialized_repo(
repository_id=message.repository_id,
name=message.repository_name,
service=service,
repository_kind=message.repository_kind,
)

rendered_template = await repo.render_jinja2_template(
commit=message.commit, location=message.template_location, data={"data": message.data}
)

return rendered_template
14 changes: 13 additions & 1 deletion backend/infrahub/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from infrahub.services.adapters.cache.redis import RedisCache
from infrahub.services.adapters.message_bus.nats import NATSMessageBus
from infrahub.services.adapters.message_bus.rabbitmq import RabbitMQMessageBus
from infrahub.services.adapters.workflow.local import WorkflowLocalExecution
from infrahub.services.adapters.workflow.worker import WorkflowWorkerExecution
from infrahub.trace import add_span_exception, configure_trace, get_traceid
from infrahub.worker import WORKER_IDENTITY

Expand All @@ -62,14 +64,24 @@ async def app_initialization(application: FastAPI) -> None:

build_component_registry()

workflow = config.OVERRIDE.workflow or (
WorkflowWorkerExecution()
if config.SETTINGS.workflow.driver == config.WorkflowDriver.WORKER
else WorkflowLocalExecution()
)

message_bus = config.OVERRIDE.message_bus or (
NATSMessageBus() if config.SETTINGS.broker.driver == config.BrokerDriver.NATS else RabbitMQMessageBus()
)
cache = config.OVERRIDE.cache or (
NATSCache() if config.SETTINGS.cache.driver == config.CacheDriver.NATS else RedisCache()
)
service = InfrahubServices(
cache=cache, database=database, message_bus=message_bus, component_type=ComponentType.API_SERVER
cache=cache,
database=database,
message_bus=message_bus,
workflow=workflow,
component_type=ComponentType.API_SERVER,
)
await service.initialize()
initialize_lock(service=service)
Expand Down
Loading

0 comments on commit 76bf288

Please sign in to comment.