Skip to content

Commit

Permalink
Validate the graph for unsupported MLIR bridge features in the Conver…
Browse files Browse the repository at this point in the history
…tGraphToTfExecutor method.

If graph contains unsupported features, throw out warnings.

PiperOrigin-RevId: 707966127
  • Loading branch information
tensorflower-gardener committed Dec 19, 2024
1 parent 06a7f29 commit d8d89da
Show file tree
Hide file tree
Showing 9 changed files with 1,200 additions and 5 deletions.
1 change: 0 additions & 1 deletion tensorflow/compiler/mlir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:attribute_utils",
"//tensorflow/compiler/mlir/tensorflow:device_util",
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
"//tensorflow/compiler/mlir/tensorflow:import_model",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/compiler/mlir/tf2xla:mlir_bridge_rollout_policy",
"//tensorflow/compiler/mlir/tf2xla/api/v2:graph_to_tf_executor",
Expand Down
8 changes: 6 additions & 2 deletions tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,9 @@ absl::Status MlirFunctionOptimizationPass::Run(
{kTfMlirCategory, "convert_graph_to_mlir"});

auto module_ref_status = tensorflow::tf2xla::v2::ConvertGraphToTfExecutor(
**graph, debug_info, *flib_def, import_config, &context);
**graph, debug_info, *flib_def, import_config, &context,
/*tf_name_to_mlir_name*/ nullptr, config_proto,
tensorflow::TF2XLABridgeVersion::kNominal);
mlir_function_pass_graph_conversion_count
->GetCell(absl::StatusCodeToString(module_ref_status.status().code()))
->IncrementBy(1);
Expand Down Expand Up @@ -389,7 +391,9 @@ absl::Status MlirV1CompatGraphOptimizationPass::Run(
import_config.restrict_functionalization_to_compiled_nodes = true;

auto module_ref_status = tensorflow::tf2xla::v2::ConvertGraphToTfExecutor(
**options.graph, debug_info, *options.flib_def, import_config, &context);
**options.graph, debug_info, *options.flib_def, import_config, &context,
/*tf_name_to_mlir_name*/ nullptr, options.session_options->config,
tensorflow::TF2XLABridgeVersion::kV1Compat);
if (!module_ref_status.ok()) {
LOG(ERROR) << "Failed to convert graph to MLIR: "
<< module_ref_status.status();
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/tf2xla/api/v2/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,10 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/mlir/tensorflow:translate_utils",
"//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow/translate:mlir_roundtrip_flags",
"//tensorflow/compiler/mlir/tensorflow/translate:upgrade_graph",
"//tensorflow/compiler/mlir/tf2xla/internal:graph_to_tf_executor_util",
"//tensorflow/compiler/mlir/tf2xla/internal:node_order",
"//tensorflow/compiler/tf2xla:functionalize_control_flow",
"//tensorflow/compiler/tf2xla:functionalize_control_flow_util",
Expand Down
20 changes: 19 additions & 1 deletion tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
#include "tensorflow/compiler/mlir/tf2xla/internal/graph_to_tf_executor_util.h"
#include "tensorflow/compiler/mlir/tf2xla/internal/node_order.h"
#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
Expand Down Expand Up @@ -118,6 +119,7 @@ limitations under the License.
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stack_frame.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
Expand Down Expand Up @@ -2687,7 +2689,23 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertGraphToTfExecutor(
const Graph& graph, const GraphDebugInfo& debug_info,
const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
mlir::MLIRContext* context,
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name) {
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
const ConfigProto& config_proto,
tensorflow::TF2XLABridgeVersion bridge_version) {
if (bridge_version != tensorflow::TF2XLABridgeVersion::kNotBridgeUseCase) {
bool has_unsupported_features_in_mlir_bridge =
GraphHasUnsupportedFeaturesInMlirBridge(
graph, &flib_def, config_proto,
tensorflow::TF2XLABridgeVersion::kNominal,
/*single_core_inference_mode=*/false);
if (has_unsupported_features_in_mlir_bridge) {
LOG(WARNING)
<< "Graph contains unsupported features in MLIR bridge. "
<< "Use MLIR bridge at your own risk or disable MLIR bridge, e.g., "
<< "tf.config.experimental.disable_mlir_bridge.";
}
}

// TODO(jpienaar): Remove need to const_cast.
if (specs.upgrade_legacy) {
NodeFilter node_filter = specs.restrict_functionalization_to_compiled_nodes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tf2xla/internal/graph_to_tf_executor_util.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph_debug_info.pb.h"
Expand All @@ -39,7 +40,10 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertGraphToTfExecutor(
const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
mlir::MLIRContext* context,
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name =
nullptr);
nullptr,
const ConfigProto& config_proto = {},
tensorflow::TF2XLABridgeVersion bridge_version =
tensorflow::TF2XLABridgeVersion::kNotBridgeUseCase);

} // namespace v2
} // namespace tf2xla
Expand Down
43 changes: 43 additions & 0 deletions tensorflow/compiler/mlir/tf2xla/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,46 @@ tf_cc_test(
"@com_google_googletest//:gtest",
],
)

cc_library(
name = "graph_to_tf_executor_util",
srcs = ["graph_to_tf_executor_util.cc"],
hdrs = ["graph_to_tf_executor_util.h"],
deps = [
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:framework_types_hdr",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime:function_body",
"//tensorflow/core/platform:enable_tf2_utils",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:string_view",
],
)

tf_cc_test(
name = "graph_to_tf_executor_util_test",
srcs = ["graph_to_tf_executor_util_test.cc"],
deps = [
":graph_to_tf_executor_util",
"//tensorflow/cc:array_ops",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:functional_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:scope",
"//tensorflow/cc:tpu_ops",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:portable_gif_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/framework:tensor_testutil",
"//tensorflow/core/platform:enable_tf2_utils",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
"@local_tsl//tsl/platform:status",
"@local_xla//xla/tsl/lib/core:status_test_util",
],
)
Loading

0 comments on commit d8d89da

Please sign in to comment.