Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit a8d6db8

Browse files
Merge branch 'master' into r0.12
2 parents cbbdf34 + a39cd8f commit a8d6db8

31 files changed

+4004
-12
lines changed

build_ngtf.py

+10
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ def main():
7070
action="store")
7171

7272
parser.add_argument(
73+
'--enable_variables_and_optimizers',
74+
help="Ops like variable and optimizers are supported by nGraph in this version of the bridge\n",
75+
action="store_true")
76+
77+
parser.add_argument(
7378
'--use_grappler_optimizer',
7479
help="Use Grappler optimizer instead of the optimization passes\n",
7580
action="store_true")
@@ -264,6 +269,11 @@ def main():
264269
else:
265270
ngraph_tf_cmake_flags.extend(["-DNGRAPH_DISTRIBUTED_ENABLE=FALSE"])
266271

272+
if (arguments.enable_variables_and_optimizers):
273+
ngraph_tf_cmake_flags.extend(["-DNGRAPH_TF_ENABLE_VARIABLES_AND_OPTIMIZERS=TRUE"])
274+
else:
275+
ngraph_tf_cmake_flags.extend(["-DNGRAPH_TF_ENABLE_VARIABLES_AND_OPTIMIZERS=FALSE"])
276+
267277
if (arguments.use_grappler_optimizer):
268278
ngraph_tf_cmake_flags.extend(
269279
["-DNGRAPH_TF_USE_GRAPPLER_OPTIMIZER=TRUE"])

src/CMakeLists.txt

+36-3
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,51 @@ set(SRC
4343
ngraph_freshness_tracker.cc
4444
ngraph_mark_for_clustering.cc
4545
ngraph_rewrite_for_tracking.cc
46+
ngraph_rewrite_pass.cc
4647
ngraph_tracked_variable.cc
4748
ngraph_utils.cc
4849
tf_graphcycles.cc
4950
tf_deadness_analysis.cc
5051
version.cc
5152
)
5253

53-
if(NGRAPH_TF_USE_GRAPPLER_OPTIMIZER)
54+
message(STATUS "NGRAPH_TF_ENABLE_VARIABLES_AND_OPTIMIZERS: ${NGRAPH_TF_ENABLE_VARIABLES_AND_OPTIMIZERS}")
55+
56+
if(NGRAPH_TF_ENABLE_VARIABLES_AND_OPTIMIZERS)
57+
# common files
58+
list(REMOVE_ITEM SRC ngraph_capture_variables.cc)
59+
list(APPEND SRC enable_variable_ops/ngraph_capture_variables.cc)
60+
61+
list(REMOVE_ITEM SRC ngraph_encapsulate_op.cc)
62+
list(APPEND SRC enable_variable_ops/ngraph_encapsulate_op.cc)
63+
64+
list(REMOVE_ITEM SRC ngraph_rewrite_for_tracking.cc)
65+
list(APPEND SRC enable_variable_ops/ngraph_rewrite_for_tracking.cc)
66+
67+
list(REMOVE_ITEM SRC ngraph_rewrite_pass.cc)
68+
list(APPEND SRC enable_variable_ops/ngraph_rewrite_pass.cc)
69+
70+
list(REMOVE_ITEM SRC ngraph_tracked_variable.cc)
71+
list(APPEND SRC enable_variable_ops/ngraph_tracked_variable.cc)
72+
73+
list(REMOVE_ITEM SRC ngraph_utils.cc)
74+
list(APPEND SRC enable_variable_ops/ngraph_utils.cc)
75+
76+
# new files
77+
list(APPEND SRC enable_variable_ops/ngraph_assign_op.cc)
78+
list(APPEND SRC enable_variable_ops/ngraph_catalog.cc)
79+
list(APPEND SRC enable_variable_ops/ngraph_enter_in_catalog.cc)
80+
list(APPEND SRC enable_variable_ops/ngraph_replace_op_utilities.cc)
81+
list(APPEND SRC enable_variable_ops/ngraph_replace_variable_modifiers.cc)
82+
list(APPEND SRC enable_variable_ops/ngraph_variable_modifiers.cc)
83+
84+
endif()
85+
86+
87+
if(NGRAPH_TF_USE_GRAPPLER_OPTIMIZER)
88+
list(REMOVE_ITEM SRC ngraph_rewrite_pass.cc)
5489
list(APPEND SRC grappler/ngraph_optimizer.cc)
5590
add_definitions(-DNGRAPH_TF_USE_GRAPPLER_OPTIMIZER)
56-
else()
57-
list(APPEND SRC ngraph_rewrite_pass.cc)
5891
endif()
5992

