diff --git a/metaflow/plugins/cards/card_decorator.py b/metaflow/plugins/cards/card_decorator.py index a65e5d83a8..7006997e5e 100644 --- a/metaflow/plugins/cards/card_decorator.py +++ b/metaflow/plugins/cards/card_decorator.py @@ -76,6 +76,12 @@ class CardDecorator(StepDecorator): card_creator = None + _config_values = None + + _config_file_name = None + + task_finished_decos = 0 + def __init__(self, *args, **kwargs): super(CardDecorator, self).__init__(*args, **kwargs) self._task_datastore = None @@ -106,6 +112,25 @@ def _set_card_counts_per_step(cls, step_name, total_count): def _increment_step_counter(cls): cls.step_counter += 1 + @classmethod + def _increment_completed_counter(cls): + cls.task_finished_decos += 1 + + @classmethod + def _set_config_values(cls, config_values): + cls._config_values = config_values + + @classmethod + def _set_config_file_name(cls, flow): + # Only create a config file from the very first card decorator. + if cls._config_values and not cls._config_file_name: + with tempfile.NamedTemporaryFile( + mode="w", encoding="utf-8", delete=False + ) as config_file: + config_value = dump_config_values(flow) + json.dump(config_value, config_file) + cls._config_file_name = config_file.name + def step_init( self, flow, graph, step_name, decorators, environment, flow_datastore, logger ): @@ -116,11 +141,13 @@ def step_init( # We check for configuration options. We do this here before they are # converted to properties. - self._config_values = [ - (config.name, ConfigInput.make_key_name(config.name)) - for _, config in flow._get_parameters() - if config.IS_CONFIG_PARAMETER - ] + self._set_config_values( + [ + (config.name, ConfigInput.make_key_name(config.name)) + for _, config in flow._get_parameters() + if config.IS_CONFIG_PARAMETER + ] + ) self.card_options = self.attributes["options"] @@ -159,15 +186,11 @@ def task_pre_step( # If we have configs, we need to dump them to a file so we can re-use them # when calling the card creation subprocess. - if self._config_values: - with tempfile.NamedTemporaryFile( - mode="w", encoding="utf-8", delete=False - ) as config_file: - config_value = dump_config_values(flow) - json.dump(config_value, config_file) - self._config_file_name = config_file.name - else: - self._config_file_name = None + # Since a step can contain multiple card decorators, and all the card creation processes + # will reference the same config file (because of how the CardCreator is created (only single class instance)), + # we need to ensure that a single config file is being referenced for all card create commands. + # This config file will be removed when the last card decorator has finished creating its card. + self._set_config_file_name(flow) card_type = self.attributes["type"] card_class = get_card_class(card_type) @@ -246,12 +269,7 @@ def task_finished( self.card_creator.create(mode="render", final=True, **create_options) self.card_creator.create(mode="refresh", final=True, **create_options) - # Unlink the config file if it exists - if self._config_file_name: - try: - os.unlink(self._config_file_name) - except Exception as e: - pass + self._cleanup(step_name) @staticmethod def _options(mapping): @@ -286,3 +304,18 @@ def _create_top_level_args(self, flow): top_level_options["local-config-file"] = self._config_file_name return list(self._options(top_level_options)) + + def task_exception( + self, exception, step_name, flow, graph, retry_count, max_user_code_retries + ): + self._cleanup(step_name) + + def _cleanup(self, step_name): + self._increment_completed_counter() + if self.task_finished_decos == self.total_decos_on_step[step_name]: + # Unlink the config file if it exists + if self._config_file_name: + try: + os.unlink(self._config_file_name) + except Exception as e: + pass