From c305fa5aa1262e1437696bc847f33705309c4604 Mon Sep 17 00:00:00 2001 From: ddalvi Date: Wed, 18 Sep 2024 17:40:27 -0400 Subject: [PATCH] Add tests for disabling default caching var and flag --- sdk/python/kfp/cli/cli_test.py | 73 +++++++++++++++++++++- sdk/python/kfp/compiler/compiler_test.py | 79 ++++++++++++++++++++++++ 2 files changed, 151 insertions(+), 1 deletion(-) diff --git a/sdk/python/kfp/cli/cli_test.py b/sdk/python/kfp/cli/cli_test.py index 361db73a14e4..72b4fcc94565 100644 --- a/sdk/python/kfp/cli/cli_test.py +++ b/sdk/python/kfp/cli/cli_test.py @@ -162,10 +162,81 @@ def test_deprecated_command_is_found(self): def test_deprecation_warning(self): res = subprocess.run(['dsl-compile', '--help'], capture_output=True) - self.assertIn('Deprecated. Please use `kfp dsl compile` instead.)', + self.assertIn('Deprecated. Please use `;2` instead.)', res.stdout.decode('utf-8')) +class TestKfpDslCompile(unittest.TestCase): + + def invoke(self, args): + starting_args = ['dsl', 'compile'] + args = starting_args + args + runner = testing.CliRunner() + return runner.invoke(cli=cli.cli, args=args, catch_exceptions=False, obj={}) + + def test_compile_with_caching_flag_enabled(self): + with tempfile.NamedTemporaryFile(suffix='.py') as temp_pipeline: + # Write the pipeline function to a temporary file + temp_pipeline.write(b""" +from kfp import dsl + +@dsl.component +def my_component(): + pass + +@dsl.pipeline(name="tiny-pipeline") +def my_pipeline(): + my_component_task = my_component() +""") + temp_pipeline.flush() # Ensure the data is written to disk + + # Invoke the CLI command with the temporary file + result = self.invoke(['--py', temp_pipeline.name, '--output', 'test_output.yaml']) + print(result.output) # Print the command output + self.assertEqual(result.exit_code, 0) + + def test_compile_with_caching_flag_disabled(self): + with tempfile.NamedTemporaryFile(suffix='.py') as temp_pipeline: + temp_pipeline.write(b""" +from kfp import dsl + +@dsl.component +def my_component(): + pass + +@dsl.pipeline(name="tiny-pipeline") +def my_pipeline(): + my_component_task = my_component() +""") + temp_pipeline.flush() + + result = self.invoke( + ['--py', temp_pipeline.name, '--output', 'test_output.yaml', '--disable-execution-caching-by-default'] + ) + print(result.output) + self.assertEqual(result.exit_code, 0) + + def test_compile_with_caching_disabled_env_var(self): + with tempfile.NamedTemporaryFile(suffix='.py') as temp_pipeline: + temp_pipeline.write(b""" +from kfp import dsl + +@dsl.component +def my_component(): + pass + +@dsl.pipeline(name="tiny-pipeline") +def my_pipeline(): + my_component_task = my_component() +""") + temp_pipeline.flush() + + os.environ['KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT'] = 'true' + result = self.invoke(['--py', temp_pipeline.name, '--output', 'test_output.yaml']) + print(result.output) + self.assertEqual(result.exit_code, 0) + del os.environ['KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT'] + info_dict = cli.cli.to_info_dict(ctx=click.Context(cli.cli)) commands_dict = { command: list(body.get('commands', {}).keys()) diff --git a/sdk/python/kfp/compiler/compiler_test.py b/sdk/python/kfp/compiler/compiler_test.py index 7f0cfd4b98a3..723fb7a11a19 100644 --- a/sdk/python/kfp/compiler/compiler_test.py +++ b/sdk/python/kfp/compiler/compiler_test.py @@ -910,6 +910,85 @@ def my_pipeline() -> NamedTuple('Outputs', [ task = print_and_return(text='Hello') +class TestCompilePipelineCaching(unittest.TestCase): + + def test_compile_pipeline_with_caching_enabled(self): + """Test pipeline compilation with caching enabled.""" + + @dsl.component + def my_component(): + pass + + @dsl.pipeline(name="tiny-pipeline") + def my_pipeline(): + my_task = my_component() + my_task.set_caching_options(True) + + with tempfile.TemporaryDirectory() as tempdir: + output_yaml = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile(pipeline_func=my_pipeline, package_path=output_yaml) + + with open(output_yaml, 'r') as f: + pipeline_spec = yaml.safe_load(f) + + task_spec = pipeline_spec['root']['dag']['tasks']['my-component'] + caching_options = task_spec['cachingOptions'] + + self.assertTrue(caching_options['enableCache']) + + def test_compile_pipeline_with_caching_disabled(self): + """Test pipeline compilation with caching disabled.""" + + @dsl.component + def my_component(): + pass + + @dsl.pipeline(name="tiny-pipeline") + def my_pipeline(): + my_task = my_component() + my_task.set_caching_options(False) + + with tempfile.TemporaryDirectory() as tempdir: + output_yaml = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile(pipeline_func=my_pipeline, package_path=output_yaml) + + with open(output_yaml, 'r') as f: + pipeline_spec = yaml.safe_load(f) + + task_spec = pipeline_spec['root']['dag']['tasks']['my-component'] + caching_options = task_spec.get('cachingOptions', {}) + + self.assertEqual(caching_options, {}) + +class TestCompilePipelineCachingEnvVar(unittest.TestCase): + + def test_env_var_true_lowercase(self): + """Test pipeline compilation with caching disabled when env var is 'true'.""" + + @dsl.component + def my_component(): + pass + + @dsl.pipeline(name="tiny-pipeline") + def my_pipeline(): + my_task = my_component() + + os.environ['KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT'] = 'true' + + # Compile the pipeline and verify caching + with tempfile.TemporaryDirectory() as tempdir: + output_yaml = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile(pipeline_func=my_pipeline, package_path=output_yaml) + + with open(output_yaml, 'r') as f: + pipeline_spec = yaml.safe_load(f) + task_spec = pipeline_spec['root']['dag']['tasks']['my-component'] + caching_options = task_spec.get('cachingOptions', {}) + + self.assertEqual(caching_options, {}) + + del os.environ['KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT'] + class V2NamespaceAliasTest(unittest.TestCase): """Test that imports of both modules and objects are aliased (e.g. all import path variants work)."""