diff --git a/sdk/python/kfp/cli/compile_.py b/sdk/python/kfp/cli/compile_.py index 2bd3bab18c2..9552370dc40 100644 --- a/sdk/python/kfp/cli/compile_.py +++ b/sdk/python/kfp/cli/compile_.py @@ -22,6 +22,7 @@ import click from kfp import compiler +from kfp.cli.utils import parsing from kfp.dsl import base_component from kfp.dsl import graph_component @@ -133,12 +134,19 @@ def parse_parameters(parameters: Optional[str]) -> Dict: is_flag=True, default=False, help='Whether to disable type checking.') +@click.option( + '--enable-caching/--disable-caching', + type=bool, + default=None, + help=parsing.get_param_descr(compiler.Compiler.compile, 'enable_caching'), +) def compile_( py: str, output: str, function_name: Optional[str] = None, pipeline_parameters: Optional[str] = None, disable_type_check: bool = False, + enable_caching: Optional[bool] = None, ) -> None: """Compiles a pipeline or component written in a .py file.""" pipeline_func = collect_pipeline_or_component_func( @@ -149,7 +157,8 @@ def compile_( pipeline_func=pipeline_func, pipeline_parameters=parsed_parameters, package_path=package_path, - type_check=not disable_type_check) + type_check=not disable_type_check, + enable_caching=enable_caching) click.echo(package_path) diff --git a/sdk/python/kfp/client/client.py b/sdk/python/kfp/client/client.py index f8897236343..79358143a1c 100644 --- a/sdk/python/kfp/client/client.py +++ b/sdk/python/kfp/client/client.py @@ -33,6 +33,7 @@ from kfp.client import auth from kfp.client import set_volume_credentials from kfp.client.token_credentials_base import TokenCredentialsBase +from kfp.compiler.compiler import override_caching_options from kfp.dsl import base_component from kfp.pipeline_spec import pipeline_spec_pb2 import kfp_server_api @@ -955,8 +956,8 @@ def _create_job_config( # Caching option set at submission time overrides the compile time # settings. if enable_caching is not None: - _override_caching_options(pipeline_doc.pipeline_spec, - enable_caching) + override_caching_options(pipeline_doc.pipeline_spec, + enable_caching) pipeline_spec = pipeline_doc.to_dict() pipeline_version_reference = None @@ -1676,17 +1677,3 @@ def _safe_load_yaml(stream: TextIO) -> _PipelineDoc: raise ValueError( f'The package_file {package_file} should end with one of the ' 'following formats: [.tar.gz, .tgz, .zip, .yaml, .yml].') - - -def _override_caching_options( - pipeline_spec: pipeline_spec_pb2.PipelineSpec, - enable_caching: bool, -) -> None: - """Overrides caching options. - - Args: - pipeline_spec: The PipelineSpec object to update in-place. - enable_caching: Overrides options, one of True, False. - """ - for _, task_spec in pipeline_spec.root.dag.tasks.items(): - task_spec.caching_options.enable_cache = enable_caching diff --git a/sdk/python/kfp/client/client_test.py b/sdk/python/kfp/client/client_test.py index 301ec6d119b..43b48872d5e 100644 --- a/sdk/python/kfp/client/client_test.py +++ b/sdk/python/kfp/client/client_test.py @@ -24,7 +24,8 @@ from google.protobuf import json_format from kfp.client import auth from kfp.client import client -from kfp.compiler import Compiler +from kfp.compiler.compiler import Compiler +from kfp.compiler.compiler import override_caching_options from kfp.dsl import component from kfp.dsl import pipeline from kfp.pipeline_spec import pipeline_spec_pb2 @@ -88,7 +89,7 @@ def pipeline_with_two_component(text: str = 'hi there'): pipeline_obj = yaml.safe_load(f) pipeline_spec = json_format.ParseDict( pipeline_obj, pipeline_spec_pb2.PipelineSpec()) - client._override_caching_options(pipeline_spec, True) + override_caching_options(pipeline_spec, True) pipeline_obj = json_format.MessageToDict(pipeline_spec) self.assertTrue(pipeline_obj['root']['dag']['tasks'] ['hello-word']['cachingOptions']['enableCache']) diff --git a/sdk/python/kfp/compiler/compiler.py b/sdk/python/kfp/compiler/compiler.py index a77f606e89c..3f54b587f3f 100644 --- a/sdk/python/kfp/compiler/compiler.py +++ b/sdk/python/kfp/compiler/compiler.py @@ -22,6 +22,7 @@ from kfp.compiler import pipeline_spec_builder as builder from kfp.dsl import base_component from kfp.dsl.types import type_utils +from kfp.pipeline_spec import pipeline_spec_pb2 class Compiler: @@ -53,6 +54,7 @@ def compile( pipeline_name: Optional[str] = None, pipeline_parameters: Optional[Dict[str, Any]] = None, type_check: bool = True, + enable_caching: Optional[bool] = None, ) -> None: """Compiles the pipeline or component function into IR YAML. @@ -62,6 +64,12 @@ def compile( pipeline_name: Name of the pipeline. pipeline_parameters: Map of parameter names to argument values. type_check: Whether to enable type checking of component interfaces during compilation. + enable_caching: Whether or not to enable caching for the + run. If not set, defaults to the compile time settings, which + is ``True`` for all tasks by default, while users may specify + different caching options for individual tasks. If set, the + setting applies to all tasks in the pipeline (overrides the + compile time settings). """ with type_utils.TypeCheckManager(enable=type_check): @@ -78,9 +86,26 @@ def compile( pipeline_parameters=pipeline_parameters, ) + if enable_caching is not None: + override_caching_options(pipeline_spec, enable_caching) + builder.write_pipeline_spec_to_file( pipeline_spec=pipeline_spec, pipeline_description=pipeline_func.description, platform_spec=pipeline_func.platform_spec, package_path=package_path, ) + + +def override_caching_options( + pipeline_spec: pipeline_spec_pb2.PipelineSpec, + enable_caching: bool, +) -> None: + """Overrides caching options. + + Args: + pipeline_spec: The PipelineSpec object to update in-place. + enable_caching: Overrides options, one of True, False. + """ + for _, task_spec in pipeline_spec.root.dag.tasks.items(): + task_spec.caching_options.enable_cache = enable_caching