diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index cd08f0ef97e..ebbe03013fc 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -592,6 +592,7 @@ cc_library( ":auto_mixed_precision", ":auto_parallel", ":concat_cast_fusing", + ":split_concat_fuse", ":constant_folding", ":custom_graph_optimizer_registry", ":debug_stripper", @@ -1211,4 +1212,53 @@ tf_cc_test( "//tensorflow/core:testlib", "//tensorflow/core/grappler/utils:grappler_test", ], -) \ No newline at end of file +) + +cc_library( + name = "split_concat_fuse", + srcs = ["split_concat_fuse.cc"], + hdrs = ["split_concat_fuse.h"], + visibility = ["//visibility:public"], + deps = [ + ":graph_optimizer", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:mutable_graph_view", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler:graph_view", + "//tensorflow/core/grappler/utils:graph_view", + "//tensorflow/core/grappler/utils:symbolic_shapes", + "//tensorflow/core/grappler/utils:topological_sort", + "//tensorflow/core/grappler/utils:pattern_utils", + ], +) + +tf_cc_test( + name = "split_concat_fuse_test", + srcs = ["split_concat_fuse_test.cc"], + deps = [ + ":split_concat_fuse", + ":dependency_optimizer", + "//tensorflow/cc:array_ops_internal", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/core:all_kernels", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:direct_session", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/utils:grappler_test", + "//tensorflow/cc:client_session" + ], +) diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index b1de0c33631..67320e2625e 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h" #include "tensorflow/core/grappler/optimizers/remapper.h" #include "tensorflow/core/grappler/optimizers/concat_cast_fusing.h" +#include "tensorflow/core/grappler/optimizers/split_concat_fuse.h" #include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h" #include "tensorflow/core/grappler/optimizers/shape_optimizer.h" #include "tensorflow/core/grappler/utils/canonicalizer.h" @@ -213,6 +214,7 @@ std::unique_ptr MetaOptimizer::MakeNewOptimizer( MK_OPT("pin_to_host", new PinToHostOptimizer(cfg_.pin_to_host_optimization())); MK_OPT("concat_cast_fusing", new ConcatCastFusing()); + MK_OPT("split_concat_fuse", new SplitConcatFuse()); return std::unique_ptr(); } @@ -304,6 +306,7 @@ Status MetaOptimizer::InitializeOptimizers( optimizers->push_back(MakeUnique( cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts())); } + optimizers->push_back(MakeUnique()); optimizers->push_back(MakeUnique()); return InitializeCustomGraphOptimizers(std::set(), optimizers); diff --git a/tensorflow/core/grappler/optimizers/split_concat_fuse.cc b/tensorflow/core/grappler/optimizers/split_concat_fuse.cc new file mode 100644 index 00000000000..86d30f385b5 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/split_concat_fuse.cc @@ -0,0 +1,196 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/grappler/optimizers/split_concat_fuse.h" + +#include "tensorflow/core/grappler/graph_view.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/evaluation_utils.h" +#include "tensorflow/core/grappler/utils/graph_view.h" +#include "tensorflow/core/grappler/utils/topological_sort.h" +#include "tensorflow/core/util/dump_graph.h" + +namespace tensorflow { +namespace grappler { +namespace { + +constexpr char kFusedSplitConcat[] = "_FusedSplitConcat"; + +struct Context { + explicit Context(GrapplerItem* item, Status* status) + : nodes_to_preserve(item->NodesToPreserve()), + graph_view(&item->graph, status), + graph_properties(*item), + inferred_graph_properties(false) {} + + std::unordered_set nodes_to_preserve; + utils::MutableGraphView graph_view; + GraphProperties graph_properties; + bool inferred_graph_properties; +}; + +struct SplitWithConcat { + SplitWithConcat() = default; + SplitWithConcat(int split_id, int concat_id) + : split_id(split_id), concat_id(concat_id){} + + int split_id = -1; + int concat_id = -1; +}; + +bool FindSplitWithConcat(const Context& ctx, int node_index, SplitWithConcat* matched) { + const auto* split_node_view = ctx.graph_view.GetNode(node_index); // split node + if (split_node_view->NumControllingFanins() > 0 || + split_node_view->NumControlledFanouts() > 0) return false; + + const auto* node_def = split_node_view->node(); + if (node_def == nullptr) return false; + if (!IsSplit(*node_def)) return false; + if (split_node_view->NumRegularFanouts() < 2) return false; + const auto& split_fanouts = split_node_view->GetRegularFanout(0); + const auto* concat_node_view = split_fanouts[0].node_view(); // concat node + const auto* concat_node_def = concat_node_view->node(); + if (!IsConcat(*concat_node_def)) return false; + + const SplitWithConcat pattern{node_index, + concat_node_view->node_index()}; + *matched = pattern; + + return true; +} +} + +SplitConcatFuse::SplitConcatFuse(RewriterConfig::Toggle opt_level, + DeviceBase* cpu_device) + : opt_level_(opt_level), cpu_device_(cpu_device) { + resource_mgr_.reset(new ResourceMgr()); +} + +SplitConcatFuse::SplitConcatFuse(DeviceBase* cpu_device) + : SplitConcatFuse(RewriterConfig::ON, cpu_device) {} + +Status AddSplitConcatFuseNode(Context* ctx, + int i, + const GraphDef* graph, + const SplitWithConcat& matched, + std::vector& invalidated_nodes, + std::vector& nodes_to_delete) { + + const auto* node_view = ctx->graph_view.GetNode(matched.split_id); + const auto& fused_node = graph->node(matched.split_id); + const auto* concat_view = ctx->graph_view.GetNode(matched.concat_id); + + VLOG(3) << "Optimizing fused Split Concat node " << SummarizeNodeDef(fused_node); + + const NodeDef& split = graph->node(matched.split_id); + const NodeDef& concat = graph->node(matched.concat_id); + const std::size_t split_num_inputs = node_view->NumRegularFanins(); + const int concat_num_inputs = concat_view->NumRegularFanins(); + const int split_num_fanouts = concat_view->NumRegularFanouts(); + + VLOG(3) << "Fuse " << split.op() << " with Concat: " + << " concat_name= " << concat.name(); + + NodeDef fused_op; + fused_op.set_op(kFusedSplitConcat); + fused_op.set_name(concat.name()); + fused_op.set_device(split.device()); + + // Add inputs + fused_op.add_input(split.input(0)); // 0: split_dim for split + fused_op.add_input(split.input(1)); // 1: value + fused_op.add_input(concat.input(concat_num_inputs - 1)); // 3: axis for concat + + auto* attrs = fused_op.mutable_attr(); + auto& split_attr = split.attr(); + auto& concat_attr = concat.attr(); + + // Add attributes + (*attrs)["num_split"] = split_attr.at("num_split"); // 0: num_split + (*attrs)["T"] = split_attr.at("T"); // 1: T + (*attrs)["N"] = concat_attr.at("N"); // 2: N + (*attrs)["Tidx"] = concat_attr.at("Tidx"); // 3: Tidx + + utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); + Status status; + mutation->AddNode(std::move(fused_op), &status); + TF_RETURN_IF_ERROR(status); + TF_RETURN_IF_ERROR(mutation->Apply()); + invalidated_nodes[matched.concat_id] = true; + nodes_to_delete[matched.split_id] = true; + + return Status::OK(); +} + +Status SplitConcatFuse::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { + if(cpu_device_ == nullptr){ + owned_device_.reset(new DeviceSimple()); + cpu_device_ = owned_device_.get(); + } + + GrapplerItem mutable_item = item; + Status status; + TF_RETURN_IF_ERROR(status); + Context ctx(&mutable_item, &status); + TF_RETURN_IF_ERROR(status); + TF_RETURN_IF_ERROR(ctx.graph_view.SortTopologically(false, {})); + const int num_nodes = item.graph.node_size(); + const GraphDef* graph = ctx.graph_view.graph(); + + std::vector invalidated_nodes(num_nodes); // Nodes changed into fused op + std::vector nodes_to_delete(num_nodes); // Fused nodes that are no longer needed + + VLOG(3) << "Before Split Concat graph rewrites: " << graph->DebugString(); + + for(int i = 0; i < num_nodes; ++i){ + if (invalidated_nodes[i] || nodes_to_delete[i]) { + continue; + } + + SplitWithConcat fused_split_concat; + if(FindSplitWithConcat(ctx, i, &fused_split_concat)) { + const auto* node_view = ctx.graph_view.GetNode(i); + const auto& fused_node = graph->node(i); + string op_name = fused_node.op(); + TF_RETURN_IF_ERROR(AddSplitConcatFuseNode(&ctx, + i, + graph, + fused_split_concat, + invalidated_nodes, + nodes_to_delete)); + } + } + + // Remove invalidated nodes + utils::Mutation* mutation = ctx.graph_view.GetMutationBuilder(); + for (int i = 0; i < num_nodes; ++i){ + if(nodes_to_delete[i]) { + mutation->RemoveNode(ctx.graph_view.GetNode(i)); + } + } + TF_RETURN_IF_ERROR(mutation->Apply()); + + *optimized_graph = mutable_item.graph; + VLOG(3) << "After Split Concat graph rewrites: " << optimized_graph->DebugString(); + + return Status::OK(); +} + +void SplitConcatFuse::Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) { + // Nothing to do for SplitConcatFuse +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/split_concat_fuse.h b/tensorflow/core/grappler/optimizers/split_concat_fuse.h new file mode 100644 index 00000000000..8c43eaea6e5 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/split_concat_fuse.h @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONCAT_CAST_FUSING_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONCAT_CAST_FUSING_H_ + +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { + +// SplitConcatFuse optimization for a graph +class SplitConcatFuse : public GraphOptimizer { + public: + + SplitConcatFuse() = default; + explicit SplitConcatFuse(DeviceBase* cpu_device); + SplitConcatFuse(RewriterConfig::Toggle opt_level, DeviceBase* cpu_device); + ~SplitConcatFuse() override {} + + string name() const override { return "split_concat_fuse"; }; + + bool UsesFunctionLibrary() const override { return false; } + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) override; + + + RewriterConfig::Toggle opt_level_; + DeviceBase* cpu_device_; + std::unique_ptr owned_device_; + + std::unique_ptr resource_mgr_; + GraphDef* graph_; + std::unique_ptr node_map_; +}; + + +} // end namespace grappler +} // end namespace tensorflow +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONCAT_CAST_FUSING_H_ diff --git a/tensorflow/core/grappler/optimizers/split_concat_fuse_test.cc b/tensorflow/core/grappler/optimizers/split_concat_fuse_test.cc new file mode 100644 index 00000000000..310dbba1a55 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/split_concat_fuse_test.cc @@ -0,0 +1,343 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/split_concat_fuse.h" + +#include "tensorflow/cc/client/client_session.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/array_ops_internal.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/model_pruner.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace grappler { +namespace { + +//----------------------------------------------------------------------------// +// Functional tests are below. // +//----------------------------------------------------------------------------// + +namespace SplitConcatFuseTestDefs{ + typedef std::tuple< + DataType, + std::vector, // Input tensor size + std::vector, + long long int // num_of_splits + > SplitConcatFuseTestParams; + std::vector dataTypes{ + DataType::DT_FLOAT + }; + std::vector> SCRIPT_SIZES = {{1024, 50, 256}}; + std::vector> SCRIPT_SPLIT_CONCAT = {{2, 0}}; + std::vector SCRIPT_NUM_OF_SPLITS = {8}; + + std::vector> SIZES_2D = {{216, 192}, {6, 36}, {48, 144}, {216, 36}}; + std::vector> SPLIT_DIM_CONCAT_AXIS_2D = {{0, 1}, {1, 0}, {1, 1}, {0, 0}, + {-1, -1}, {-1, 0}, {0, -1}}; + std::vector NUM_OF_SPLITS_2D = {2, 3}; + + std::vector> SIZES_3D = {{36, 6, 48}, {48, 48, 48}, {18, 216, 6}}; + std::vector> SPLIT_DIM_CONCAT_AXIS_3D = {{0, 0}, {0, 1}, {1, 0}, {1, 1}, + {0, 2}, {2, 0}, {2, 2}, {2, 1}, {1, 2}, + {-2, 0}, {-1, 0}}; + std::vector NUM_OF_SPLITS_3D = {2, 3}; + + std::vector> SIZES_4D = {{18, 6, 192, 48}, {36, 36, 36, 36}, {6, 18, 6, 36}}; + std::vector> SPLIT_DIM_CONCAT_AXIS_4D = {{0, 0}, {0, 1}, {1, 0}, {1, 1}, + {0, 2}, {2, 0}, {2, 2}, {2, 1}, {1, 2}, + {0, 3}, {3, 0}, {1, 3}, {3, 1}, {2, 3}, {3, 2}, {3, 3}, + {-3, 0}, {-1, 3}}; + std::vector NUM_OF_SPLITS_4D = {2, 3}; +} // namespace SplitConcatFuseTestDefs + +using namespace SplitConcatFuseTestDefs; +class SplitConcatFuseTest : +public ::testing::WithParamInterface, +public GrapplerTest { + public: + static std::string getTestCaseName(::testing::TestParamInfo obj){ + DataType dtype; + std::vector input_size; + std::vector split_dim_concat_axis; + long long int num_split; + std::tie(dtype, input_size, split_dim_concat_axis, num_split) = obj.param; + + std::ostringstream result; + result << "SplitConcatFuse_DataType_"; + switch(dtype) { + case DataType::DT_FLOAT: + result << "FLOAT"; + break; + default: + result << "UNRECOGNISED_TYPE"; + } + result << "_InputSize"; + for (auto &x : input_size){ + result << "_" << x; + } + if(split_dim_concat_axis[0] < 0){ + result << "_SplitDim_negative_" << abs(split_dim_concat_axis[0]); + } else{ + result << "_SplitDim_" << split_dim_concat_axis[0]; + } + if(split_dim_concat_axis[1] < 0){ + result << "_ConcatAxis_negative_" << abs(split_dim_concat_axis[1]); + } else{ + result << "_ConcatAxis_" << split_dim_concat_axis[1]; + } + result << "_NumSplit_" << num_split; + return result.str(); + } + + void SetUp(){ + std::tie(dtype, input_size, split_dim_concat_axis, num_of_splits) = this->GetParam(); + std::vector input_names; + GraphDef ref_graph; + input = Tensor(dtype, TensorShape(tensorflow::gtl::ArraySlice(input_size.data(), input_size.size()))); + switch(dtype){ + case DataType::DT_FLOAT: + input.flat() = input.flat().template setRandom>(); // input + break; + default: + GTEST_FAIL() << "Unexpected DataType"; + } + split_dim_tensor = Tensor((int32)split_dim_concat_axis[0]); + concat_axis_tensor = Tensor((int32)split_dim_concat_axis[1]); + } + + protected: + void Validate(std::vector tensor, std::vector tensor_expected){ + EXPECT_EQ(dtype, tensor_expected[0].dtype()); + EXPECT_EQ(dtype, tensor[0].dtype()); + test::ExpectTensorEqual(tensor_expected[0], tensor[0]); + } + + // Test definition (straight from Params, filled in Setup) + DataType dtype; + std::vector input_size; + std::vector split_dim_concat_axis; + long long int num_of_splits; + + Tensor input; + Tensor split_dim_tensor; + Tensor concat_axis_tensor; + + GraphDef want; +}; + +class SplitConcatFuseTestSimpleFusing : public SplitConcatFuseTest{ + public: + void RunAndValidate(){ + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + + auto value = + ops::Const(root.WithOpName("value"), input); + auto concat_axis_op = + ops::Const(root.WithOpName("axis"), concat_axis_tensor); + auto split_dim_op = + ops::Const(root.WithOpName("split_dim"), split_dim_tensor); + auto s = + ops::Split(root.WithOpName("split"), split_dim_op, value, num_of_splits); + + if(num_of_splits == 2){ + auto concat_out = + ops::Concat(root.WithOpName("concat"), {s[0], s[1]}, concat_axis_op); + } else if( num_of_splits == 3){ + auto concat_out = + ops::Concat(root.WithOpName("concat"), {s[0], s[1], s[2]}, concat_axis_op); + } else if( num_of_splits == 4){ + auto concat_out = + ops::Concat(root.WithOpName("concat"), {s[0], s[1], s[2], s[3]}, concat_axis_op); + } else if( num_of_splits == 8){ + auto concat_out = + ops::Concat(root.WithOpName("concat"), {s[0], s[1], s[2], s[3], s[4], s[5], s[6], s[7]}, concat_axis_op); + } else{ + GTEST_FAIL() << "This num of splits is not coded in tests, try between 2 and 4."; + } + + GrapplerItem item; + item.fetch.push_back("concat"); + TF_CHECK_OK(root.ToGraphDef(&item.graph)); + + SplitConcatFuse optimizer(nullptr); + GraphDef output; + Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output); + TF_EXPECT_OK(status); + + std::vector fetch = {"concat"}; + auto tensors_expected = EvaluateNodes(item.graph, fetch); + auto tensors = EvaluateNodes(output, fetch); + Validate(tensors, tensors_expected); + } +}; + +TEST_P(SplitConcatFuseTestSimpleFusing, CompareWithRefs){ + SetUp(); + RunAndValidate(); +} + +INSTANTIATE_TEST_CASE_P(SCRIPTplit, SplitConcatFuseTestSimpleFusing, + ::testing::Combine( + ::testing::ValuesIn(dataTypes), + ::testing::ValuesIn(SCRIPT_SIZES), + ::testing::ValuesIn(SCRIPT_SPLIT_CONCAT), + ::testing::ValuesIn(SCRIPT_NUM_OF_SPLITS)), + SplitConcatFuseTestSimpleFusing::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(Split2D, SplitConcatFuseTestSimpleFusing, + ::testing::Combine( + ::testing::ValuesIn(dataTypes), + ::testing::ValuesIn(SIZES_2D), + ::testing::ValuesIn(SPLIT_DIM_CONCAT_AXIS_2D), + ::testing::ValuesIn(NUM_OF_SPLITS_2D)), + SplitConcatFuseTestSimpleFusing::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(Split3D, SplitConcatFuseTestSimpleFusing, + ::testing::Combine( + ::testing::ValuesIn(dataTypes), + ::testing::ValuesIn(SIZES_3D), + ::testing::ValuesIn(SPLIT_DIM_CONCAT_AXIS_3D), + ::testing::ValuesIn(NUM_OF_SPLITS_3D)), + SplitConcatFuseTestSimpleFusing::getTestCaseName); + +INSTANTIATE_TEST_CASE_P(Split4D, SplitConcatFuseTestSimpleFusing, + ::testing::Combine( + ::testing::ValuesIn(dataTypes), + ::testing::ValuesIn(SIZES_4D), + ::testing::ValuesIn(SPLIT_DIM_CONCAT_AXIS_4D), + ::testing::ValuesIn(NUM_OF_SPLITS_4D)), + SplitConcatFuseTestSimpleFusing::getTestCaseName); + +//----------------------------------------------------------------------------// +// Performance benchmarks are below. // +//----------------------------------------------------------------------------// + +template +static Graph* SplitConcatFuse(bool fused, std::vector input_shape, int split_d, int concat_a, int num_split){ + Graph* g = new Graph(OpRegistry::Global()); + DataType dt = DataTypeToEnum::v(); + Tensor split_dim(DT_INT32, TensorShape({})); + split_dim.scalar()() = split_d; + Tensor concat_axis(DT_INT32, TensorShape({})); + concat_axis.scalar()() = concat_a; + Tensor input(dt, TensorShape(tensorflow::gtl::ArraySlice(input_shape.data(), input_shape.size()))); + input.flat().setRandom(); + + if(fused){ + Node* split_concat; + TF_CHECK_OK(NodeBuilder(g->NewName("splitconcat"), "_FusedSplitConcat") + .Input(test::graph::Constant(g, split_dim)) + .Input(test::graph::Constant(g, input)) + .Input(test::graph::Constant(g, concat_axis)) + .Attr("num_split", num_split) + .Attr("T", dt) + .Attr("N", num_split) + .Attr("Tidx", DT_INT32) + .Finalize(g, &split_concat)); + + return g; + } else{ + Node* split; + TF_CHECK_OK(NodeBuilder(g->NewName("split"), "Split") + .Input(test::graph::Constant(g, split_dim)) + .Input(test::graph::Constant(g, input)) + .Attr("num_split", num_split) + .Attr("T", dt) + .Finalize(g, &split)); + + + std::vector out_list; + for (int i = 0; i < num_split; ++i){ + Output buf(split, i); + out_list.push_back(buf.node()); + } + + Node* concat; + TF_CHECK_OK(NodeBuilder(g->NewName("concat"), "Concat") + .Input(test::graph::Constant(g, concat_axis)) + .Input(out_list) + .Attr("N", num_split) + .Attr("T", dt) + .Finalize(g, &concat)); + + return g; + } +} + +using fp32 = float; + +#define BM_NAME(name, FUSED, T, split_name, concat_name, NUM_SPLIT, input_shape_name) \ + name##_##FUSED##_##T##_##split_name##_##concat_name##_##NUM_SPLIT##_##input_shape_name + +#define BM_SplitConcatFuseBenchmark(FUSED, T, INPUT_SHAPE, split_name, SPLIT_DIM, concat_name, CONCAT_AXIS, NUM_SPLIT, input_shape_name, LABEL) \ + static void BM_NAME(BM_SplitConcatFuse, FUSED, T, split_name, concat_name, NUM_SPLIT, input_shape_name)(int iters){ \ + testing::StopTiming(); \ + std::string base_label = FUSED ? "fused_split_concat" : "not_fused_split_concat"; \ + size_t input_shape_size = INPUT_SHAPE.size(); \ + int all_elements = 1; \ + for (int i = 0; i < input_shape_size; ++i){ \ + all_elements *= INPUT_SHAPE[i]; \ + } \ + testing::SetLabel(base_label + "_" + LABEL); \ + testing::BytesProcessed(static_cast(iters) * all_elements * sizeof(T)); \ + testing::StartTiming(); \ + test::Benchmark("cpu", SplitConcatFuse(FUSED, INPUT_SHAPE, SPLIT_DIM, CONCAT_AXIS, NUM_SPLIT)) \ + .Run(iters); \ + } \ + BENCHMARK(BM_NAME(BM_SplitConcatFuse, FUSED, T, split_name, concat_name, NUM_SPLIT, input_shape_name)); + +#define DLRM_BENCHMARK(FUSED) \ + BM_SplitConcatFuseBenchmark(FUSED, fp32, (std::vector{1024, 50, 256}), 2, 2, 0, 0, 8, 1024x50x256, "float_1024x50x356_split_2_concat_0_numsplit_8"); \ + BM_SplitConcatFuseBenchmark(FUSED, fp32, (std::vector{1024, 50, 256}), neg_1, -1, 0, 0, 8, 1024x50x256, "float_1024x50x356_split_-1_concat_0_numsplit_8"); + +#define EQUAL_BENCHMARK(FUSED) \ + /* split_dim == concat_axis */ \ + BM_SplitConcatFuseBenchmark(FUSED, fp32, (std::vector{1024, 50, 256}), 0, 0, 0, 0, 8, 1024x50x256, "float_1024x50x356_split_0_concat_0_numsplit_8"); \ + BM_SplitConcatFuseBenchmark(FUSED, fp32, (std::vector{128, 512, 2, 16}), 3, 3, neg_1, -1, 2, 128x512x2z16, "float_128x512x2x16_split_3_concat_-1_numsplit_2"); \ + BM_SplitConcatFuseBenchmark(FUSED, fp32, (std::vector{192, 64, 96}), neg_1, -1, 2, 2, 3, 192x64x96, "float_192x64x96_split_-1_concat_2_numsplit_3"); + +#define SPLIT_MAJOR_BENCHMARK(FUSED) \ + /* split_dim > concat_axis */ \ + BM_SplitConcatFuseBenchmark(FUSED, fp32, (std::vector{1024, 50, 256}), 2, 2, 1, 1, 2, 1024x50x256, "float_1024x50x256_split_2_concat_1_numsplit_2"); \ + BM_SplitConcatFuseBenchmark(FUSED, fp32, (std::vector{1024, 512}), 1, 1, 0, 0, 2, 1024x512, "float_1024x512_split_1_concat_0_numsplit_2"); \ + BM_SplitConcatFuseBenchmark(FUSED, fp32, (std::vector{384, 192, 64, 6}), 3, 3, neg_3, -3, 3, 384x192x64x6, "float_384x192x64x6_split_3_concat_-3_numsplit_3"); + +#define CONCAT_MAJOR_BENCHMARK(FUSED) \ + /* concat_axis > split_dim */ \ + BM_SplitConcatFuseBenchmark(FUSED, fp32, (std::vector{1024, 50, 256}), 0, 0, 2, 2, 8, 1024x50x256, "float_1024x50x356_split_0_concat_2_numsplit_8"); \ + BM_SplitConcatFuseBenchmark(FUSED, fp32, (std::vector{1024, 50, 256, 16}), neg_3, -3, 3, 3, 2, 1024x50z256x16, "float_1024x50x256x16_split_-3_concat_3_numsplit_2"); \ + BM_SplitConcatFuseBenchmark(FUSED, fp32, (std::vector{384, 192, 6}), 0, 0, 1, 1, 3, 384x192x6, "float_384x192x6_split_0_concat_1_numsplit_1"); + +DLRM_BENCHMARK(true) +DLRM_BENCHMARK(false) + +EQUAL_BENCHMARK(true) +EQUAL_BENCHMARK(false) + +SPLIT_MAJOR_BENCHMARK(true) +SPLIT_MAJOR_BENCHMARK(false) + +CONCAT_MAJOR_BENCHMARK(true) +CONCAT_MAJOR_BENCHMARK(false) + +} // namespace +} // namespace grappler +} // namespace tensorflow \ No newline at end of file diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc index da5701836f6..4dbb63d9ff8 100644 --- a/tensorflow/core/kernels/split_op.cc +++ b/tensorflow/core/kernels/split_op.cc @@ -24,9 +24,11 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/kernels/split_lib.h" +#include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/util/work_sharder.h" +#include "tensorflow/core/platform/prefetch.h" #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/kernels/gpu_device_array.h" @@ -437,4 +439,184 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL); #endif // TENSORFLOW_USE_SYCL +template +class FusedSplitConcatOp : public SplitOpBase { + public: + typedef SplitOpBase Base; + + typedef std::vector::ConstMatrix>> + ConstMatrixVector; + + explicit FusedSplitConcatOp(OpKernelConstruction* c) : Base(c) { + // Concat attributes. + OP_REQUIRES_OK(c, c->GetAttr("N", &n_concat_)); + OP_REQUIRES_OK(c, c->GetAttr("Tidx", &axis_dtype_)); + OP_REQUIRES_OK(c, c->GetAttr("num_split", &num_split_)); + } + + auto tf_tensor_to_vector(Tensor tensor, int32_t tensorSize){ + int32_t* tensor_ptr = tensor.flat().data(); + std::vector v(tensor_ptr, tensor_ptr + tensorSize); + return v; + } + + auto create_dx(std::vector input){ + const int v_size = input.size(); + std::vector return_vec({1}); + for(int i = 1; i < v_size; ++i){ + return_vec.push_back(return_vec[i - 1] * input[v_size - i]); + } + std::reverse(return_vec.begin(), return_vec.end()); + return return_vec; + } + + std::vector to_vector(auto input){ + std::vector return_vec; + for(int i = input.size() - 1; i >= 0; --i){ + return_vec.push_back(input[i]); + } + std::reverse(return_vec.begin(), return_vec.end()); + return return_vec; + } + + void Compute(OpKernelContext* context) override { + Tensor input = context->input(1); + const TensorShape& input_shape = input.shape(); + Tensor split_dim_tensor = context->input(0); + auto split_dim = split_dim_tensor.scalar()(); + Tensor concat_axis_tensor = context->input(2); + auto concat_axis = concat_axis_tensor.scalar()(); + + split_dim = split_dim < 0 ? split_dim + input_shape.dims() : split_dim; + concat_axis = concat_axis < 0 ? concat_axis + input_shape.dims() : concat_axis; + + // Split dim and Concat axis equal case + if (split_dim == concat_axis) { + TensorShape output_shape(input_shape); + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + + auto out_ptr = output->flat().data(); + auto in_ptr = input.flat().data(); + auto cpu_device = context->eigen_cpu_device(); + auto worker = [&](int64_t begin, int64_t end) -> void { + int64_t range = end-begin; + std::memcpy(out_ptr + begin, in_ptr + begin, range * 4); + }; + const Eigen::TensorOpCost cost(4, 4, 1); + cpu_device.parallelFor(output->NumElements(), cost, worker ); + + return; + } else { + + // Split dim and Concat axis not equal case + auto shape = input_shape.dim_sizes(); + int concat_cat_dim_size = shape[concat_axis]; + int split_size = shape[split_dim] / num_split_; + int size_dot = 1; + std::vector size_ranges; + + for(int i = 0; i < shape.size(); ++i){ + size_ranges.push_back(size_dot); + size_dot = size_dot * shape[shape.size() - i - 1]; + } + + std::reverse(size_ranges.begin(), size_ranges.end()); + auto new_shape = shape; + new_shape[split_dim] = split_size; + new_shape[concat_axis] = new_shape[concat_axis] * num_split_; + TensorShape output_shape(new_shape); + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + + std::vector new_dx = create_dx(to_vector(new_shape)); + std::vector old_dx = create_dx(to_vector(shape)); + auto output_flat_data = output->flat().data(); + auto input_flat_data = input.flat().data(); + + // calculate maximum number of copying in parallel + auto split_max_stride = 1; + for(int i = shape.size() - 1; i >= split_dim; i--) { + split_max_stride *= shape[i]; + } + split_max_stride /= num_split_; // maximum number of data that can be taken before striding occurs + + auto concat_max_stride = 1; + for(int i = shape.size() - 1; i >= concat_axis; i--) { + concat_max_stride *= shape[i]; // maximum number of data that can be inserted before data from other strides is inserted + } + + if (concat_max_stride > split_max_stride) { + auto worker_number = size_dot / split_max_stride / num_split_; + auto cpu_device = context->eigen_cpu_device(); + auto worker = [&](int64_t begin, int64_t end) -> void { + int64_t dot_begin = begin * split_max_stride * num_split_; + std::vector original_indexes(split_dim); + std::div_t dv{}; + for(auto i = begin; i < end; i++) { + dv.rem = dot_begin; + for(int x = 0; x <= split_dim - 1; ++x){ + dv = std::div(dv.rem, size_ranges[x]); + original_indexes[x] = dv.quot; + } + + int new_flat = 0; + for(int x = 0; x <= split_dim - 1; ++x){ + new_flat += original_indexes[x] * new_dx[x]; + } + + for(int j = 0; j < num_split_; j++) { + memcpy(output_flat_data + new_flat, input_flat_data + dot_begin, split_max_stride * 4); + new_flat += concat_max_stride / num_split_; + dot_begin += split_max_stride; + } + } + }; + + const Eigen::TensorOpCost cost(split_max_stride * num_split_ * 4, split_max_stride * num_split_ * 4, (split_max_stride * num_split_) + (split_dim - 1) * 15); + cpu_device.parallelFor(worker_number, cost, worker); + } else { + auto worker_number = size_dot / concat_max_stride / num_split_; + auto cpu_device = context->eigen_cpu_device(); + + auto worker = [&](int64_t begin, int64_t end) -> void { + std::vector original_indexes(concat_axis+1); + for(auto i = begin; i < end; i++) { + int64_t dot_begin = i * concat_max_stride; + int number_of_splits = dot_begin / split_max_stride; + int old_flat = dot_begin + number_of_splits * (num_split_ - 1) * split_max_stride; + auto new_flat = dot_begin * num_split_; + + for(int j = 0; j < num_split_; j++) { + memcpy(output_flat_data + new_flat, input_flat_data + old_flat, concat_max_stride * 4); + old_flat += split_max_stride; + new_flat += concat_max_stride; + } + } + }; + + const Eigen::TensorOpCost cost(concat_max_stride * num_split_ * 4, concat_max_stride * num_split_ * 4, (concat_max_stride * num_split_) + (concat_axis) * 15); + cpu_device.parallelFor(worker_number, cost, worker); + } + } + } + private: + int n_concat_; + DataType axis_dtype_; + int num_split_; +}; + +#define REGISTER_SPLITCONCAT(type) \ +REGISTER_KERNEL_BUILDER(Name("_FusedSplitConcat") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("split_dim"), \ + FusedSplitConcatOp) + +// TF_CALL_POD_STRING_TYPES(REGISTER_SPLITCONCAT); +//TF_CALL_ALL_TYPES(REGISTER_SPLITCONCAT); +REGISTER_SPLITCONCAT(float); + +#undef REGISTER_SPLITCONCAT } // end namespace tensorflow diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index a034e831895..4100ac1769a 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -741,6 +741,44 @@ REGISTER_OP("SplitV") return Status::OK(); }); +REGISTER_OP("_FusedSplitConcat") + .Input("split_dim: int32") + .Input("value: T") + .Input("axis: Tidx") + .Output("output: T") + .Attr("num_split: int >= 1") + .Attr("T: type") + // Concat attributes + .Attr("N: int >= 2") + .Attr("Tidx: {int32, int64} = DT_INT32") + .SetShapeFn([](InferenceContext* c) { + DimensionHandle split_dimension; + ShapeHandle input = c->input(1); + TF_RETURN_IF_ERROR(c->MakeDimForScalarInputWithNegativeIndexing( + 0, c->Rank(input), &split_dimension)); + int num_split = c->num_outputs(); + ShapeHandle out; + if (!c->ValueKnown(split_dimension)) { + if (c->RankKnown(input)) { + out = c->UnknownShapeOfRank(c->Rank(input)); + } else { + out = c->UnknownShape(); + } + } else { + int64 split_dim = c->Value(split_dimension); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input)); + DimensionHandle split_dim_size; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + c->Divide(c->Dim(input, split_dim), num_split, + true /* evenly_divisible */, &split_dim_size), + "Number of ways to split should evenly divide the split dimension"); + TF_RETURN_IF_ERROR( + c->ReplaceDim(input, split_dim, split_dim_size, &out)); + } + for (int i = 0; i < num_split; ++i) c->set_output(i, out); + return Status::OK(); + }); + // -------------------------------------------------------------------------- REGISTER_OP("Const") .Output("output: dtype")