diff --git a/CMakeLists.txt b/CMakeLists.txt index 89f5bd5e..1a786315 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -226,6 +226,12 @@ message(STATUS "NGRAPH_PLAIDML_ENABLE: ${NGRAPH_PLAIDML_ENABLE}") message(STATUS "NGRAPH_TARGET_ARCH: ${NGRAPH_TARGET_ARCH}") message(STATUS "NGRAPH_TUNE_ARCH: ${NGRAPH_TUNE_ARCH}") +if(NGRAPH_DISTRIBUTED_ENABLE) + find_package(MPI REQUIRED) + add_definitions(-DNGRAPH_DISTRIBUTED) + include_directories(SYSTEM ${MPI_C_INCLUDE_PATH} ${MPI_CXX_INCLUDE_PATH}) + link_directories(${MPI_C_LIBRARIES} ${MPI_CXX_LIBRARIES}) +endif() # Find and build ngraph - if not using pre-built one if (NOT USE_PRE_BUILT_NGRAPH) ExternalProject_Add( diff --git a/build_ngtf.py b/build_ngtf.py index 8e1b450e..b26ab30f 100755 --- a/build_ngtf.py +++ b/build_ngtf.py @@ -408,6 +408,11 @@ def main(): " |-- venv-tf-py3 (Virtualenv directory to be used)\n", action="store_true") + parser.add_argument( + '--distributed_build', + help="Builds a distributed version of the nGraph components\n", + action="store_true") + arguments = parser.parse_args() if (arguments.debug_build): @@ -494,7 +499,6 @@ def main(): ngraph_cmake_flags = [ "-DNGRAPH_INSTALL_PREFIX=" + artifacts_location, - "-DNGRAPH_DISTRIBUTED_ENABLE=FALSE", "-DNGRAPH_USE_CXX_ABI=" + cxx_abi, "-DNGRAPH_UNIT_TEST_ENABLE=NO", "-DNGRAPH_DEX_ONLY=TRUE", @@ -512,6 +516,11 @@ def main(): if (arguments.debug_build): ngraph_cmake_flags.extend(["-DCMAKE_BUILD_TYPE=Debug"]) + if (arguments.distributed_build): + ngraph_cmake_flags.extend(["-DNGRAPH_DISTRIBUTED_ENABLE=TRUE"]) + else: + ngraph_cmake_flags.extend(["-DNGRAPH_DISTRIBUTED_ENABLE=FALSE"]) + build_ngraph("./ngraph", ngraph_cmake_flags, verbosity) # Next build CMAKE options for the bridge @@ -529,6 +538,11 @@ def main(): if (arguments.debug_build): ngraph_tf_cmake_flags.extend(["-DCMAKE_BUILD_TYPE=Debug"]) + if (arguments.distributed_build): + ngraph_tf_cmake_flags.extend(["-DNGRAPH_DISTRIBUTED_ENABLE=TRUE"]) + else: + ngraph_tf_cmake_flags.extend(["-DNGRAPH_DISTRIBUTED_ENABLE=FALSE"]) + # Now build the bridge ng_tf_whl = build_ngraph_tf(artifacts_location, "../", venv_dir, ngraph_tf_cmake_flags, verbosity) diff --git a/examples/mnist/mnist_softmax_distributed.py b/examples/mnist/mnist_softmax_distributed.py new file mode 100644 index 00000000..597849f2 --- /dev/null +++ b/examples/mnist/mnist_softmax_distributed.py @@ -0,0 +1,148 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A very simple MNIST classifier. + +See extensive documentation at +https://www.tensorflow.org/get_started/mnist/beginners +Reference to the original source code: +https://github.com/tensorflow/tensorflow/blob/r1.2/tensorflow/examples/tutorials/mnist/mnist_softmax.py +Add distributed fetaure with horovod +1. hvd.init() +2. Add distributed wrapper from hvd.DistributedOptimizer +3. Broadcast the variables from root rank to the rest processors: hvd.BroadcastGlobalVariablesHook(0) +4. Print the output for root rank only +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys +import time + +from tensorflow.examples.tutorials.mnist import input_data + +import tensorflow as tf +import ngraph_bridge +import horovod.tensorflow as hvd +learn = tf.contrib.learn + +FLAGS = None + +hvd.init() + + +def main(_): + run_mnist(_) + + +def run_mnist(_): + # Import data + mnist = learn.datasets.mnist.read_data_sets( + FLAGS.data_dir + 'MNIST-data-%d' % hvd.rank(), one_hot=True) + + # Create the model + with tf.name_scope("mnist_placholder"): + x = tf.placeholder(tf.float32, [None, 784]) + W = tf.Variable(tf.zeros([784, 10])) + b = tf.Variable(tf.zeros([10])) + y = tf.matmul(x, W) + b + + # Define loss and optimizer + y_ = tf.placeholder(tf.float32, [None, 10]) + + # The raw formulation of cross-entropy, + # + # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)), + # reduction_indices=[1])) + # + # can be numerically unstable. + # + # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw + # outputs of 'y', and then average across the batch. + cross_entropy = tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) + #global_step = tf.train.get_or_create_global_step() + global_step = tf.contrib.framework.get_or_create_global_step() + opt = tf.train.GradientDescentOptimizer(0.5) + # Add MPI Distributed Optimizer + with tf.name_scope("horovod_opt"): + opt = hvd.DistributedOptimizer(opt) + train_step = opt.minimize(cross_entropy, global_step=global_step) + + # The StopAtStepHook handles stopping after running given steps. + hooks = [ + hvd.BroadcastGlobalVariablesHook(0), + tf.train.StopAtStepHook(last_step=10) + ] + + # Test trained model + correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) + accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) + + # Enable soft placement and tracing as needed + config = tf.ConfigProto( + allow_soft_placement=True, + log_device_placement=True, + inter_op_parallelism_threads=1) + + #config.graph_options.optimizer_options.global_jit_level = jit_level + run_metadata = tf.RunMetadata() + + #init_op = tf.global_variables_initializer() + print("Variables initialized ...") + + # The MonitoredTrainingSession takes care of session initialization + with tf.train.MonitoredTrainingSession( + hooks=hooks, config=config) as mon_sess: + start = time.time() + train_writer = tf.summary.FileWriter(FLAGS.log_dir, mon_sess.graph) + while not mon_sess.should_stop(): + # Train + batch_xs, batch_ys = mnist.train.next_batch(100) + mon_sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) + + # Test trained model + if not mon_sess.should_stop(): + print("Accuracy: ", + mon_sess.run( + accuracy, + feed_dict={ + x: mnist.test.images, + y_: mnist.test.labels + })) + + end = time.time() + + if hvd.rank() == 0: + print("Training time: %f seconds" % (end - start)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--data_dir', + type=str, + default='/tmp/tensorflow/mnist/input_data', + help='Directory for storing input data') + parser.add_argument( + '--log_dir', + type=str, + default='/tmp/tensorflow/mnist/logs/mnist_with_summaries', + help='Summaries log directory') + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) +# run command for this distributed script +# mpirun -np 2 python mnist_softmax_distributed.py --data_dir=/mnt/data/mnist diff --git a/src/ngraph_encapsulate_op.cc b/src/ngraph_encapsulate_op.cc index ec4421dc..a8240607 100644 --- a/src/ngraph_encapsulate_op.cc +++ b/src/ngraph_encapsulate_op.cc @@ -36,6 +36,10 @@ #include "ngraph/runtime/interpreter/int_backend.hpp" +#if defined NGRAPH_DISTRIBUTED +#include "ngraph/distributed.hpp" +#endif + using namespace std; namespace ng = ngraph; @@ -260,8 +264,18 @@ class NGraphEncapsulateOp : public OpKernel { // Serialize to nGraph if needed if (std::getenv("NGRAPH_ENABLE_SERIALIZE") != nullptr) { + std::string file_name = + "tf_function_" + ctx->op_kernel().name() + ".json"; NgraphSerialize("tf_function_" + ctx->op_kernel().name() + ".json", ng_function); +#if defined NGRAPH_DISTRIBUTED + ngraph::Distributed dist; + int Rank_ID; + Rank_ID = dist.get_rank(); + NgraphSerialize("tf_function_" + ctx->op_kernel().name() + "_" + + to_string(Rank_ID) + ".json", + ng_function); +#endif } m_ng_functions[signature] = ng_function; diff --git a/src/ngraph_mark_for_clustering.cc b/src/ngraph_mark_for_clustering.cc index 047ec00b..c9ff628c 100644 --- a/src/ngraph_mark_for_clustering.cc +++ b/src/ngraph_mark_for_clustering.cc @@ -253,7 +253,7 @@ Status MarkForClustering(Graph* graph) { }; confirmation_function_map["Greater"] = SimpleConfirmationFunction(); confirmation_function_map["GreaterEqual"] = SimpleConfirmationFunction(); -#ifdef NGRAPH_DISTRIBUTED +#if defined NGRAPH_DISTRIBUTED confirmation_function_map["HorovodAllreduce"] = SimpleConfirmationFunction(); #endif @@ -389,7 +389,7 @@ Status MarkForClustering(Graph* graph) { type_constraint_map["FusedBatchNormGrad"]["T"] = NGraphNumericDTypes(); type_constraint_map["Greater"]["T"] = NGraphDTypes(); type_constraint_map["GreaterEqual"]["T"] = NGraphDTypes(); -#ifdef NGRAPH_DISTRIBUTED +#if defined NGRAPH_DISTRIBUTED type_constraint_map["HorovodAllreduce"]["T"] = NGraphNumericDTypes(); #endif type_constraint_map["Identity"]["T"] = NGraphDTypes(); diff --git a/src/ngraph_rewrite_pass.cc b/src/ngraph_rewrite_pass.cc index 912cc4de..8d34ea8c 100644 --- a/src/ngraph_rewrite_pass.cc +++ b/src/ngraph_rewrite_pass.cc @@ -29,6 +29,10 @@ #include +#if defined NGRAPH_DISTRIBUTED +#include "ngraph/distributed.hpp" +#endif + using namespace std; namespace tensorflow { @@ -102,6 +106,11 @@ class NGraphRewritePass : public GraphOptimizationPass { static std::string GraphFilenamePrefix(std::string kind, int idx) { std::stringstream ss; ss << kind << "_" << std::setfill('0') << std::setw(4) << idx; +#if defined NGRAPH_DISTRIBUTED + ngraph::Distributed dist; + int Rank_ID = dist.get_rank(); + ss << "_" << std::setfill('0') << std::setw(4) << Rank_ID; +#endif return ss.str(); } static std::string GraphFilenamePrefix(std::string kind, int idx, @@ -109,6 +118,11 @@ class NGraphRewritePass : public GraphOptimizationPass { std::stringstream ss; ss << GraphFilenamePrefix(kind, idx) << "_" << std::setfill('0') << std::setw(4) << sub_idx; +#if defined NGRAPH_DISTRIBUTED + ngraph::Distributed dist; + int Rank_ID = dist.get_rank(); + ss << "_" << std::setfill('0') << std::setw(4) << Rank_ID; +#endif return ss.str(); }