Skip to content

Commit 6bd59c6

Browse files
frigus02tfx-copybara
authored andcommitted
Fix or ignore some pytype errors.
PiperOrigin-RevId: 662500667
1 parent 90c8da3 commit 6bd59c6

File tree

6 files changed

+14
-14
lines changed

6 files changed

+14
-14
lines changed

tfx/dsl/input_resolution/ops/latest_policy_model_op.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,13 @@ def add_downstream_artifact(
7979
"""Adds a downstream artifact to the ModelRelations."""
8080
artifact_type_name = downstream_artifact.type
8181
if _is_eval_blessed(artifact_type_name, downstream_artifact):
82-
self.model_blessing_artifacts.append(downstream_artifact)
82+
self.model_blessing_artifacts.append(downstream_artifact) # pytype: disable=container-type-mismatch # dont-delete-module-type
8383

8484
elif _is_infra_blessed(artifact_type_name, downstream_artifact):
85-
self.infra_blessing_artifacts.append(downstream_artifact)
85+
self.infra_blessing_artifacts.append(downstream_artifact) # pytype: disable=container-type-mismatch # dont-delete-module-type
8686

8787
elif artifact_type_name == ops_utils.MODEL_PUSH_TYPE_NAME:
88-
self.model_push_artifacts.append(downstream_artifact)
88+
self.model_push_artifacts.append(downstream_artifact) # pytype: disable=container-type-mismatch # dont-delete-module-type
8989

9090
def meets_policy(self, policy: Policy) -> bool:
9191
"""Checks if ModelRelations contains artifacts that meet the Policy."""
@@ -486,7 +486,7 @@ def event_filter(event):
486486
]
487487
# Set `max_num_hops` to 50, which should be enough for this use case.
488488
batch_downstream_artifacts_and_types_by_model_identifier = (
489-
mlmd_resolver.get_downstream_artifacts_by_artifacts(
489+
mlmd_resolver.get_downstream_artifacts_by_artifacts( # pytype: disable=wrong-arg-types # dont-delete-module-type
490490
batch_model_artifacts,
491491
max_num_hops=ops_utils.LATEST_POLICY_MODEL_OP_MAX_NUM_HOPS,
492492
filter_query=filter_query,

tfx/dsl/input_resolution/ops/test_utils.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def create_examples(
256256
)
257257
self.put_execution(
258258
'ExampleGen',
259-
inputs={},
259+
inputs={}, # pytype: disable=wrong-arg-types # dont-delete-module-type
260260
outputs={'examples': self.unwrap_tfx_artifacts(examples)},
261261
contexts=contexts,
262262
connection_config=connection_config,
@@ -275,7 +275,7 @@ def transform_examples(
275275
)
276276
self.put_execution(
277277
'Transform',
278-
inputs=inputs,
278+
inputs=inputs, # pytype: disable=wrong-arg-types # dont-delete-module-type
279279
outputs={
280280
'transform_graph': self.unwrap_tfx_artifacts([transform_graph])
281281
},
@@ -298,7 +298,7 @@ def train_on_examples(
298298
inputs['transform_graph'] = self.unwrap_tfx_artifacts([transform_graph])
299299
self.put_execution(
300300
'TFTrainer',
301-
inputs=inputs,
301+
inputs=inputs, # pytype: disable=wrong-arg-types # dont-delete-module-type
302302
outputs={'model': self.unwrap_tfx_artifacts([model])},
303303
contexts=contexts,
304304
connection_config=connection_config,
@@ -325,7 +325,7 @@ def evaluator_bless_model(
325325

326326
self.put_execution(
327327
'Evaluator',
328-
inputs=inputs,
328+
inputs=inputs, # pytype: disable=wrong-arg-types # dont-delete-module-type
329329
outputs={'blessing': self.unwrap_tfx_artifacts([model_blessing])},
330330
contexts=contexts,
331331
connection_config=connection_config,
@@ -353,7 +353,7 @@ def infra_validator_bless_model(
353353

354354
self.put_execution(
355355
'InfraValidator',
356-
inputs={'model': self.unwrap_tfx_artifacts([model])},
356+
inputs={'model': self.unwrap_tfx_artifacts([model])}, # pytype: disable=wrong-arg-types # dont-delete-module-type
357357
outputs={'result': self.unwrap_tfx_artifacts([model_infra_blessing])},
358358
contexts=contexts,
359359
connection_config=connection_config,
@@ -375,7 +375,7 @@ def push_model(
375375
)
376376
self.put_execution(
377377
'ServomaticPusher',
378-
inputs={'model_export': self.unwrap_tfx_artifacts([model])},
378+
inputs={'model_export': self.unwrap_tfx_artifacts([model])}, # pytype: disable=wrong-arg-types # dont-delete-module-type
379379
outputs={'model_push': self.unwrap_tfx_artifacts([model_push])},
380380
contexts=contexts,
381381
connection_config=connection_config,

tfx/orchestration/metadata.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def get_published_artifacts_by_type_within_context(
267267
@staticmethod
268268
def _get_legacy_producer_component_id(
269269
execution: metadata_store_pb2.Execution) -> str:
270-
return execution.properties[_EXECUTION_TYPE_KEY_COMPONENT_ID].string_value
270+
return execution.properties[_EXECUTION_TYPE_KEY_COMPONENT_ID].string_value # pytype: disable=bad-return-type # dont-delete-module-type
271271

272272
def get_qualified_artifacts(
273273
self,

tfx/orchestration/portable/importer_node_handler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _extract_proto_map(
4747
extract_mlmd_value = lambda v: getattr(v, v.WhichOneof('value'))
4848
return {k: extract_mlmd_value(v.field_value) for k, v in proto_map.items()}
4949

50-
def run(
50+
def run( # pytype: disable=signature-mismatch # dont-delete-module-type
5151
self, mlmd_connection: metadata.Metadata,
5252
pipeline_node: pipeline_pb2.PipelineNode,
5353
pipeline_info: pipeline_pb2.PipelineInfo,

tfx/orchestration/portable/partial_run_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,7 @@ def _get_base_pipeline_run_context(
639639
pipeline_run_contexts, key=lambda c: c.create_time_since_epoch
640640
)
641641
if not sorted_run_contexts:
642-
return None
642+
return None # pytype: disable=bad-return-type # dont-delete-module-type
643643

644644
logging.info(
645645
'base_run_id not provided. Default to latest pipeline run: %s',

tfx/orchestration/portable/resolver_node_handler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def _extract_proto_map(
4242
extract_mlmd_value = lambda v: getattr(v, v.WhichOneof('value'))
4343
return {k: extract_mlmd_value(v.field_value) for k, v in proto_map.items()}
4444

45-
def run(
45+
def run( # pytype: disable=signature-mismatch # dont-delete-module-type
4646
self, mlmd_connection: metadata.Metadata,
4747
pipeline_node: pipeline_pb2.PipelineNode,
4848
pipeline_info: pipeline_pb2.PipelineInfo,

0 commit comments

Comments
 (0)