Skip to content

Commit

Permalink
feat!: match improve interface for publishers and lift bare funcs in …
Browse files Browse the repository at this point in the history
…validators
  • Loading branch information
z3z1ma committed Aug 19, 2024
1 parent a9a4c2a commit 86c6c1d
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 69 deletions.
8 changes: 6 additions & 2 deletions src/cdf/core/component/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,10 @@ def __call__(self) -> T:

@pydantic.model_validator(mode="before")
@classmethod
def _parse_metadata(cls, data: t.Any) -> t.Any:
def _parse_func(cls, data: t.Any) -> t.Any:
"""Parse node metadata."""
if inspect.isfunction(data):
data = {"main": data}
if isinstance(data, dict):
dep = data["main"]
if isinstance(dep, dict):
Expand Down Expand Up @@ -243,8 +245,10 @@ class Entrypoint(_Node, t.Generic[T], frozen=True):

@pydantic.model_validator(mode="before")
@classmethod
def _parse_metadata(cls, data: t.Any) -> t.Any:
def _parse_func(cls, data: t.Any) -> t.Any:
"""Parse node metadata."""
if inspect.isfunction(data):
data = {"main": data}
if isinstance(data, dict):
func = _unwrap_entrypoint(data["main"])
return {**_parse_metadata_from_callable(func), **data}
Expand Down
2 changes: 2 additions & 0 deletions src/cdf/core/component/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,7 @@ def get_schemas(self, destination: t.Optional["DltDestination"] = None):
def run_tests(self) -> None:
"""Run the integration test for the pipeline."""
_, _, tests = self.main()
if not tests:
raise ValueError("No tests found for pipeline")
for test in tests:
test()
43 changes: 24 additions & 19 deletions src/cdf/core/component/publisher.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,37 @@
import typing as t

import pydantic

from .base import Entrypoint, _get_bind_func, _unwrap_entrypoint
from .base import Entrypoint


def _ping() -> bool:
"""A default preflight check which always returns True."""
return bool("pong")


class DataPublisher(Entrypoint[t.Any], frozen=True):
class DataPublisher(
Entrypoint[
t.Tuple[
t.Callable[..., None], # run
t.Callable[..., bool], # preflight
t.Optional[t.Callable[..., None]], # success hook
t.Optional[t.Callable[..., None]], # failure hook
]
],
frozen=True,
):
"""A data publisher which pushes data to an operational system."""

preflight_check: t.Callable[..., bool] = _ping
"""A user defined function to check if the data publisher is able to publish data"""

integration_test: t.Optional[t.Callable[..., bool]] = None
"""A function to test the data publisher in an integration environment"""

@pydantic.field_validator("preflight_check", "integration_test", mode="before")
@classmethod
def _bind_ancillary(cls, value: t.Any, info: pydantic.ValidationInfo) -> t.Any:
"""Bind the active workspace to the ancillary functions."""
return _get_bind_func(info)(_unwrap_entrypoint(value))

def __call__(self, *args: t.Any, **kwargs: t.Any) -> None:
"""Publish the data"""
if not self.preflight_check():
raise RuntimeError("Preflight check failed")
return self.main(*args, **kwargs)
publisher, pre, success, err = self.main(*args, **kwargs)
if not pre():
raise ValueError("Preflight check failed")
try:
return publisher()
except Exception as e:
if err:
err()
raise e
else:
if success:
success()
79 changes: 31 additions & 48 deletions src/cdf/core/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _list(d: t.Dict[str, cmp.TComponent], verbose: bool = False) -> None:

@cli.command("run-pipeline")
@click.argument(
"pipeline",
"pipeline_name",
required=False,
type=click.Choice(list(self.pipelines.keys())),
)
Expand All @@ -227,7 +227,6 @@ def run_pipeline(
test: bool = False,
) -> None:
"""Run a data pipeline."""
# Prompt for a pipeline if not specified
if pipeline_name is None:
pipeline_name = click.prompt(
"Enter a pipeline",
Expand All @@ -246,10 +245,10 @@ def run_pipeline(
try:
pipeline.run_tests()
except Exception as e:
click.echo(f"Pipeline test failed: {e}", err=True)
click.echo(f"Pipeline test(s) failed: {e}", err=True)
ctx.exit(1)
else:
click.echo("Integration test passed.", err=True)
click.echo("Pipeline test(s) passed!", err=True)
ctx.exit(0)

start = time.time()
Expand All @@ -275,62 +274,45 @@ def run_pipeline(

@cli.command("run-publisher")
@click.argument(
"publisher", required=False, type=click.Choice(list(self.publishers.keys()))
"publisher_name",
required=False,
type=click.Choice(list(self.publishers.keys())),
)
@click.option(
"--test",
is_flag=True,
help="Run the publishers integration test if defined.",
)
@click.option(
"--skip-preflight-check",
is_flag=True,
help="Skip the pre-check for the publisher.",
)
@click.pass_context
def run_publisher(
ctx: click.Context,
publisher: t.Optional[str] = None,
publisher_name: t.Optional[str] = None,
test: bool = False,
skip_preflight_check: bool = False,
) -> None:
"""Run a data publisher."""
# Prompt for a publisher if not specified
if publisher is None:
publisher = click.prompt(
if publisher_name is None:
publisher_name = click.prompt(
"Enter a publisher",
type=click.Choice(list(self.publishers.keys())),
show_choices=True,
)
if publisher is None:
if publisher_name is None:
raise click.BadParameter(
"Publisher must be specified.", ctx=ctx, param_hint="publisher"
)

# Get the publisher definition
publisher_definition = self.publishers[publisher]

# Run the integration test if specified
if test:
if not publisher_definition.integration_test:
raise click.UsageError(
f"Publisher `{publisher}` does not define an integration test."
)
click.echo("Running integration test.", err=True)
if publisher_definition.integration_test():
click.echo("Integration test passed.", err=True)
ctx.exit(0)
else:
ctx.fail("Integration test failed.")

# Optionally run the preflight check
if not skip_preflight_check:
if not publisher_definition.preflight_check():
ctx.fail("Preflight-check failed.")
publisher = self.publishers[publisher_name]

# Run the publisher
start = time.time()
click.echo(publisher_definition())
try:
publisher()
except Exception as e:
click.echo(
f"Publisher failed after {time.time() - start:.2f} seconds: {e}",
err=True,
)
ctx.exit(1)

click.echo(
f"Publisher process finished in {time.time() - start:.2f} seconds.",
err=True,
Expand All @@ -339,28 +321,29 @@ def run_publisher(

@cli.command("run-operation")
@click.argument(
"operation", required=False, type=click.Choice(list(self.operations.keys()))
"operation_name",
required=False,
type=click.Choice(list(self.operations.keys())),
)
@click.pass_context
def run_operation(ctx: click.Context, operation: t.Optional[str] = None) -> int:
def run_operation(
ctx: click.Context, operation_name: t.Optional[str] = None
) -> int:
"""Run an operation."""
# Prompt for an operation if not specified
if operation is None:
operation = click.prompt(
if operation_name is None:
operation_name = click.prompt(
"Enter an operation",
type=click.Choice(list(self.operations.keys())),
show_choices=True,
)
if operation is None:
if operation_name is None:
raise click.BadParameter(
"Operation must be specified.", ctx=ctx, param_hint="operation"
)

# Get the operation definition
operation_definition = self.operations[operation]
operation = self.operations[operation_name]

# Run the operation
ctx.exit(operation_definition())
ctx.exit(operation())

return cli

Expand Down

0 comments on commit 86c6c1d

Please sign in to comment.