1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/gradients.h" |
17 | |
18 | #include <deque> |
19 | #include <vector> |
20 | |
21 | #include "tensorflow/core/common_runtime/device.h" |
22 | #include "tensorflow/core/common_runtime/executor.h" |
23 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
24 | #include "tensorflow/core/common_runtime/graph_optimizer.h" |
25 | #include "tensorflow/core/framework/function.h" |
26 | #include "tensorflow/core/framework/node_def.pb.h" |
27 | #include "tensorflow/core/framework/node_def_util.h" |
28 | #include "tensorflow/core/framework/op.h" |
29 | #include "tensorflow/core/framework/op_kernel.h" |
30 | #include "tensorflow/core/graph/algorithm.h" |
31 | #include "tensorflow/core/graph/optimizer_cse.h" |
32 | #include "tensorflow/core/lib/gtl/map_util.h" |
33 | #include "tensorflow/core/platform/macros.h" |
34 | |
35 | namespace tensorflow { |
36 | |
37 | // TODO(andydavis) Remove some of the code duplicated between this module |
38 | // and that in 'common_runtime/function.cc'. |
39 | // A few string constant used throughout this module. |
40 | static const char* const kGradientOp = "SymbolicGradient" ; |
41 | static const char* const kNodeLabel = "Func" ; |
42 | |
43 | string NodeOut::name() const { |
44 | if (index == 0) { |
45 | return node->name(); |
46 | } else { |
47 | return strings::StrCat(node->name(), ":" , index); |
48 | } |
49 | } |
50 | |
51 | DataType NodeOut::dtype() const { return node->output_type(index); } |
52 | |
53 | struct NodeOutHash { |
54 | uint64 operator()(const NodeOut& x) const { |
55 | return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*), |
56 | x.index); |
57 | } |
58 | }; |
59 | |
60 | struct NodeOutEq { |
61 | bool operator()(const NodeOut& x, const NodeOut& y) const { |
62 | return (x.node == y.node) && (x.index == y.index); |
63 | } |
64 | }; |
65 | |
66 | static Node* AddZerosLike(Graph* g, NodeOut input) { |
67 | DCHECK_LT(0, input.dtype()); |
68 | DCHECK_LT(input.dtype(), DT_FLOAT_REF); |
69 | if (input.dtype() == DT_RESOURCE) { |
70 | NodeDef read_def; |
71 | read_def.set_name(g->NewName("Read" )); |
72 | read_def.set_op("ReadVariableOp" ); |
73 | read_def.add_input(input.name()); |
74 | AddNodeAttr("dtype" , DT_FLOAT, &read_def); |
75 | Status s; |
76 | Node* read = g->AddNode(read_def, &s); |
77 | TF_CHECK_OK(s); |
78 | g->AddEdge(input.node, input.index, read, 0); |
79 | NodeDef ndef; |
80 | ndef.set_name(g->NewName(kNodeLabel)); |
81 | ndef.set_op("ZerosLike" ); |
82 | ndef.add_input(read_def.name()); |
83 | AddNodeAttr("T" , DT_FLOAT, &ndef); |
84 | Node* ret = g->AddNode(ndef, &s); |
85 | TF_CHECK_OK(s); |
86 | g->AddEdge(read, 0, ret, 0); |
87 | return ret; |
88 | } else { |
89 | NodeDef ndef; |
90 | ndef.set_name(g->NewName(kNodeLabel)); |
91 | ndef.set_op("ZerosLike" ); |
92 | ndef.add_input(input.name()); |
93 | AddNodeAttr("T" , input.dtype(), &ndef); |
94 | Status s; |
95 | Node* ret = g->AddNode(ndef, &s); |
96 | TF_CHECK_OK(s); |
97 | g->AddEdge(input.node, input.index, ret, 0); |
98 | return ret; |
99 | } |
100 | } |
101 | |
102 | static Node* AddSymGrad(Graph* g, Node* n, gtl::ArraySlice<NodeOut> grads) { |
103 | const int num_x = n->num_inputs(); |
104 | const int num_y = n->num_outputs(); |
105 | CHECK_EQ(num_y, grads.size()); |
106 | |
107 | NodeDef ndef; |
108 | ndef.set_name(g->NewName(kNodeLabel)); |
109 | ndef.set_op(kGradientOp); |
110 | |
111 | // The gradient node should have num_x + num_y inputs. |
112 | std::vector<NodeOut> n_inputs(num_x); |
113 | for (const Edge* e : n->in_edges()) { |
114 | if (e->IsControlEdge()) continue; |
115 | n_inputs[e->dst_input()] = {e->src(), e->src_output()}; |
116 | } |
117 | DataTypeVector in_types; |
118 | for (const NodeOut& nout : n_inputs) { |
119 | ndef.add_input(nout.name()); |
120 | in_types.push_back(nout.dtype()); |
121 | } |
122 | for (const NodeOut& nout : grads) { |
123 | ndef.add_input(nout.name()); |
124 | in_types.push_back(nout.dtype()); |
125 | } |
126 | CHECK_EQ(ndef.input_size(), num_x + num_y); |
127 | |
128 | AddNodeAttr("Tin" , in_types, &ndef); |
129 | |
130 | // The gradient node's outputs have the same types as the node 'n's |
131 | // inputs, except for resources. |
132 | DataTypeVector out_types = n->input_types(); |
133 | for (int i = 0, end = out_types.size(); i < end; ++i) { |
134 | if (out_types[i] == DT_RESOURCE) { |
135 | // TODO(apassos): figure out how to get the right dtype |
136 | out_types[i] = DT_FLOAT; |
137 | } |
138 | } |
139 | AddNodeAttr("Tout" , out_types, &ndef); |
140 | NameAttrList func; |
141 | func.set_name(n->type_string()); |
142 | for (const auto& attr : n->attrs()) { |
143 | (*func.mutable_attr())[attr.first] = attr.second; |
144 | } |
145 | AddNodeAttr("f" , func, &ndef); |
146 | Status s; |
147 | Node* ret = g->AddNode(ndef, &s); |
148 | TF_CHECK_OK(s); |
149 | return ret; |
150 | } |
151 | |
152 | class SymbolicGradientBuilder { |
153 | public: |
154 | SymbolicGradientBuilder(gtl::ArraySlice<NodeOut> y_node_outputs, |
155 | gtl::ArraySlice<NodeOut> x_node_outputs, |
156 | gtl::ArraySlice<NodeOut> y_grad_node_outputs, |
157 | std::vector<NodeOut>* x_grad_node_outputs, |
158 | Graph* graph); |
159 | |
160 | Status Compute(); |
161 | |
162 | private: |
163 | gtl::ArraySlice<NodeOut> y_node_outputs_; |
164 | gtl::ArraySlice<NodeOut> x_node_outputs_; |
165 | gtl::ArraySlice<NodeOut> y_grad_node_outputs_; |
166 | std::vector<NodeOut>* x_grad_node_outputs_; |
167 | Graph* graph_; // Not owned. |
168 | |
169 | // A vector of output endpoints which represents backpropagated |
170 | // gradients |
171 | typedef std::vector<NodeOut> BackproppedGradients; |
172 | |
173 | // backprops_ is a map from a node output to its accumulated |
174 | // gradients. When a node output has accumulated all its |
175 | // gradients, we add a node which sums them up. |
176 | std::unordered_map<NodeOut, BackproppedGradients, NodeOutHash, NodeOutEq> |
177 | backprops_; |
178 | |
179 | // pending[i] is count-down counter for i-th node's expected |
180 | // backprops. When pending[i] becomes zero, we collected all |
181 | // backprop gradients for all outputs of the ith-node. |
182 | std::vector<int> pending_; |
183 | |
184 | // 'ready' keeps track of nodes that have been completely |
185 | // backpropped. Initially, for every output y of the function f, we |
186 | // add dy as an input of the gradient function. |
187 | std::deque<Node*> ready_; |
188 | |
189 | // The set of node ids at which to stop backprop. |
190 | std::unordered_set<int> stop_nodes_; |
191 | |
192 | // Initialize pending_ and ready_. |
193 | void InitBackprop(); |
194 | |
195 | // In the original function body, there is a forward edge from 'src' |
196 | // to 'dst', when the backprop algorithm constructs the node |
197 | // 'dst_grad' which computes the gradient, we need to propagate it |
198 | // to 'src'. |
199 | void BackpropAlongEdge(const NodeOut& dst_grad, const NodeOut& src); |
200 | void BackpropZerosAlongEdge(const NodeOut& src); |
201 | |
202 | // Returns a node representing the sum of any backpropped gradients for 'src'. |
203 | // This will be an AddN node if there is more than one accumulated gradient. |
204 | // Returns zeros if there are no gradients, or the dtype is DT_BOOL. |
205 | NodeOut SumGradients(const NodeOut& src); |
206 | |
207 | TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientBuilder); |
208 | }; |
209 | |
210 | SymbolicGradientBuilder::SymbolicGradientBuilder( |
211 | gtl::ArraySlice<NodeOut> y_node_outputs, |
212 | gtl::ArraySlice<NodeOut> x_node_outputs, |
213 | gtl::ArraySlice<NodeOut> y_grad_node_outputs, |
214 | std::vector<NodeOut>* x_grad_node_outputs, Graph* graph) |
215 | : y_node_outputs_(y_node_outputs), |
216 | x_node_outputs_(x_node_outputs), |
217 | y_grad_node_outputs_(y_grad_node_outputs), |
218 | x_grad_node_outputs_(x_grad_node_outputs), |
219 | graph_(graph) { |
220 | CHECK_EQ(y_node_outputs_.size(), y_grad_node_outputs.size()); |
221 | x_grad_node_outputs_->clear(); |
222 | x_grad_node_outputs_->resize(x_node_outputs_.size()); |
223 | stop_nodes_.reserve(x_node_outputs_.size()); |
224 | for (int i = 0, end = x_node_outputs_.size(); i < end; ++i) { |
225 | stop_nodes_.insert(x_node_outputs_[i].node->id()); |
226 | } |
227 | } |
228 | |
229 | void SymbolicGradientBuilder::BackpropAlongEdge(const NodeOut& dst_grad, |
230 | const NodeOut& src) { |
231 | CHECK_NOTNULL(src.node); |
232 | auto iter = backprops_.find(src); |
233 | if (iter != backprops_.end()) { |
234 | auto* grads = &iter->second; |
235 | grads->push_back(dst_grad); |
236 | if (--pending_[src.node->id()] == 0) { |
237 | ready_.push_back(src.node); |
238 | } |
239 | } |
240 | } |
241 | |
242 | void SymbolicGradientBuilder::BackpropZerosAlongEdge(const NodeOut& src) { |
243 | CHECK_NOTNULL(src.node); |
244 | auto iter = backprops_.find(src); |
245 | if (iter != backprops_.end()) { |
246 | if (--pending_[src.node->id()] == 0) { |
247 | ready_.push_back(src.node); |
248 | } |
249 | } |
250 | } |
251 | |
252 | void SymbolicGradientBuilder::InitBackprop() { |
253 | pending_.resize(graph_->num_node_ids(), 0); |
254 | { |
255 | backprops_.clear(); |
256 | std::unordered_set<Node*> visited; |
257 | std::deque<Node*> queue; |
258 | for (const NodeOut& nout : y_node_outputs_) { |
259 | queue.push_back(nout.node); |
260 | visited.insert(nout.node); |
261 | } |
262 | |
263 | // Going forward to figure out which endpoints need backprop-ed. |
264 | // A node's endpoints need to be backprop-ed only if one of the |
265 | // return nodes can reach backwards to the node via data edges. |
266 | while (!queue.empty()) { |
267 | Node* n = queue.front(); |
268 | queue.pop_front(); |
269 | for (int i = 0; i < n->num_outputs(); ++i) { |
270 | backprops_[{n, i}].clear(); |
271 | } |
272 | for (const Edge* e : n->in_edges()) { |
273 | if (e->IsControlEdge()) continue; |
274 | pending_[e->src()->id()]++; |
275 | if (visited.find(e->src()) == visited.end()) { |
276 | queue.push_back(e->src()); |
277 | visited.insert(e->src()); |
278 | } |
279 | } |
280 | } |
281 | |
282 | // Create entries in backprops_ for all x_node_outputs_, because they will |
283 | // not be added in above loop if they are not reverse reachable from |
284 | // y_node_outputs_. |
285 | for (const NodeOut& nout : x_node_outputs_) { |
286 | backprops_[{nout.node, nout.index}].clear(); |
287 | } |
288 | } |
289 | |
290 | { |
291 | const int num_y = y_grad_node_outputs_.size(); |
292 | for (int i = 0; i < num_y; ++i) { |
293 | Node* y = y_node_outputs_[i].node; |
294 | for (const Edge* e : y->in_edges()) { |
295 | if (e->IsControlEdge()) continue; |
296 | BackpropAlongEdge(y_grad_node_outputs_[i], {e->src(), e->src_output()}); |
297 | } |
298 | } |
299 | } |
300 | CHECK(!ready_.empty()); |
301 | } |
302 | |
303 | NodeOut SymbolicGradientBuilder::SumGradients(const NodeOut& src) { |
304 | const DataType dtype = src.dtype(); |
305 | auto iter = backprops_.find(src); |
306 | CHECK(iter != backprops_.end()); |
307 | const auto& grads = iter->second; |
308 | if (grads.empty() || dtype == DT_BOOL) { |
309 | // Nothing propagated back. The best we can come up is zeros. |
310 | Node* zero_like = AddZerosLike(graph_, src); |
311 | return {zero_like, 0}; |
312 | } |
313 | if (grads.size() == 1) { |
314 | // Just one backprop edge. |
315 | return grads[0]; |
316 | } |
317 | // Otherwise, adds backprop-ed gradients. |
318 | NodeDef ndef; |
319 | ndef.set_name(graph_->NewName(kNodeLabel)); |
320 | ndef.set_op("AddN" ); // N-way Add |
321 | for (const NodeOut& nout : grads) { |
322 | ndef.add_input(nout.name()); |
323 | } |
324 | AddNodeAttr("N" , static_cast<int64_t>(grads.size()), &ndef); |
325 | AddNodeAttr("T" , dtype, &ndef); |
326 | Status s; |
327 | Node* add = graph_->AddNode(ndef, &s); |
328 | TF_CHECK_OK(s); |
329 | for (size_t i = 0; i < grads.size(); ++i) { |
330 | const NodeOut& nout = grads[i]; |
331 | graph_->AddEdge(nout.node, nout.index, add, i); |
332 | } |
333 | return {add, 0}; |
334 | } |
335 | |
336 | static bool IsPrimitiveOpWithNoGrad(const string& func) { |
337 | gradient::Creator creator; |
338 | Status s = gradient::GetOpGradientCreator(func, &creator); |
339 | return s.ok() && (creator == nullptr); |
340 | } |
341 | |
342 | Status SymbolicGradientBuilder::Compute() { |
343 | // Initialize backprops. |
344 | InitBackprop(); |
345 | |
346 | // Backward propagation. |
347 | gtl::InlinedVector<NodeOut, 8> dy; |
348 | while (!ready_.empty()) { |
349 | // n has collected all gradients. |
350 | Node* n = ready_.front(); |
351 | ready_.pop_front(); |
352 | |
353 | // "n" has num_x inputs and num_y outputs. |
354 | const int num_x = n->num_inputs(); |
355 | const int num_y = n->num_outputs(); |
356 | |
357 | auto iter = stop_nodes_.find(n->id()); |
358 | if (iter != stop_nodes_.end()) { |
359 | // Stop backprop. |
360 | // TODO(andydavis) Support stop nodes with more than one output. |
361 | CHECK_EQ(1, num_y); |
362 | continue; |
363 | } |
364 | |
365 | // dy[i] is the sum of i-th output's backpropped gradients. |
366 | dy.clear(); |
367 | dy.resize(num_y, {nullptr, 0}); |
368 | for (int i = 0; i < num_y; ++i) { |
369 | dy[i] = SumGradients({n, i}); |
370 | } |
371 | |
372 | if (IsPrimitiveOpWithNoGrad(n->type_string())) { |
373 | // No grad defined for this op: Backprop zeros along the in edges. |
374 | for (const Edge* e : n->in_edges()) { |
375 | if (e->IsControlEdge()) continue; |
376 | BackpropZerosAlongEdge({e->src(), e->src_output()}); |
377 | } |
378 | continue; |
379 | } |
380 | |
381 | // Adds a gradient node with num_x + num_y inputs and num_x |
382 | // outputs. |
383 | // TODO(andydavis) Support primitive gradient ops. |
384 | Node* grad = AddSymGrad(graph_, n, dy); |
385 | for (const Edge* e : n->in_edges()) { |
386 | if (e->IsControlEdge()) continue; |
387 | graph_->AddEdge(e->src(), e->src_output(), grad, e->dst_input()); |
388 | } |
389 | for (int i = 0; i < num_y; ++i) { |
390 | graph_->AddEdge(dy[i].node, dy[i].index, grad, num_x + i); |
391 | } |
392 | |
393 | // Backprops along the in edges. |
394 | for (const Edge* e : n->in_edges()) { |
395 | if (e->IsControlEdge()) continue; |
396 | BackpropAlongEdge({grad, e->dst_input()}, {e->src(), e->src_output()}); |
397 | } |
398 | } |
399 | |
400 | for (int i = 0, end = x_node_outputs_.size(); i < end; ++i) { |
401 | (*x_grad_node_outputs_)[i] = SumGradients(x_node_outputs_[i]); |
402 | } |
403 | |
404 | return OkStatus(); |
405 | } |
406 | |
407 | Status AddSymbolicGradients(gtl::ArraySlice<NodeOut> y_node_outputs, |
408 | gtl::ArraySlice<NodeOut> x_node_outputs, |
409 | gtl::ArraySlice<NodeOut> y_grad_node_outputs, |
410 | std::vector<NodeOut>* x_grad_node_outputs, |
411 | Graph* graph) { |
412 | SymbolicGradientBuilder builder(y_node_outputs, x_node_outputs, |
413 | y_grad_node_outputs, x_grad_node_outputs, |
414 | graph); |
415 | return builder.Compute(); |
416 | } |
417 | |
418 | } // end namespace tensorflow |
419 | |