From 86c6c1d6c8bc8b83d659a0e6a070b2f7566b5afa Mon Sep 17 00:00:00 2001 From: z3z1ma Date: Sun, 18 Aug 2024 23:47:54 -0700 Subject: [PATCH] feat!: match improve interface for publishers and lift bare funcs in validators --- src/cdf/core/component/base.py | 8 ++- src/cdf/core/component/pipeline.py | 2 + src/cdf/core/component/publisher.py | 43 +++++++++------- src/cdf/core/workspace.py | 79 +++++++++++------------------ 4 files changed, 63 insertions(+), 69 deletions(-) diff --git a/src/cdf/core/component/base.py b/src/cdf/core/component/base.py index 1f8ffe6..d7b9a47 100644 --- a/src/cdf/core/component/base.py +++ b/src/cdf/core/component/base.py @@ -193,8 +193,10 @@ def __call__(self) -> T: @pydantic.model_validator(mode="before") @classmethod - def _parse_metadata(cls, data: t.Any) -> t.Any: + def _parse_func(cls, data: t.Any) -> t.Any: """Parse node metadata.""" + if inspect.isfunction(data): + data = {"main": data} if isinstance(data, dict): dep = data["main"] if isinstance(dep, dict): @@ -243,8 +245,10 @@ class Entrypoint(_Node, t.Generic[T], frozen=True): @pydantic.model_validator(mode="before") @classmethod - def _parse_metadata(cls, data: t.Any) -> t.Any: + def _parse_func(cls, data: t.Any) -> t.Any: """Parse node metadata.""" + if inspect.isfunction(data): + data = {"main": data} if isinstance(data, dict): func = _unwrap_entrypoint(data["main"]) return {**_parse_metadata_from_callable(func), **data} diff --git a/src/cdf/core/component/pipeline.py b/src/cdf/core/component/pipeline.py index b41ebbd..c9ac985 100644 --- a/src/cdf/core/component/pipeline.py +++ b/src/cdf/core/component/pipeline.py @@ -42,5 +42,7 @@ def get_schemas(self, destination: t.Optional["DltDestination"] = None): def run_tests(self) -> None: """Run the integration test for the pipeline.""" _, _, tests = self.main() + if not tests: + raise ValueError("No tests found for pipeline") for test in tests: test() diff --git a/src/cdf/core/component/publisher.py b/src/cdf/core/component/publisher.py index 8b26fc5..e810ded 100644 --- a/src/cdf/core/component/publisher.py +++ b/src/cdf/core/component/publisher.py @@ -1,8 +1,6 @@ import typing as t -import pydantic - -from .base import Entrypoint, _get_bind_func, _unwrap_entrypoint +from .base import Entrypoint def _ping() -> bool: @@ -10,23 +8,30 @@ def _ping() -> bool: return bool("pong") -class DataPublisher(Entrypoint[t.Any], frozen=True): +class DataPublisher( + Entrypoint[ + t.Tuple[ + t.Callable[..., None], # run + t.Callable[..., bool], # preflight + t.Optional[t.Callable[..., None]], # success hook + t.Optional[t.Callable[..., None]], # failure hook + ] + ], + frozen=True, +): """A data publisher which pushes data to an operational system.""" - preflight_check: t.Callable[..., bool] = _ping - """A user defined function to check if the data publisher is able to publish data""" - - integration_test: t.Optional[t.Callable[..., bool]] = None - """A function to test the data publisher in an integration environment""" - - @pydantic.field_validator("preflight_check", "integration_test", mode="before") - @classmethod - def _bind_ancillary(cls, value: t.Any, info: pydantic.ValidationInfo) -> t.Any: - """Bind the active workspace to the ancillary functions.""" - return _get_bind_func(info)(_unwrap_entrypoint(value)) - def __call__(self, *args: t.Any, **kwargs: t.Any) -> None: """Publish the data""" - if not self.preflight_check(): - raise RuntimeError("Preflight check failed") - return self.main(*args, **kwargs) + publisher, pre, success, err = self.main(*args, **kwargs) + if not pre(): + raise ValueError("Preflight check failed") + try: + return publisher() + except Exception as e: + if err: + err() + raise e + else: + if success: + success() diff --git a/src/cdf/core/workspace.py b/src/cdf/core/workspace.py index 72a7d34..a8c4851 100644 --- a/src/cdf/core/workspace.py +++ b/src/cdf/core/workspace.py @@ -211,7 +211,7 @@ def _list(d: t.Dict[str, cmp.TComponent], verbose: bool = False) -> None: @cli.command("run-pipeline") @click.argument( - "pipeline", + "pipeline_name", required=False, type=click.Choice(list(self.pipelines.keys())), ) @@ -227,7 +227,6 @@ def run_pipeline( test: bool = False, ) -> None: """Run a data pipeline.""" - # Prompt for a pipeline if not specified if pipeline_name is None: pipeline_name = click.prompt( "Enter a pipeline", @@ -246,10 +245,10 @@ def run_pipeline( try: pipeline.run_tests() except Exception as e: - click.echo(f"Pipeline test failed: {e}", err=True) + click.echo(f"Pipeline test(s) failed: {e}", err=True) ctx.exit(1) else: - click.echo("Integration test passed.", err=True) + click.echo("Pipeline test(s) passed!", err=True) ctx.exit(0) start = time.time() @@ -275,62 +274,45 @@ def run_pipeline( @cli.command("run-publisher") @click.argument( - "publisher", required=False, type=click.Choice(list(self.publishers.keys())) + "publisher_name", + required=False, + type=click.Choice(list(self.publishers.keys())), ) @click.option( "--test", is_flag=True, help="Run the publishers integration test if defined.", ) - @click.option( - "--skip-preflight-check", - is_flag=True, - help="Skip the pre-check for the publisher.", - ) @click.pass_context def run_publisher( ctx: click.Context, - publisher: t.Optional[str] = None, + publisher_name: t.Optional[str] = None, test: bool = False, - skip_preflight_check: bool = False, ) -> None: """Run a data publisher.""" - # Prompt for a publisher if not specified - if publisher is None: - publisher = click.prompt( + if publisher_name is None: + publisher_name = click.prompt( "Enter a publisher", type=click.Choice(list(self.publishers.keys())), show_choices=True, ) - if publisher is None: + if publisher_name is None: raise click.BadParameter( "Publisher must be specified.", ctx=ctx, param_hint="publisher" ) - # Get the publisher definition - publisher_definition = self.publishers[publisher] - - # Run the integration test if specified - if test: - if not publisher_definition.integration_test: - raise click.UsageError( - f"Publisher `{publisher}` does not define an integration test." - ) - click.echo("Running integration test.", err=True) - if publisher_definition.integration_test(): - click.echo("Integration test passed.", err=True) - ctx.exit(0) - else: - ctx.fail("Integration test failed.") - - # Optionally run the preflight check - if not skip_preflight_check: - if not publisher_definition.preflight_check(): - ctx.fail("Preflight-check failed.") + publisher = self.publishers[publisher_name] - # Run the publisher start = time.time() - click.echo(publisher_definition()) + try: + publisher() + except Exception as e: + click.echo( + f"Publisher failed after {time.time() - start:.2f} seconds: {e}", + err=True, + ) + ctx.exit(1) + click.echo( f"Publisher process finished in {time.time() - start:.2f} seconds.", err=True, @@ -339,28 +321,29 @@ def run_publisher( @cli.command("run-operation") @click.argument( - "operation", required=False, type=click.Choice(list(self.operations.keys())) + "operation_name", + required=False, + type=click.Choice(list(self.operations.keys())), ) @click.pass_context - def run_operation(ctx: click.Context, operation: t.Optional[str] = None) -> int: + def run_operation( + ctx: click.Context, operation_name: t.Optional[str] = None + ) -> int: """Run an operation.""" - # Prompt for an operation if not specified - if operation is None: - operation = click.prompt( + if operation_name is None: + operation_name = click.prompt( "Enter an operation", type=click.Choice(list(self.operations.keys())), show_choices=True, ) - if operation is None: + if operation_name is None: raise click.BadParameter( "Operation must be specified.", ctx=ctx, param_hint="operation" ) - # Get the operation definition - operation_definition = self.operations[operation] + operation = self.operations[operation_name] - # Run the operation - ctx.exit(operation_definition()) + ctx.exit(operation()) return cli