Skip to content

Commit

Permalink
Add tests for disabling default caching var and flag
Browse files Browse the repository at this point in the history
Signed-off-by: ddalvi <[email protected]>
  • Loading branch information
DharmitD committed Sep 19, 2024
1 parent 600624d commit eedb220
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 7 deletions.
76 changes: 76 additions & 0 deletions sdk/python/kfp/cli/cli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,82 @@ def test_deprecation_warning(self):
res.stdout.decode('utf-8'))


class TestKfpDslCompile(unittest.TestCase):

def invoke(self, args):
starting_args = ['dsl', 'compile']
args = starting_args + args
runner = testing.CliRunner()
return runner.invoke(
cli=cli.cli, args=args, catch_exceptions=False, obj={})

def test_compile_with_caching_flag_enabled(self):
with tempfile.NamedTemporaryFile(suffix='.py') as temp_pipeline:
# Write the pipeline function to a temporary file
temp_pipeline.write(b"""
from kfp import dsl
@dsl.component
def my_component():
pass
@dsl.pipeline(name="tiny-pipeline")
def my_pipeline():
my_component_task = my_component()
""")
temp_pipeline.flush() # Ensure the data is written to disk

# Invoke the CLI command with the temporary file
result = self.invoke(
['--py', temp_pipeline.name, '--output', 'test_output.yaml'])
print(result.output) # Print the command output
self.assertEqual(result.exit_code, 0)

def test_compile_with_caching_flag_disabled(self):
with tempfile.NamedTemporaryFile(suffix='.py') as temp_pipeline:
temp_pipeline.write(b"""
from kfp import dsl
@dsl.component
def my_component():
pass
@dsl.pipeline(name="tiny-pipeline")
def my_pipeline():
my_component_task = my_component()
""")
temp_pipeline.flush()

result = self.invoke([
'--py', temp_pipeline.name, '--output', 'test_output.yaml',
'--disable-execution-caching-by-default'
])
print(result.output)
self.assertEqual(result.exit_code, 0)

def test_compile_with_caching_disabled_env_var(self):
with tempfile.NamedTemporaryFile(suffix='.py') as temp_pipeline:
temp_pipeline.write(b"""
from kfp import dsl
@dsl.component
def my_component():
pass
@dsl.pipeline(name="tiny-pipeline")
def my_pipeline():
my_component_task = my_component()
""")
temp_pipeline.flush()

os.environ['KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT'] = 'true'
result = self.invoke(
['--py', temp_pipeline.name, '--output', 'test_output.yaml'])
print(result.output)
self.assertEqual(result.exit_code, 0)
del os.environ['KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT']


info_dict = cli.cli.to_info_dict(ctx=click.Context(cli.cli))
commands_dict = {
command: list(body.get('commands', {}).keys())
Expand Down
1 change: 1 addition & 0 deletions sdk/python/kfp/cli/compile_.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from kfp.dsl import graph_component
from kfp.dsl.pipeline_context import Pipeline


def is_pipeline_func(func: Callable) -> bool:
"""Checks if a function is a pipeline function.
Expand Down
53 changes: 53 additions & 0 deletions sdk/python/kfp/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,59 @@ def my_pipeline() -> NamedTuple('Outputs', [
task = print_and_return(text='Hello')


class TestCompilePipelineCaching(unittest.TestCase):

def test_compile_pipeline_with_caching_enabled(self):
"""Test pipeline compilation with caching enabled."""

@dsl.component
def my_component():
pass

@dsl.pipeline(name='tiny-pipeline')
def my_pipeline():
my_task = my_component()
my_task.set_caching_options(True)

with tempfile.TemporaryDirectory() as tempdir:
output_yaml = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=output_yaml)

with open(output_yaml, 'r') as f:
pipeline_spec = yaml.safe_load(f)

task_spec = pipeline_spec['root']['dag']['tasks']['my-component']
caching_options = task_spec['cachingOptions']

self.assertTrue(caching_options['enableCache'])

def test_compile_pipeline_with_caching_disabled(self):
"""Test pipeline compilation with caching disabled."""

@dsl.component
def my_component():
pass

@dsl.pipeline(name='tiny-pipeline')
def my_pipeline():
my_task = my_component()
my_task.set_caching_options(False)

with tempfile.TemporaryDirectory() as tempdir:
output_yaml = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=output_yaml)

with open(output_yaml, 'r') as f:
pipeline_spec = yaml.safe_load(f)

task_spec = pipeline_spec['root']['dag']['tasks']['my-component']
caching_options = task_spec.get('cachingOptions', {})

self.assertEqual(caching_options, {})


class V2NamespaceAliasTest(unittest.TestCase):
"""Test that imports of both modules and objects are aliased (e.g. all
import path variants work)."""
Expand Down
6 changes: 3 additions & 3 deletions sdk/python/kfp/dsl/base_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, component_spec: structures.ComponentSpec):
# (backend) system to pass a value.
self._component_inputs = {
input_name for input_name, input_spec in (
self.component_spec.inputs or {}).items()
self.component_spec.inputs or {}).items()
if not type_utils.is_task_final_status_type(input_spec.type)
}

Expand Down Expand Up @@ -102,7 +102,7 @@ def __call__(self, *args, **kwargs) -> pipeline_task.PipelineTask:
component_spec=self.component_spec,
args=task_inputs,
execute_locally=pipeline_context.Pipeline.get_default_pipeline() is
None,
None,
execution_caching_default=pipeline_context.Pipeline.get_execution_caching_default(),
)

Expand Down Expand Up @@ -130,7 +130,7 @@ def execute(self, **kwargs):
def required_inputs(self) -> List[str]:
return [
input_name for input_name, input_spec in (
self.component_spec.inputs or {}).items()
self.component_spec.inputs or {}).items()
if not input_spec.optional
]

Expand Down
7 changes: 4 additions & 3 deletions sdk/python/kfp/dsl/pipeline_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@
"""Definition for Pipeline."""

import functools
import os
from typing import Callable, Optional

from kfp.dsl import component_factory
from kfp.dsl import pipeline_task
from kfp.dsl import tasks_group
from kfp.dsl import utils

import os


def pipeline(func: Optional[Callable] = None,
*,
Expand Down Expand Up @@ -107,7 +106,9 @@ def get_default_pipeline():
# or the env var KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT.
# align with click's treatment of env vars for boolean flags.
# per click doc, "1", "true", "t", "yes", "y", and "on" are all converted to True
_execution_caching_default = not str(os.getenv('KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT')).strip().lower() in {"1", "true", "t", "yes", "y", "on"}
_execution_caching_default = not str(
os.getenv('KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT')).strip().lower(
) in {'1', 'true', 't', 'yes', 'y', 'on'}

@staticmethod
def get_execution_caching_default():
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/kfp/dsl/pipeline_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def validate_placeholder_types(
task_name=self._task_spec.name,
is_artifact_list=output_spec.is_artifact_list,
) for output_name, output_spec in (
component_spec.outputs or {}).items()
component_spec.outputs or {}).items()
}

self._inputs = args
Expand Down

0 comments on commit eedb220

Please sign in to comment.