diff --git a/tfx/components/trainer/component.py b/tfx/components/trainer/component.py index 4efd9beb64..a1b1f38737 100644 --- a/tfx/components/trainer/component.py +++ b/tfx/components/trainer/component.py @@ -72,6 +72,7 @@ class Trainer(base_component.BaseComponent): def __init__( self, + statistics: Optional[types.BaseChannel] = None, examples: Optional[types.BaseChannel] = None, transformed_examples: Optional[types.BaseChannel] = None, transform_graph: Optional[types.BaseChannel] = None, @@ -170,6 +171,7 @@ def run_fn(trainer.fn_args_utils.FnArgs) model = types.Channel(type=standard_artifacts.Model) model_run = types.Channel(type=standard_artifacts.ModelRun) spec = standard_component_specs.TrainerSpec( + statistics=statistics, examples=examples, transform_graph=transform_graph, schema=schema, diff --git a/tfx/components/trainer/component_test.py b/tfx/components/trainer/component_test.py index de9ea0fe9a..f26f39fd22 100644 --- a/tfx/components/trainer/component_test.py +++ b/tfx/components/trainer/component_test.py @@ -18,6 +18,7 @@ from tfx.components.trainer import executor from tfx.dsl.components.base import executor_spec from tfx.orchestration import data_types +from tfx.types import artifact_utils from tfx.proto import trainer_pb2 from tfx.types import channel_utils from tfx.types import standard_artifacts @@ -30,6 +31,10 @@ def setUp(self): super().setUp() self.examples = channel_utils.as_channel([standard_artifacts.Examples()]) + statistics_artifact = standard_artifacts.ExampleStatistics() + statistics_artifact.split_names = artifact_utils.encode_split_names( + ['train', 'eval']) + self.statistics = channel_utils.as_channel([statistics_artifact]) self.transform_graph = channel_utils.as_channel( [standard_artifacts.TransformGraph()]) self.schema = channel_utils.as_channel([standard_artifacts.Schema()]) @@ -62,6 +67,23 @@ def testConstructFromModuleFile(self): '{"test": 10}', trainer.spec.exec_properties[ standard_component_specs.CUSTOM_CONFIG_KEY]) + def testConstructFromModuleFileWithStatistics(self): + module_file = '/path/to/module/file' + trainer = component.Trainer( + module_file=module_file, + examples=self.examples, + statistics=self.statistics, + transform_graph=self.transform_graph, + schema=self.schema, + custom_config={'test': 10}) + self._verify_outputs(trainer) + self.assertEqual( + module_file, + trainer.spec.exec_properties[standard_component_specs.MODULE_FILE_KEY]) + self.assertEqual( + '{"test": 10}', trainer.spec.exec_properties[ + standard_component_specs.CUSTOM_CONFIG_KEY]) + def testConstructWithParameter(self): module_file = data_types.RuntimeParameter(name='module-file', ptype=str) n_steps = data_types.RuntimeParameter(name='n-steps', ptype=int) diff --git a/tfx/components/trainer/executor.py b/tfx/components/trainer/executor.py index 0d086fc295..6163a62854 100644 --- a/tfx/components/trainer/executor.py +++ b/tfx/components/trainer/executor.py @@ -22,6 +22,7 @@ from tfx.components.trainer import constants from tfx.components.trainer import fn_args_utils from tfx.components.util import udf_utils +from tfx.components.statistics_gen import stats_artifact_utils from tfx.dsl.components.base import base_executor from tfx.dsl.io import fileio from tfx.types import artifact_utils @@ -87,6 +88,15 @@ class GenericExecutor(base_executor.BaseExecutor): def _GetFnArgs(self, input_dict: Dict[str, List[types.Artifact]], output_dict: Dict[str, List[types.Artifact]], exec_properties: Dict[str, Any]) -> fn_args_utils.FnArgs: + if standard_component_specs.STATISTICS_KEY in input_dict.keys(): + stats_artifact = artifact_utils.get_single_instance( + input_dict[standard_component_specs.STATISTICS_KEY]) + split_names = artifact_utils.decode_split_names(stats_artifact.split_names) + num_examples = {} + for split in split_names: + stats = stats_artifact_utils.load_statistics(stats_artifact, + split).proto() + num_examples[split] = stats.datasets[0].num_examples if input_dict.get(standard_component_specs.HYPERPARAMETERS_KEY): hyperparameters_file = io_utils.get_only_uri_in_dir( artifact_utils.get_single_uri( @@ -115,6 +125,8 @@ def _GetFnArgs(self, input_dict: Dict[str, List[types.Artifact]], result.model_run_dir = model_run_dir result.schema_file = result.schema_path result.hyperparameters = hyperparameters_config + if standard_component_specs.STATISTICS_KEY in input_dict.keys(): + result.num_examples = num_examples return result def Do(self, input_dict: Dict[str, List[types.Artifact]], diff --git a/tfx/types/standard_component_specs.py b/tfx/types/standard_component_specs.py index 140b1c4c21..a2d2456458 100644 --- a/tfx/types/standard_component_specs.py +++ b/tfx/types/standard_component_specs.py @@ -411,6 +411,8 @@ class TrainerSpec(ComponentSpec): HYPERPARAMETERS_KEY: ChannelParameter( type=standard_artifacts.HyperParameters, optional=True), + STATISTICS_KEY: + ChannelParameter(type=standard_artifacts.ExampleStatistics, optional=True), } OUTPUTS = { MODEL_KEY: ChannelParameter(type=standard_artifacts.Model),