1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/gradients.h"
17
18#include <deque>
19#include <vector>
20
21#include "tensorflow/core/common_runtime/device.h"
22#include "tensorflow/core/common_runtime/executor.h"
23#include "tensorflow/core/common_runtime/graph_constructor.h"
24#include "tensorflow/core/common_runtime/graph_optimizer.h"
25#include "tensorflow/core/framework/function.h"
26#include "tensorflow/core/framework/node_def.pb.h"
27#include "tensorflow/core/framework/node_def_util.h"
28#include "tensorflow/core/framework/op.h"
29#include "tensorflow/core/framework/op_kernel.h"
30#include "tensorflow/core/graph/algorithm.h"
31#include "tensorflow/core/graph/optimizer_cse.h"
32#include "tensorflow/core/lib/gtl/map_util.h"
33#include "tensorflow/core/platform/macros.h"
34
35namespace tensorflow {
36
37// TODO(andydavis) Remove some of the code duplicated between this module
38// and that in 'common_runtime/function.cc'.
39// A few string constant used throughout this module.
40static const char* const kGradientOp = "SymbolicGradient";
41static const char* const kNodeLabel = "Func";
42
43string NodeOut::name() const {
44 if (index == 0) {
45 return node->name();
46 } else {
47 return strings::StrCat(node->name(), ":", index);
48 }
49}
50
51DataType NodeOut::dtype() const { return node->output_type(index); }
52
53struct NodeOutHash {
54 uint64 operator()(const NodeOut& x) const {
55 return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*),
56 x.index);
57 }
58};
59
60struct NodeOutEq {
61 bool operator()(const NodeOut& x, const NodeOut& y) const {
62 return (x.node == y.node) && (x.index == y.index);
63 }
64};
65
66static Node* AddZerosLike(Graph* g, NodeOut input) {
67 DCHECK_LT(0, input.dtype());
68 DCHECK_LT(input.dtype(), DT_FLOAT_REF);
69 if (input.dtype() == DT_RESOURCE) {
70 NodeDef read_def;
71 read_def.set_name(g->NewName("Read"));
72 read_def.set_op("ReadVariableOp");
73 read_def.add_input(input.name());
74 AddNodeAttr("dtype", DT_FLOAT, &read_def);
75 Status s;
76 Node* read = g->AddNode(read_def, &s);
77 TF_CHECK_OK(s);
78 g->AddEdge(input.node, input.index, read, 0);
79 NodeDef ndef;
80 ndef.set_name(g->NewName(kNodeLabel));
81 ndef.set_op("ZerosLike");
82 ndef.add_input(read_def.name());
83 AddNodeAttr("T", DT_FLOAT, &ndef);
84 Node* ret = g->AddNode(ndef, &s);
85 TF_CHECK_OK(s);
86 g->AddEdge(read, 0, ret, 0);
87 return ret;
88 } else {
89 NodeDef ndef;
90 ndef.set_name(g->NewName(kNodeLabel));
91 ndef.set_op("ZerosLike");
92 ndef.add_input(input.name());
93 AddNodeAttr("T", input.dtype(), &ndef);
94 Status s;
95 Node* ret = g->AddNode(ndef, &s);
96 TF_CHECK_OK(s);
97 g->AddEdge(input.node, input.index, ret, 0);
98 return ret;
99 }
100}
101
102static Node* AddSymGrad(Graph* g, Node* n, gtl::ArraySlice<NodeOut> grads) {
103 const int num_x = n->num_inputs();
104 const int num_y = n->num_outputs();
105 CHECK_EQ(num_y, grads.size());
106
107 NodeDef ndef;
108 ndef.set_name(g->NewName(kNodeLabel));
109 ndef.set_op(kGradientOp);
110
111 // The gradient node should have num_x + num_y inputs.
112 std::vector<NodeOut> n_inputs(num_x);
113 for (const Edge* e : n->in_edges()) {
114 if (e->IsControlEdge()) continue;
115 n_inputs[e->dst_input()] = {e->src(), e->src_output()};
116 }
117 DataTypeVector in_types;
118 for (const NodeOut& nout : n_inputs) {
119 ndef.add_input(nout.name());
120 in_types.push_back(nout.dtype());
121 }
122 for (const NodeOut& nout : grads) {
123 ndef.add_input(nout.name());
124 in_types.push_back(nout.dtype());
125 }
126 CHECK_EQ(ndef.input_size(), num_x + num_y);
127
128 AddNodeAttr("Tin", in_types, &ndef);
129
130 // The gradient node's outputs have the same types as the node 'n's
131 // inputs, except for resources.
132 DataTypeVector out_types = n->input_types();
133 for (int i = 0, end = out_types.size(); i < end; ++i) {
134 if (out_types[i] == DT_RESOURCE) {
135 // TODO(apassos): figure out how to get the right dtype
136 out_types[i] = DT_FLOAT;
137 }
138 }
139 AddNodeAttr("Tout", out_types, &ndef);
140 NameAttrList func;
141 func.set_name(n->type_string());
142 for (const auto& attr : n->attrs()) {
143 (*func.mutable_attr())[attr.first] = attr.second;
144 }
145 AddNodeAttr("f", func, &ndef);
146 Status s;
147 Node* ret = g->AddNode(ndef, &s);
148 TF_CHECK_OK(s);
149 return ret;
150}
151
152class SymbolicGradientBuilder {
153 public:
154 SymbolicGradientBuilder(gtl::ArraySlice<NodeOut> y_node_outputs,
155 gtl::ArraySlice<NodeOut> x_node_outputs,
156 gtl::ArraySlice<NodeOut> y_grad_node_outputs,
157 std::vector<NodeOut>* x_grad_node_outputs,
158 Graph* graph);
159
160 Status Compute();
161
162 private:
163 gtl::ArraySlice<NodeOut> y_node_outputs_;
164 gtl::ArraySlice<NodeOut> x_node_outputs_;
165 gtl::ArraySlice<NodeOut> y_grad_node_outputs_;
166 std::vector<NodeOut>* x_grad_node_outputs_;
167 Graph* graph_; // Not owned.
168
169 // A vector of output endpoints which represents backpropagated
170 // gradients
171 typedef std::vector<NodeOut> BackproppedGradients;
172
173 // backprops_ is a map from a node output to its accumulated
174 // gradients. When a node output has accumulated all its
175 // gradients, we add a node which sums them up.
176 std::unordered_map<NodeOut, BackproppedGradients, NodeOutHash, NodeOutEq>
177 backprops_;
178
179 // pending[i] is count-down counter for i-th node's expected
180 // backprops. When pending[i] becomes zero, we collected all
181 // backprop gradients for all outputs of the ith-node.
182 std::vector<int> pending_;
183
184 // 'ready' keeps track of nodes that have been completely
185 // backpropped. Initially, for every output y of the function f, we
186 // add dy as an input of the gradient function.
187 std::deque<Node*> ready_;
188
189 // The set of node ids at which to stop backprop.
190 std::unordered_set<int> stop_nodes_;
191
192 // Initialize pending_ and ready_.
193 void InitBackprop();
194
195 // In the original function body, there is a forward edge from 'src'
196 // to 'dst', when the backprop algorithm constructs the node
197 // 'dst_grad' which computes the gradient, we need to propagate it
198 // to 'src'.
199 void BackpropAlongEdge(const NodeOut& dst_grad, const NodeOut& src);
200 void BackpropZerosAlongEdge(const NodeOut& src);
201
202 // Returns a node representing the sum of any backpropped gradients for 'src'.
203 // This will be an AddN node if there is more than one accumulated gradient.
204 // Returns zeros if there are no gradients, or the dtype is DT_BOOL.
205 NodeOut SumGradients(const NodeOut& src);
206
207 TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientBuilder);
208};
209
210SymbolicGradientBuilder::SymbolicGradientBuilder(
211 gtl::ArraySlice<NodeOut> y_node_outputs,
212 gtl::ArraySlice<NodeOut> x_node_outputs,
213 gtl::ArraySlice<NodeOut> y_grad_node_outputs,
214 std::vector<NodeOut>* x_grad_node_outputs, Graph* graph)
215 : y_node_outputs_(y_node_outputs),
216 x_node_outputs_(x_node_outputs),
217 y_grad_node_outputs_(y_grad_node_outputs),
218 x_grad_node_outputs_(x_grad_node_outputs),
219 graph_(graph) {
220 CHECK_EQ(y_node_outputs_.size(), y_grad_node_outputs.size());
221 x_grad_node_outputs_->clear();
222 x_grad_node_outputs_->resize(x_node_outputs_.size());
223 stop_nodes_.reserve(x_node_outputs_.size());
224 for (int i = 0, end = x_node_outputs_.size(); i < end; ++i) {
225 stop_nodes_.insert(x_node_outputs_[i].node->id());
226 }
227}
228
229void SymbolicGradientBuilder::BackpropAlongEdge(const NodeOut& dst_grad,
230 const NodeOut& src) {
231 CHECK_NOTNULL(src.node);
232 auto iter = backprops_.find(src);
233 if (iter != backprops_.end()) {
234 auto* grads = &iter->second;
235 grads->push_back(dst_grad);
236 if (--pending_[src.node->id()] == 0) {
237 ready_.push_back(src.node);
238 }
239 }
240}
241
242void SymbolicGradientBuilder::BackpropZerosAlongEdge(const NodeOut& src) {
243 CHECK_NOTNULL(src.node);
244 auto iter = backprops_.find(src);
245 if (iter != backprops_.end()) {
246 if (--pending_[src.node->id()] == 0) {
247 ready_.push_back(src.node);
248 }
249 }
250}
251
252void SymbolicGradientBuilder::InitBackprop() {
253 pending_.resize(graph_->num_node_ids(), 0);
254 {
255 backprops_.clear();
256 std::unordered_set<Node*> visited;
257 std::deque<Node*> queue;
258 for (const NodeOut& nout : y_node_outputs_) {
259 queue.push_back(nout.node);
260 visited.insert(nout.node);
261 }
262
263 // Going forward to figure out which endpoints need backprop-ed.
264 // A node's endpoints need to be backprop-ed only if one of the
265 // return nodes can reach backwards to the node via data edges.
266 while (!queue.empty()) {
267 Node* n = queue.front();
268 queue.pop_front();
269 for (int i = 0; i < n->num_outputs(); ++i) {
270 backprops_[{n, i}].clear();
271 }
272 for (const Edge* e : n->in_edges()) {
273 if (e->IsControlEdge()) continue;
274 pending_[e->src()->id()]++;
275 if (visited.find(e->src()) == visited.end()) {
276 queue.push_back(e->src());
277 visited.insert(e->src());
278 }
279 }
280 }
281
282 // Create entries in backprops_ for all x_node_outputs_, because they will
283 // not be added in above loop if they are not reverse reachable from
284 // y_node_outputs_.
285 for (const NodeOut& nout : x_node_outputs_) {
286 backprops_[{nout.node, nout.index}].clear();
287 }
288 }
289
290 {
291 const int num_y = y_grad_node_outputs_.size();
292 for (int i = 0; i < num_y; ++i) {
293 Node* y = y_node_outputs_[i].node;
294 for (const Edge* e : y->in_edges()) {
295 if (e->IsControlEdge()) continue;
296 BackpropAlongEdge(y_grad_node_outputs_[i], {e->src(), e->src_output()});
297 }
298 }
299 }
300 CHECK(!ready_.empty());
301}
302
303NodeOut SymbolicGradientBuilder::SumGradients(const NodeOut& src) {
304 const DataType dtype = src.dtype();
305 auto iter = backprops_.find(src);
306 CHECK(iter != backprops_.end());
307 const auto& grads = iter->second;
308 if (grads.empty() || dtype == DT_BOOL) {
309 // Nothing propagated back. The best we can come up is zeros.
310 Node* zero_like = AddZerosLike(graph_, src);
311 return {zero_like, 0};
312 }
313 if (grads.size() == 1) {
314 // Just one backprop edge.
315 return grads[0];
316 }
317 // Otherwise, adds backprop-ed gradients.
318 NodeDef ndef;
319 ndef.set_name(graph_->NewName(kNodeLabel));
320 ndef.set_op("AddN"); // N-way Add
321 for (const NodeOut& nout : grads) {
322 ndef.add_input(nout.name());
323 }
324 AddNodeAttr("N", static_cast<int64_t>(grads.size()), &ndef);
325 AddNodeAttr("T", dtype, &ndef);
326 Status s;
327 Node* add = graph_->AddNode(ndef, &s);
328 TF_CHECK_OK(s);
329 for (size_t i = 0; i < grads.size(); ++i) {
330 const NodeOut& nout = grads[i];
331 graph_->AddEdge(nout.node, nout.index, add, i);
332 }
333 return {add, 0};
334}
335
336static bool IsPrimitiveOpWithNoGrad(const string& func) {
337 gradient::Creator creator;
338 Status s = gradient::GetOpGradientCreator(func, &creator);
339 return s.ok() && (creator == nullptr);
340}
341
342Status SymbolicGradientBuilder::Compute() {
343 // Initialize backprops.
344 InitBackprop();
345
346 // Backward propagation.
347 gtl::InlinedVector<NodeOut, 8> dy;
348 while (!ready_.empty()) {
349 // n has collected all gradients.
350 Node* n = ready_.front();
351 ready_.pop_front();
352
353 // "n" has num_x inputs and num_y outputs.
354 const int num_x = n->num_inputs();
355 const int num_y = n->num_outputs();
356
357 auto iter = stop_nodes_.find(n->id());
358 if (iter != stop_nodes_.end()) {
359 // Stop backprop.
360 // TODO(andydavis) Support stop nodes with more than one output.
361 CHECK_EQ(1, num_y);
362 continue;
363 }
364
365 // dy[i] is the sum of i-th output's backpropped gradients.
366 dy.clear();
367 dy.resize(num_y, {nullptr, 0});
368 for (int i = 0; i < num_y; ++i) {
369 dy[i] = SumGradients({n, i});
370 }
371
372 if (IsPrimitiveOpWithNoGrad(n->type_string())) {
373 // No grad defined for this op: Backprop zeros along the in edges.
374 for (const Edge* e : n->in_edges()) {
375 if (e->IsControlEdge()) continue;
376 BackpropZerosAlongEdge({e->src(), e->src_output()});
377 }
378 continue;
379 }
380
381 // Adds a gradient node with num_x + num_y inputs and num_x
382 // outputs.
383 // TODO(andydavis) Support primitive gradient ops.
384 Node* grad = AddSymGrad(graph_, n, dy);
385 for (const Edge* e : n->in_edges()) {
386 if (e->IsControlEdge()) continue;
387 graph_->AddEdge(e->src(), e->src_output(), grad, e->dst_input());
388 }
389 for (int i = 0; i < num_y; ++i) {
390 graph_->AddEdge(dy[i].node, dy[i].index, grad, num_x + i);
391 }
392
393 // Backprops along the in edges.
394 for (const Edge* e : n->in_edges()) {
395 if (e->IsControlEdge()) continue;
396 BackpropAlongEdge({grad, e->dst_input()}, {e->src(), e->src_output()});
397 }
398 }
399
400 for (int i = 0, end = x_node_outputs_.size(); i < end; ++i) {
401 (*x_grad_node_outputs_)[i] = SumGradients(x_node_outputs_[i]);
402 }
403
404 return OkStatus();
405}
406
407Status AddSymbolicGradients(gtl::ArraySlice<NodeOut> y_node_outputs,
408 gtl::ArraySlice<NodeOut> x_node_outputs,
409 gtl::ArraySlice<NodeOut> y_grad_node_outputs,
410 std::vector<NodeOut>* x_grad_node_outputs,
411 Graph* graph) {
412 SymbolicGradientBuilder builder(y_node_outputs, x_node_outputs,
413 y_grad_node_outputs, x_grad_node_outputs,
414 graph);
415 return builder.Compute();
416}
417
418} // end namespace tensorflow
419