Skip to content

Commit a4f29a0

Browse files
authored
Removing tf-ranking as a dependency untill it supports tf 2.16 (#7725)
1 parent 4975229 commit a4f29a0

File tree

4 files changed

+195
-365
lines changed

4 files changed

+195
-365
lines changed

tfx/examples/ranking/features.py

+18-17
Original file line numberDiff line numberDiff line change
@@ -17,36 +17,37 @@
1717
These names will be shared between the transform and the model.
1818
"""
1919

20-
import tensorflow as tf
21-
from tfx.examples.ranking import struct2tensor_parsing_utils
20+
# import tensorflow as tf
21+
# This is due to TF Ranking not supporting TensorFlow 2.16, We should re-enable it when support is added.
22+
# from tfx.examples.ranking import struct2tensor_parsing_utils
2223

2324
# Labels are expected to be dense. In case of a batch of ELWCs have different
2425
# number of documents, the shape of the label is [N, D], where N is the batch
2526
# size, D is the maximum number of documents in the batch. If an ELWC in the
2627
# batch has D_0 < D documents, then the value of label at D0 <= d < D must be
2728
# negative to indicate that the document is invalid.
28-
LABEL_PADDING_VALUE = -1
29+
#LABEL_PADDING_VALUE = -1
2930

3031
# Names of features in the ELWC.
31-
QUERY_TOKENS = 'query_tokens'
32-
DOCUMENT_TOKENS = 'document_tokens'
33-
LABEL = 'relevance'
32+
#QUERY_TOKENS = 'query_tokens'
33+
#DOCUMENT_TOKENS = 'document_tokens'
34+
#LABEL = 'relevance'
3435

3536
# This "feature" does not exist in the data but will be created on the fly.
36-
LIST_SIZE_FEATURE_NAME = 'example_list_size'
37+
# LIST_SIZE_FEATURE_NAME = 'example_list_size'
3738

3839

39-
def get_features():
40-
"""Defines the context features and example features spec for parsing."""
40+
#def get_features():
41+
# """Defines the context features and example features spec for parsing."""
4142

42-
context_features = [
43-
struct2tensor_parsing_utils.Feature(QUERY_TOKENS, tf.string)
44-
]
43+
# context_features = [
44+
# struct2tensor_parsing_utils.Feature(QUERY_TOKENS, tf.string)
45+
# ]
4546

46-
example_features = [
47-
struct2tensor_parsing_utils.Feature(DOCUMENT_TOKENS, tf.string)
48-
]
47+
# example_features = [
48+
# struct2tensor_parsing_utils.Feature(DOCUMENT_TOKENS, tf.string)
49+
# ]
4950

50-
label = struct2tensor_parsing_utils.Feature(LABEL, tf.int64)
51+
# label = struct2tensor_parsing_utils.Feature(LABEL, tf.int64)
5152

52-
return context_features, example_features, label
53+
# return context_features, example_features, label

tfx/examples/ranking/ranking_pipeline_e2e_test.py

+25-22
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616
import unittest
1717

1818
import tensorflow as tf
19-
from tfx.examples.ranking import ranking_pipeline
20-
from tfx.orchestration import metadata
21-
from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner
19+
# from tfx.orchestration import metadata
20+
# from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner
21+
22+
# This is due to TF Ranking not supporting TensorFlow 2.16, We should re-enable it when support is added.
23+
# from tfx.examples.ranking import ranking_pipeline
24+
2225

2326
try:
2427
import struct2tensor # pylint: disable=g-import-not-at-top
@@ -62,23 +65,23 @@ def assertExecutedOnce(self, component) -> None:
6265
execution = tf.io.gfile.listdir(os.path.join(component_path, output))
6366
self.assertEqual(1, len(execution))
6467

65-
def testPipeline(self):
66-
BeamDagRunner().run(
67-
ranking_pipeline._create_pipeline(
68-
pipeline_name=self._pipeline_name,
69-
pipeline_root=self._tfx_root,
70-
data_root=self._data_root,
71-
module_file=self._module_file,
72-
serving_model_dir=self._serving_model_dir,
73-
metadata_path=self._metadata_path,
74-
beam_pipeline_args=['--direct_num_workers=1']))
75-
self.assertTrue(tf.io.gfile.exists(self._serving_model_dir))
76-
self.assertTrue(tf.io.gfile.exists(self._metadata_path))
68+
#def testPipeline(self):
69+
# BeamDagRunner().run(
70+
# ranking_pipeline._create_pipeline(
71+
# pipeline_name=self._pipeline_name,
72+
# pipeline_root=self._tfx_root,
73+
# data_root=self._data_root,
74+
# module_file=self._module_file,
75+
# serving_model_dir=self._serving_model_dir,
76+
# metadata_path=self._metadata_path,
77+
# beam_pipeline_args=['--direct_num_workers=1']))
78+
# self.assertTrue(tf.io.gfile.exists(self._serving_model_dir))
79+
# self.assertTrue(tf.io.gfile.exists(self._metadata_path))
7780

78-
metadata_config = metadata.sqlite_metadata_connection_config(
79-
self._metadata_path)
80-
with metadata.Metadata(metadata_config) as m:
81-
artifact_count = len(m.store.get_artifacts())
82-
execution_count = len(m.store.get_executions())
83-
self.assertGreaterEqual(artifact_count, execution_count)
84-
self.assertEqual(9, execution_count)
81+
# metadata_config = metadata.sqlite_metadata_connection_config(
82+
# self._metadata_path)
83+
# with metadata.Metadata(metadata_config) as m:
84+
# artifact_count = len(m.store.get_artifacts())
85+
# execution_count = len(m.store.get_executions())
86+
# self.assertGreaterEqual(artifact_count, execution_count)
87+
# self.assertEqual(9, execution_count)

0 commit comments

Comments
 (0)