Skip to content

Commit 77e9ffb

Browse files
Jiho Choitensorflower-gardener
Jiho Choi
authored andcommitted
Use the name scope of the forward pass in the backward pass.
PiperOrigin-RevId: 294732289 Change-Id: I15b20a033f66f4d4201b0eb6ea9be8b6cd82bf56
1 parent 7e2063e commit 77e9ffb

File tree

7 files changed

+65
-16
lines changed

7 files changed

+65
-16
lines changed

tensorflow/python/eager/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@ py_library(
514514
":tape",
515515
"//tensorflow/python:array_ops",
516516
"//tensorflow/python:constant_op",
517+
"//tensorflow/python:control_flow_util",
517518
"//tensorflow/python:dtypes",
518519
"//tensorflow/python:errors",
519520
"//tensorflow/python:framework_ops",

tensorflow/python/eager/backprop.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from tensorflow.python.framework import tensor_util
3939
from tensorflow.python.ops import array_ops
4040
from tensorflow.python.ops import check_ops
41+
from tensorflow.python.ops import control_flow_util
4142
from tensorflow.python.ops import default_gradient
4243
from tensorflow.python.ops import gen_array_ops
4344
from tensorflow.python.ops import gen_math_ops
@@ -123,7 +124,7 @@ def _get_control_flow_context(self):
123124

124125

125126
def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs,
126-
out_grads, skip_input_indices):
127+
out_grads, skip_input_indices, forward_pass_name_scope):
127128
"""Calls the gradient function of the op.
128129
129130
Args:
@@ -135,6 +136,7 @@ def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs,
135136
out_grads: gradients of the operation wrt its outputs.
136137
skip_input_indices: a tuple that is passed to the gradient function,
137138
indicating which inputs to skip calculating the gradient for
139+
forward_pass_name_scope: the namescope of the op in the forward pass.
138140
139141
Returns:
140142
The gradients with respect to the inputs of the function, as a list.
@@ -144,7 +146,17 @@ def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs,
144146
if grad_fn is None:
145147
return [None] * num_inputs
146148

147-
return grad_fn(mock_op, *out_grads)
149+
# This does not work with v1 TensorArrays.
150+
if ops.executing_eagerly_outside_functions(
151+
) or control_flow_util.EnableControlFlowV2(ops.get_default_graph()):
152+
if forward_pass_name_scope:
153+
gradient_name_scope = "gradient_tape/" + forward_pass_name_scope + "/"
154+
else:
155+
gradient_name_scope = "gradient_tape/"
156+
with ops.name_scope(gradient_name_scope):
157+
return grad_fn(mock_op, *out_grads)
158+
else:
159+
return grad_fn(mock_op, *out_grads)
148160

149161

150162
pywrap_tfe.TFE_Py_RegisterGradientFunction(_gradient_function)
@@ -155,7 +167,8 @@ def _must_record_gradient():
155167

156168

157169
def _record_gradient(op_name, inputs, attrs, results):
158-
return pywrap_tfe.TFE_Py_RecordGradient(op_name, inputs, attrs, results)
170+
return pywrap_tfe.TFE_Py_RecordGradient(op_name, inputs, attrs, results,
171+
ops.get_name_scope())
159172

160173

161174
execute.must_record_gradient = _must_record_gradient

tensorflow/python/eager/backprop_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,6 +1512,27 @@ def testWatchedVariablesRespectReset(self):
15121512
tape.gradient(z, z)
15131513
self.assertEqual((z,), tape.watched_variables())
15141514

1515+
def testNameScope(self):
1516+
def fn(x):
1517+
with ops.name_scope('my_scope'):
1518+
a = math_ops.cos(x)
1519+
b = math_ops.cos(x)
1520+
return math_ops.add(a, b)
1521+
1522+
@function.defun
1523+
def grad_fn(x):
1524+
return backprop.gradients_function(fn)(x)
1525+
1526+
grad_ops = grad_fn.get_concrete_function(
1527+
constant_op.constant(1.0)).graph.get_operations()
1528+
num_sin_ops_found = 0
1529+
for op in grad_ops:
1530+
if op.type == 'Sin':
1531+
num_sin_ops_found += 1
1532+
self.assertIn('gradient_tape/my_scope/', op.name)
1533+
self.assertEqual(num_sin_ops_found, 2)
1534+
1535+
15151536
class JacobianTest(test.TestCase):
15161537

15171538
def _jacobian(self, experimental_use_pfor):

tensorflow/python/eager/pywrap_tfe.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,8 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args);
275275

276276
// Record the gradient for a given op.
277277
PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
278-
PyObject* attrs, PyObject* results);
278+
PyObject* attrs, PyObject* results,
279+
PyObject* forward_pass_name_scope);
279280

280281
// Returns all variables watched by the given tape in the order those variables
281282
// were created.

tensorflow/python/eager/pywrap_tfe_src.cc

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2916,7 +2916,8 @@ PyObject* CopySequenceSettingIndicesToNull(
29162916
}
29172917

29182918
PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
2919-
PyObject* results) {
2919+
PyObject* results,
2920+
PyObject* forward_pass_name_scope = nullptr) {
29202921
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs);
29212922
if (PyErr_Occurred()) return nullptr;
29222923
std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs);
@@ -2997,16 +2998,21 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
29972998

29982999
PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs));
29993000

3001+
if (!forward_pass_name_scope) forward_pass_name_scope = Py_None;
3002+
30003003
TapeSetRecordOperation(
30013004
op_name, inputs, results, input_ids, input_dtypes,
3002-
[op_name, attrs, num_inputs, op_inputs, op_outputs]() {
3005+
[op_name, attrs, num_inputs, op_inputs, op_outputs,
3006+
forward_pass_name_scope]() {
30033007
Py_INCREF(op_name);
30043008
Py_INCREF(attrs);
30053009
Py_INCREF(num_inputs);
30063010
Py_INCREF(op_inputs);
30073011
Py_INCREF(op_outputs);
3012+
Py_INCREF(forward_pass_name_scope);
30083013
PyBackwardFunction* function = new PyBackwardFunction(
3009-
[op_name, attrs, num_inputs, op_inputs, op_outputs](
3014+
[op_name, attrs, num_inputs, op_inputs, op_outputs,
3015+
forward_pass_name_scope](
30103016
PyObject* output_grads,
30113017
const std::vector<tensorflow::int64>& unneeded_gradients) {
30123018
if (PyErr_Occurred()) {
@@ -3026,8 +3032,9 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
30263032
skip_input_indices.reset(Py_None);
30273033
}
30283034
tensorflow::Safe_PyObjectPtr callback_args(Py_BuildValue(
3029-
"OOOOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs,
3030-
output_grads, skip_input_indices.get()));
3035+
"OOOOOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs,
3036+
output_grads, skip_input_indices.get(),
3037+
forward_pass_name_scope));
30313038

30323039
tensorflow::Safe_PyObjectPtr result(
30333040
PyObject_CallObject(gradient_function, callback_args.get()));
@@ -3038,13 +3045,14 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
30383045
});
30393046
return function;
30403047
},
3041-
[op_name, attrs, num_inputs, op_inputs,
3042-
op_outputs](PyBackwardFunction* backward_function) {
3048+
[op_name, attrs, num_inputs, op_inputs, op_outputs,
3049+
forward_pass_name_scope](PyBackwardFunction* backward_function) {
30433050
Py_DECREF(op_name);
30443051
Py_DECREF(attrs);
30453052
Py_DECREF(num_inputs);
30463053
Py_DECREF(op_inputs);
30473054
Py_DECREF(op_outputs);
3055+
Py_DECREF(forward_pass_name_scope);
30483056

30493057
delete backward_function;
30503058
},
@@ -3668,12 +3676,14 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
36683676
}
36693677

36703678
PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
3671-
PyObject* attrs, PyObject* results) {
3679+
PyObject* attrs, PyObject* results,
3680+
PyObject* forward_pass_name_scope) {
36723681
if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) {
36733682
Py_RETURN_NONE;
36743683
}
36753684

3676-
return RecordGradient(op_name, inputs, attrs, results);
3685+
return RecordGradient(op_name, inputs, attrs, results,
3686+
forward_pass_name_scope);
36773687
}
36783688

36793689
namespace {

tensorflow/python/framework/op_callbacks_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -794,7 +794,8 @@ def get_gradients():
794794
self.assertIn(_COS_OP, instrument.graph_op_types)
795795

796796
# Check the ndarrays from runtime.
797-
cos_op_outputs = instrument.graph_internal_ndarrays[_COS_OP]
797+
cos_op_outputs = instrument.graph_internal_ndarrays[b"gradient_tape/" +
798+
_COS_OP]
798799
self.assertEqual(len(cos_op_outputs), 1)
799800
self.assertAllClose(cos_op_outputs[0], np.cos(3.0 * 3.0))
800801

tensorflow/python/tfe_wrapper.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -582,9 +582,11 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
582582
});
583583
m.def("TFE_Py_RecordGradient",
584584
[](const py::handle& op_name, const py::handle& inputs,
585-
const py::handle& attrs, const py::handle& results) {
585+
const py::handle& attrs, const py::handle& results,
586+
const py::handle& forward_pass_name_scope) {
586587
return tensorflow::pyo_or_throw(TFE_Py_RecordGradient(
587-
op_name.ptr(), inputs.ptr(), attrs.ptr(), results.ptr()));
588+
op_name.ptr(), inputs.ptr(), attrs.ptr(), results.ptr(),
589+
forward_pass_name_scope.ptr()));
588590
});
589591
m.def("TFE_Py_UID", []() { return tensorflow::pyo_or_throw(TFE_Py_UID()); });
590592

0 commit comments

Comments
 (0)