Skip to content

Commit

Permalink
Add tests for disabling default caching var and flag
Browse files Browse the repository at this point in the history
  • Loading branch information
DharmitD committed Sep 18, 2024
1 parent bad5d70 commit c305fa5
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 1 deletion.
73 changes: 72 additions & 1 deletion sdk/python/kfp/cli/cli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,81 @@ def test_deprecated_command_is_found(self):

def test_deprecation_warning(self):
res = subprocess.run(['dsl-compile', '--help'], capture_output=True)
self.assertIn('Deprecated. Please use `kfp dsl compile` instead.)',
self.assertIn('Deprecated. Please use `;2` instead.)',
res.stdout.decode('utf-8'))


class TestKfpDslCompile(unittest.TestCase):

def invoke(self, args):
starting_args = ['dsl', 'compile']
args = starting_args + args
runner = testing.CliRunner()
return runner.invoke(cli=cli.cli, args=args, catch_exceptions=False, obj={})

def test_compile_with_caching_flag_enabled(self):
with tempfile.NamedTemporaryFile(suffix='.py') as temp_pipeline:
# Write the pipeline function to a temporary file
temp_pipeline.write(b"""
from kfp import dsl
@dsl.component
def my_component():
pass
@dsl.pipeline(name="tiny-pipeline")
def my_pipeline():
my_component_task = my_component()
""")
temp_pipeline.flush() # Ensure the data is written to disk

# Invoke the CLI command with the temporary file
result = self.invoke(['--py', temp_pipeline.name, '--output', 'test_output.yaml'])
print(result.output) # Print the command output
self.assertEqual(result.exit_code, 0)

def test_compile_with_caching_flag_disabled(self):
with tempfile.NamedTemporaryFile(suffix='.py') as temp_pipeline:
temp_pipeline.write(b"""
from kfp import dsl
@dsl.component
def my_component():
pass
@dsl.pipeline(name="tiny-pipeline")
def my_pipeline():
my_component_task = my_component()
""")
temp_pipeline.flush()

result = self.invoke(
['--py', temp_pipeline.name, '--output', 'test_output.yaml', '--disable-execution-caching-by-default']
)
print(result.output)
self.assertEqual(result.exit_code, 0)

def test_compile_with_caching_disabled_env_var(self):
with tempfile.NamedTemporaryFile(suffix='.py') as temp_pipeline:
temp_pipeline.write(b"""
from kfp import dsl
@dsl.component
def my_component():
pass
@dsl.pipeline(name="tiny-pipeline")
def my_pipeline():
my_component_task = my_component()
""")
temp_pipeline.flush()

os.environ['KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT'] = 'true'
result = self.invoke(['--py', temp_pipeline.name, '--output', 'test_output.yaml'])
print(result.output)
self.assertEqual(result.exit_code, 0)
del os.environ['KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT']

info_dict = cli.cli.to_info_dict(ctx=click.Context(cli.cli))
commands_dict = {
command: list(body.get('commands', {}).keys())
Expand Down
79 changes: 79 additions & 0 deletions sdk/python/kfp/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,85 @@ def my_pipeline() -> NamedTuple('Outputs', [
task = print_and_return(text='Hello')


class TestCompilePipelineCaching(unittest.TestCase):

def test_compile_pipeline_with_caching_enabled(self):
"""Test pipeline compilation with caching enabled."""

@dsl.component
def my_component():
pass

@dsl.pipeline(name="tiny-pipeline")
def my_pipeline():
my_task = my_component()
my_task.set_caching_options(True)

with tempfile.TemporaryDirectory() as tempdir:
output_yaml = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(pipeline_func=my_pipeline, package_path=output_yaml)

with open(output_yaml, 'r') as f:
pipeline_spec = yaml.safe_load(f)

task_spec = pipeline_spec['root']['dag']['tasks']['my-component']
caching_options = task_spec['cachingOptions']

self.assertTrue(caching_options['enableCache'])

def test_compile_pipeline_with_caching_disabled(self):
"""Test pipeline compilation with caching disabled."""

@dsl.component
def my_component():
pass

@dsl.pipeline(name="tiny-pipeline")
def my_pipeline():
my_task = my_component()
my_task.set_caching_options(False)

with tempfile.TemporaryDirectory() as tempdir:
output_yaml = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(pipeline_func=my_pipeline, package_path=output_yaml)

with open(output_yaml, 'r') as f:
pipeline_spec = yaml.safe_load(f)

task_spec = pipeline_spec['root']['dag']['tasks']['my-component']
caching_options = task_spec.get('cachingOptions', {})

self.assertEqual(caching_options, {})

class TestCompilePipelineCachingEnvVar(unittest.TestCase):

def test_env_var_true_lowercase(self):
"""Test pipeline compilation with caching disabled when env var is 'true'."""

@dsl.component
def my_component():
pass

@dsl.pipeline(name="tiny-pipeline")
def my_pipeline():
my_task = my_component()

os.environ['KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT'] = 'true'

# Compile the pipeline and verify caching
with tempfile.TemporaryDirectory() as tempdir:
output_yaml = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(pipeline_func=my_pipeline, package_path=output_yaml)

with open(output_yaml, 'r') as f:
pipeline_spec = yaml.safe_load(f)
task_spec = pipeline_spec['root']['dag']['tasks']['my-component']
caching_options = task_spec.get('cachingOptions', {})

self.assertEqual(caching_options, {})

del os.environ['KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT']

class V2NamespaceAliasTest(unittest.TestCase):
"""Test that imports of both modules and objects are aliased (e.g. all
import path variants work)."""
Expand Down

0 comments on commit c305fa5

Please sign in to comment.