6093
message(STATUS "NGRAPH_TF_USE_GRAPPLER_OPTIMIZER: ${NGRAPH_TF_USE_GRAPPLER_OPTIMIZER}")
+185
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
/*******************************************************************************
2+
* Copyright 2019 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use thi0s file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#include "tensorflow/core/common_runtime/dma_helper.h"
18+
#include "tensorflow/core/framework/op.h"
19+
#include "tensorflow/core/framework/op_kernel.h"
20+
#include "tensorflow/core/framework/resource_mgr.h"
21+
#include "tensorflow/core/lib/strings/strcat.h"
22+
23+
#include "tensorflow/core/framework/op_kernel.h"
24+
#include "tensorflow/core/framework/tensor_types.h"
25+
#include "tensorflow/core/platform/default/logging.h"
26+
27+
#include "ngraph/event_tracing.hpp"
28+
#include "ngraph/runtime/backend.hpp"
29+
#include "ngraph_catalog.h"
30+
#include "ngraph_freshness_tracker.h"
31+
#include "ngraph_timer.h"
32+
#include "ngraph_utils.h"
33+
#include "ngraph_var.h"
34+
35+
using namespace std;
36+
namespace ng = ngraph;
37+
38+
namespace tensorflow {
39+
40+
namespace ngraph_bridge {
41+
42+
/* -------------------------------------------------
43+
//
44+
// NGraphAssignOp
45+
//
46+
---------------------------------------------------*/
47+
48+
// Computes *input[0] = input[1]
49+
class NGraphAssignOp : public OpKernel {
50+
private:
51+
bool just_looking_;
52+
bool copy_to_tf_;
53+
int ng_graph_id_;
54+
static int s_instance_count;
55+
int my_instance_id{0};
56+
57+
// TODO(malikshr): Do we need these attributes, exist in TF Assign ops
58+
// use_exclusive_lock_, validate_shape_, relax_constraints_;
59+
60+
public:
61+
explicit NGraphAssignOp(OpKernelConstruction* context)
62+
: OpKernel(context), just_looking_(false), copy_to_tf_(false) {
63+
OP_REQUIRES_OK(context, context->GetAttr("just_looking", &just_looking_));
64+
OP_REQUIRES_OK(context, context->GetAttr("copy_to_tf", &copy_to_tf_));
65+
OP_REQUIRES_OK(context, context->GetAttr("ngraph_graph_id", &ng_graph_id_));
66+
67+
NGRAPH_VLOG(4) << "NGraphAssign:: Constructor called for: " << def().name()
68+
<< ",just looking " << PrintBool(just_looking_)
69+
<< ",copy-to-tf " << PrintBool(copy_to_tf_) << " ,Graph ID "
70+
<< ng_graph_id_;
71+
72+
OP_REQUIRES(context, IsRefType(context->input_type(0)),
73+
errors::InvalidArgument("lhs input needs to be a ref type"));
74+
my_instance_id = s_instance_count;
75+
s_instance_count++;
76+
}
77+
78+
void Compute(OpKernelContext* context) override {
79+
std::ostringstream oss;
80+
oss << "Execute: Assign_" << my_instance_id << ": " << name();
81+
ngraph::Event event_compute(oss.str(), name(), "");
82+
83+
NGRAPH_VLOG(4) << "NGraphAssign:: Compute called for: " << def().name()
84+
<< " ,just looking " << PrintBool(just_looking_)
85+
<< " ,copy-to-tf " << PrintBool(copy_to_tf_) << " ,Graph ID "
86+
<< ng_graph_id_;
87+
88+
bool log_copies = false;
89+
OP_REQUIRES_OK(context, IsCopyLogEnabled(ng_graph_id_, log_copies));
90+
std::stringstream copy_log_str;
91+
copy_log_str << "KERNEL[" << type_string() << "]: " << name()
92+
<< " ,Copy_TF " << PrintBool(copy_to_tf_) << " ,Just_Looking "
93+
<< PrintBool(just_looking_) << "\n";
94+
int number_of_copies = 0;
95+
96+
bool ref_exists = NGraphCatalog::ExistsInInputVariableSharedNameMap(
97+
ng_graph_id_, def().name(), 0);
98+
if (!ref_exists) {
99+
OP_REQUIRES(context, ref_exists,
100+
errors::Internal(
101+
"Caught exception : RefInput to NGAssign not found \n"));
102+
}
103+
string get_ref_var_name = NGraphCatalog::GetInputVariableSharedName(
104+
ng_graph_id_, def().name(), 0);
105+
106+
NGraphVar* var;
107+
OP_REQUIRES_OK(context,
108+
context->resource_manager()->Lookup<NGraphVar>(
109+
context->resource_manager()->default_container(),
110+
get_ref_var_name, &var));
111+
112+
const Tensor& rhs = context->input(1);
113+
114+
// We always return the input ref.
115+
context->forward_ref_input_to_ref_output(0, 0);
116+
117+
// get the nGraphTensor
118+
shared_ptr<ngraph::runtime::Tensor> ng_tensor_to_assign = var->ng_tensor();
119+
120+
// DO NOT CARE ABOUT SYNCING AS WE ARE ALWAYS SETTING THE NGTENSOR
121+
122+
// Get input[1]
123+
string valkey = to_string(ng_graph_id_) + "_" + def().input(1);
124+
bool valref_exists = NGraphCatalog::ExistsInEncapOutputTensorMap(valkey);
125+
if (valref_exists) {
126+
// Value is from encap
127+
NGRAPH_VLOG(4) << "NGraphAssign::Getting from catalog: " << valkey;
128+
auto ng_val = NGraphCatalog::GetTensorFromEncapOutputTensorMap(valkey);
129+
ng_tensor_to_assign->copy_from(*ng_val);
130+
} else {
131+
number_of_copies++;
132+
copy_log_str << " COPY_INP_VAL[0]";
133+
NGRAPH_VLOG(4) << "NGraphAssign::Getting from TF : " << valkey;
134+
void* tf_src_ptr = (void*)DMAHelper::base(&rhs);
135+
ng_tensor_to_assign->write(
136+
tf_src_ptr, 0, ng_tensor_to_assign->get_element_count() *
137+
ng_tensor_to_assign->get_element_type().size());
138+
}
139+
140+
mutex_lock l(*context->input_ref_mutex(0));
141+
Tensor old_lhs = context->mutable_input(0, /* lock_held */ true);
142+
143+
if (copy_to_tf_) {
144+
number_of_copies++;
145+
copy_log_str << " COPY_TF ";
146+
ReadNGTensor(ng_tensor_to_assign, &old_lhs);
147+
148+
if (!just_looking_) {
149+
// Some tf op might update the ng-tensor value so mark it stale
150+
copy_log_str << " SET_SYNC ";
151+
var->sync_ng_tensor(true);
152+
}
153+
}
154+
155+
copy_log_str << " Number of copies " << number_of_copies << "\n";
156+
if (log_copies) {
157+
cout << copy_log_str.str();
158+
}
159+
160+
// Unref Var
161+
var->Unref();
162+
event_compute.Stop();
163+
ngraph::Event::write_trace(event_compute);
164+
}
165+
};
166+
167+
int NGraphAssignOp::s_instance_count = 0;
168+
169+
REGISTER_OP("NGraphAssign")
170+
.Input("ref: Ref(T)")
171+
.Input("value: T")
172+
.Output("output_ref: Ref(T)")
173+
.Attr("T: type")
174+
.Attr("validate_shape: bool = true")
175+
.Attr("use_locking: bool = true")
176+
.Attr("just_looking: bool = false")
177+
.Attr("copy_to_tf: bool = false")
178+
.Attr("ngraph_graph_id: int");
179+
180+
REGISTER_KERNEL_BUILDER(Name("NGraphAssign").Device(DEVICE_CPU),
181+
NGraphAssignOp);
182+
183+
} // namespace ngraph_bridge
184+
185+
} // namespace tensorflow
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*******************************************************************************
2+
* Copyright 2017-2019 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#include "tensorflow/core/graph/graph.h"
18+
#include "tensorflow/core/graph/node_builder.h"
19+
20+
#include "ngraph_api.h"
21+
#include "ngraph_capture_variables.h"
22+
#include "ngraph_replace_op_utilities.h"
23+
#include "ngraph_utils.h"
24+
25+
using namespace std;
26+
27+
namespace tensorflow {
28+
29+
namespace ngraph_bridge {
30+
31+
//
32+
// Utility function to check if placement on the NGRAPH device has been
33+
// requested.
34+
//
35+
// FIXME(amprocte): stubbed out for now because NGRAPH device is gone.
36+
//
37+
static bool NGraphPlacementRequested(const Node* node) { return true; }
38+
39+
//
40+
// Main entry point for the variable-capture.
41+
//
42+
Status CaptureVariables(Graph* graph, std::vector<string> skip_these_nodes) {
43+
const static std::map<
44+
const string,
45+
const pair<string,
46+
function<Status(
47+
Graph * graph, Node * node, Node * *replacement,
48+
const string replacement_node_name,
49+
const string replacement_op_type, const bool just_looking,
50+
const bool outputs_ng_supported, const int graph_id,
51+
const bool is_backend_set)>>>
52+
CAPTURE_REPLACE_OP_MAP{
53+
{"ApplyGradientDescent", std::make_pair("NGraphApplyGradientDescent",
54+
ReplaceApplyGradientDescent)},
55+
{"Assign", std::make_pair("NGraphAssign", ReplaceAssign)},
56+
{"AssignAdd", std::make_pair("NGraphAssignAdd", ReplaceAssign)},
57+
{"AssignSub", std::make_pair("NGraphAssignSub", ReplaceAssign)},
58+
{"VariableV2", std::make_pair("NGraphVariable", ReplaceVariable)}};
59+
60+
std::vector<Node*> replaced_nodes;
61+
for (auto node : graph->op_nodes()) {
62+
if (NGraphPlacementRequested(node)) {
63+
auto itr = CAPTURE_REPLACE_OP_MAP.find(node->type_string());
64+
if (itr != CAPTURE_REPLACE_OP_MAP.end()) {
65+
NGRAPH_VLOG(1) << "Capturing: " << node->name();
66+
Node* replacement;
67+
68+
// Create the replacement node
69+
TF_RETURN_IF_ERROR((itr->second.second)(graph, node, &replacement,
70+
node->name(), itr->second.first,
71+
false, false, 0, false));
72+
73+
std::vector<const Edge*> edges;
74+
75+
NGRAPH_VLOG(4) << "Replacing Node " << node->DebugString() << " with "
76+
<< replacement->DebugString();
77+
78+
TF_RETURN_IF_ERROR(ReplaceInputControlEdges(graph, node, replacement));
79+
TF_RETURN_IF_ERROR(ReplaceOutputEdges(graph, node, replacement));
80+
81+
replaced_nodes.push_back(node);
82+
}
83+
84+
} // end of checking NGraphPlacementRequested
85+
} // end of looping through nodes in the graph
86+
87+
for (auto node : replaced_nodes) {
88+
NGRAPH_VLOG(4) << "Removing: " << node->name();
89+
graph->RemoveNode(node);
90+
}
91+
92+
return Status::OK();
93+
}
94+
95+
} // namespace ngraph_bridge
96+
97+
} // namespace tensorflow

0 commit comments

Comments
 (0)