Skip to content

Commit a50def4

Browse files
committed
Adding TFX IR and node_id as arguments to BaseComponent.
PiperOrigin-RevId: 364924597
1 parent c456fbb commit a50def4

File tree

4 files changed

+38
-12
lines changed

4 files changed

+38
-12
lines changed

tfx/dsl/compiler/compiler.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Compiles a TFX pipeline into a TFX DSL IR proto."""
15-
import json
16-
import re
1715
from typing import cast, Iterable, List, Mapping
1816

1917
from tfx import types
@@ -228,12 +226,6 @@ def _compile_node(
228226
compiler_utils.set_runtime_parameter_pb(
229227
parameter_value.runtime_parameter, value.name, value.ptype,
230228
value.default)
231-
elif isinstance(value, str) and re.search(
232-
data_types.RUNTIME_PARAMETER_PATTERN, value):
233-
runtime_param = json.loads(value)
234-
compiler_utils.set_runtime_parameter_pb(
235-
parameter_value.runtime_parameter, runtime_param.name,
236-
runtime_param.ptype, runtime_param.default)
237229
else:
238230
try:
239231
data_types_utils.set_metadata_value(parameter_value.field_value,

tfx/orchestration/kubeflow/base_component.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from tfx.orchestration.kubeflow import utils
4040
from tfx.orchestration.kubeflow.proto import kubeflow_pb2
4141
from tfx.orchestration.launcher import base_component_launcher
42+
from tfx.proto.orchestration import pipeline_pb2
4243
from tfx.utils import json_utils
4344

4445
from google.protobuf import json_format
@@ -70,6 +71,7 @@ def __init__(
7071
tfx_image: Text,
7172
kubeflow_metadata_config: Optional[kubeflow_pb2.KubeflowMetadataConfig],
7273
component_config: base_component_config.BaseComponentConfig,
74+
tfx_ir: pipeline_pb2.Pipeline,
7375
pod_labels_to_attach: Optional[Dict[Text, Text]] = None):
7476
"""Creates a new Kubeflow-based component.
7577
@@ -89,6 +91,7 @@ def __init__(
8991
kubeflow_metadata_config: Configuration settings for connecting to the
9092
MLMD store in a Kubeflow cluster.
9193
component_config: Component config to launch the component.
94+
tfx_ir: The TFX intermedia representation of the pipeline.
9295
pod_labels_to_attach: Optional dict of pod labels to attach to the
9396
GKE pod.
9497
"""
@@ -117,6 +120,10 @@ def __init__(
117120
serialized_component,
118121
'--component_config',
119122
json_utils.dumps(component_config),
123+
'--tfx_ir',
124+
json_format.MessageToJson(tfx_ir),
125+
'--node_id',
126+
component.id,
120127
]
121128

122129
if pipeline.enable_cache:

tfx/orchestration/kubeflow/base_component_test.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from tfx.orchestration.kubeflow import base_component
3232
from tfx.orchestration.kubeflow.proto import kubeflow_pb2
3333
from tfx.orchestration.launcher import in_process_component_launcher
34+
from tfx.proto.orchestration import pipeline_pb2
3435
from tfx.types import channel_utils
3536
from tfx.types import standard_artifacts
3637

@@ -60,6 +61,7 @@ def setUp(self):
6061

6162
self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
6263
self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
64+
self._tfx_ir = pipeline_pb2.Pipeline()
6365
with dsl.Pipeline('test_pipeline'):
6466
self.component = base_component.BaseComponent(
6567
component=statistics_gen,
@@ -72,6 +74,7 @@ def setUp(self):
7274
tfx_image='container_image',
7375
kubeflow_metadata_config=self._metadata_config,
7476
component_config=None,
77+
tfx_ir=self._tfx_ir,
7578
)
7679
self.tfx_component = statistics_gen
7780

@@ -159,6 +162,7 @@ def setUp(self):
159162

160163
self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
161164
self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
165+
self._tfx_ir = pipeline_pb2.Pipeline()
162166
with dsl.Pipeline('test_pipeline'):
163167
self.example_gen = base_component.BaseComponent(
164168
component=example_gen,
@@ -170,7 +174,8 @@ def setUp(self):
170174
pipeline_root=test_pipeline_root,
171175
tfx_image='container_image',
172176
kubeflow_metadata_config=self._metadata_config,
173-
component_config=None)
177+
component_config=None,
178+
tfx_ir=self._tfx_ir)
174179
self.statistics_gen = base_component.BaseComponent(
175180
component=statistics_gen,
176181
component_launcher_class=in_process_component_launcher
@@ -182,6 +187,7 @@ def setUp(self):
182187
tfx_image='container_image',
183188
kubeflow_metadata_config=self._metadata_config,
184189
component_config=None,
190+
tfx_ir=self._tfx_ir
185191
)
186192

187193
self.tfx_example_gen = example_gen
@@ -221,6 +227,10 @@ def testContainerOpArguments(self):
221227
formatted_statistics_gen,
222228
'--component_config',
223229
'null',
230+
'--tfx_ir',
231+
'{}',
232+
'--node_id',
233+
'StatisticsGen.foo',
224234
]
225235
example_gen_expected_args = [
226236
'--pipeline_name',
@@ -243,6 +253,10 @@ def testContainerOpArguments(self):
243253
formatted_example_gen,
244254
'--component_config',
245255
'null',
256+
'--tfx_ir',
257+
'{}',
258+
'--node_id',
259+
'CsvExampleGen',
246260
]
247261
try:
248262
self.assertEqual(

tfx/orchestration/kubeflow/kubeflow_dag_runner.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@
1919

2020
import os
2121
import re
22-
from typing import Callable, Dict, List, Optional, Text, Type
22+
from typing import Callable, Dict, List, Optional, Text, Type, cast
2323

24+
from absl import logging
2425
from kfp import compiler
2526
from kfp import dsl
2627
from kfp import gcp
2728
from kubernetes import client as k8s_client
28-
2929
from tfx import version
30+
from tfx.dsl.compiler import compiler as tfx_compiler
3031
from tfx.orchestration import data_types
3132
from tfx.orchestration import pipeline as tfx_pipeline
3233
from tfx.orchestration import tfx_runner
@@ -38,9 +39,11 @@
3839
from tfx.orchestration.launcher import base_component_launcher
3940
from tfx.orchestration.launcher import in_process_component_launcher
4041
from tfx.orchestration.launcher import kubernetes_component_launcher
42+
from tfx.proto.orchestration import pipeline_pb2
4143
from tfx.utils import json_utils
4244
from tfx.utils import telemetry_utils
4345

46+
4447
# OpFunc represents the type of a function that takes as input a
4548
# dsl.ContainerOp and returns the same object. Common operations such as adding
4649
# k8s secrets, mounting volumes, specifying the use of TPUs and so on can be
@@ -249,9 +252,11 @@ def __init__(
249252
if config and not isinstance(config, KubeflowDagRunnerConfig):
250253
raise TypeError('config must be type of KubeflowDagRunnerConfig.')
251254
super(KubeflowDagRunner, self).__init__(config or KubeflowDagRunnerConfig())
255+
self._config = cast(KubeflowDagRunnerConfig, self._config)
252256
self._output_dir = output_dir or os.getcwd()
253257
self._output_filename = output_filename
254258
self._compiler = compiler.Compiler()
259+
self._tfx_compiler = tfx_compiler.Compiler()
255260
self._params = [] # List of dsl.PipelineParam used in this pipeline.
256261
self._deduped_parameter_names = set() # Set of unique param names used.
257262
if pod_labels_to_attach is None:
@@ -307,6 +312,7 @@ def _construct_pipeline_graph(self, pipeline: tfx_pipeline.Pipeline,
307312
pipeline_root: dsl.PipelineParam representing the pipeline root.
308313
"""
309314
component_to_kfp_op = {}
315+
tfx_ir = self._generate_tfx_ir(pipeline)
310316

311317
# Assumption: There is a partial ordering of components in the list, i.e.,
312318
# if component A depends on component B and C, then A appears after B and C
@@ -332,13 +338,20 @@ def _construct_pipeline_graph(self, pipeline: tfx_pipeline.Pipeline,
332338
tfx_image=self._config.tfx_image,
333339
kubeflow_metadata_config=self._config.kubeflow_metadata_config,
334340
component_config=component_config,
335-
pod_labels_to_attach=self._pod_labels_to_attach)
341+
pod_labels_to_attach=self._pod_labels_to_attach,
342+
tfx_ir=tfx_ir)
336343

337344
for operator in self._config.pipeline_operator_funcs:
338345
kfp_component.container_op.apply(operator)
339346

340347
component_to_kfp_op[component] = kfp_component.container_op
341348

349+
def _generate_tfx_ir(
350+
self, pipeline: tfx_pipeline.Pipeline) -> pipeline_pb2.Pipeline:
351+
result = self._tfx_compiler.compile(pipeline)
352+
logging.info('Generated pipeline:\n %s', result)
353+
return result
354+
342355
def run(self, pipeline: tfx_pipeline.Pipeline):
343356
"""Compiles and outputs a Kubeflow Pipeline YAML definition file.
344357

0 commit comments

Comments
 (0)