1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_C_EAGER_TAPE_H_ |
16 | #define TENSORFLOW_C_EAGER_TAPE_H_ |
17 | |
18 | // Language-agnostic gradient tape. Does not perform backpropagation, just |
19 | // maintains the data structures required to do so. |
20 | |
21 | #include <stack> |
22 | #include <unordered_map> |
23 | #include <unordered_set> |
24 | #include <vector> |
25 | |
26 | #include "tensorflow/core/config/flag_defs.h" |
27 | #include "tensorflow/core/config/flags.h" |
28 | #include "tensorflow/core/framework/tensor_shape.h" |
29 | #include "tensorflow/core/framework/types.h" |
30 | #include "tensorflow/core/lib/gtl/array_slice.h" |
31 | #include "tensorflow/core/lib/gtl/cleanup.h" |
32 | #include "tensorflow/core/lib/gtl/flatmap.h" |
33 | #include "tensorflow/core/lib/gtl/flatset.h" |
34 | #include "tensorflow/core/platform/errors.h" |
35 | #include "tensorflow/core/platform/types.h" |
36 | |
37 | namespace tensorflow { |
38 | namespace eager { |
39 | |
40 | // Represents an entry in the tape. |
41 | template <typename BackwardFunction, typename TapeTensor> |
42 | struct OpTapeEntry { |
43 | string op_type; |
44 | std::vector<TapeTensor> output_tensor_info; |
45 | std::vector<int64_t> input_tensor_id; |
46 | |
47 | // TODO(apassos) consider narrowing down this interface. |
48 | BackwardFunction* backward_function; |
49 | |
50 | // Should be called before deleting the backward function. TODO(apassos) use |
51 | // unique_ptrs to ensure this happens. |
52 | std::function<void(BackwardFunction*)> backward_function_deleter; |
53 | }; |
54 | |
55 | // Map from tensor_id to internally-defined operation-id of the operation which |
56 | // produced this tensor. A value of -1 means that the tensor was directly |
57 | // watched and not the result of any operation in the tape. |
58 | using TensorTape = std::unordered_map<int64_t, int64_t>; |
59 | |
60 | // Map from operation-id to tape entry. |
61 | template <typename BackwardFunction, typename TapeTensor> |
62 | using OpTape = |
63 | std::unordered_map<int64_t, OpTapeEntry<BackwardFunction, TapeTensor>>; |
64 | |
65 | // Operations the tape needs to perform on tensors to do backpropagation. Named |
66 | // "vspace" because a subset of these are related to a vector space, such as |
67 | // adding gradients, getting zeroes, etc. Currently cannot be implemented |
68 | // without using tensorflow python code, hence left unspecified here. |
69 | // |
70 | // Gradient is the type returned by gradient functions. In Python TF it's either |
71 | // Tensor or IndexedSlices or None, which here we map to nullptr. Gradients need |
72 | // to allow their size to be computed and they need to be passable to a backward |
73 | // function and deleted (as the backprop code creates lots of gradients the user |
74 | // is not interested in). |
75 | // |
76 | // BackwardFunction needs to be a closure which stores intermediate activations |
77 | // from the forward computation and calls a vector-jacobian product function |
78 | // (also known as adjoint function) to compute, given downstream gradients, |
79 | // upstream gradients. |
80 | // |
81 | // TODO(apassos) provide concrete template instantiations for TFE_TensorHandle |
82 | // specialization, which is blocked by quite a few things needing to loop back |
83 | // into python now. |
84 | template <typename Gradient, typename BackwardFunction, typename TapeTensor> |
85 | class VSpace { |
86 | public: |
87 | virtual ~VSpace() {} |
88 | |
89 | // Returns the number of elements in the gradient tensor. |
90 | virtual int64_t NumElements(Gradient* tensor) const = 0; |
91 | |
92 | // Consumes references to the tensors in the gradient_tensors list and returns |
93 | // a tensor with the result. |
94 | virtual Gradient* AggregateGradients( |
95 | gtl::ArraySlice<Gradient*> gradient_tensors) const = 0; |
96 | |
97 | // Calls the passed-in backward function. |
98 | // |
99 | // `unneeded_gradients` contains sorted list of input indices for which a |
100 | // gradient is not required. |
101 | virtual Status CallBackwardFunction( |
102 | const string& op_type, BackwardFunction* backward_function, |
103 | const std::vector<int64_t>& unneeded_gradients, |
104 | gtl::ArraySlice<Gradient*> output_gradients, |
105 | absl::Span<Gradient*> result) const = 0; |
106 | |
107 | // Builds a tensor filled with ones with the same shape and dtype as `t`. |
108 | virtual Status BuildOnesLike(const TapeTensor& t, |
109 | Gradient** result) const = 0; |
110 | |
111 | // Looks up the ID of a Gradient. |
112 | virtual int64_t TensorId(Gradient* tensor) const = 0; |
113 | |
114 | // Converts a Gradient to a TapeTensor. |
115 | virtual TapeTensor TapeTensorFromGradient(Gradient* gradient) const = 0; |
116 | |
117 | // Marks the following gradient as a result so it's not consumed by backward |
118 | // functions. |
119 | virtual void MarkAsResult(Gradient* gradient) const = 0; |
120 | |
121 | // Deletes the input tensor. |
122 | virtual void DeleteGradient(Gradient* gradient) const = 0; |
123 | }; |
124 | |
125 | // Traces the execution of operations, doing eager garbage collection, and |
126 | // exporting a full trace so other code can do backpropagation. Not thread-safe. |
127 | template <typename Gradient, typename BackwardFunction, typename TapeTensor> |
128 | class GradientTape { |
129 | public: |
130 | // If `persistent` is true, GradientTape will not eagerly delete backward |
131 | // functions (and hence the tensors they keep alive). Instead, everything |
132 | // is deleted in ~GradientTape. Persistent GradientTapes are useful when |
133 | // users want to compute multiple gradients over the same tape. |
134 | explicit GradientTape(bool persistent) : persistent_(persistent) {} |
135 | ~GradientTape() { |
136 | for (const auto& pair : op_tape_) { |
137 | pair.second.backward_function_deleter(pair.second.backward_function); |
138 | } |
139 | } |
140 | |
141 | // Returns whether any tensor in a list of tensors is being watched and has |
142 | // a trainable dtype. |
143 | bool ShouldRecord(gtl::ArraySlice<int64_t> tensor_ids, |
144 | gtl::ArraySlice<tensorflow::DataType> dtypes) const; |
145 | |
146 | // Adds this tensor to the list of watched tensors. |
147 | // |
148 | // This is a no-op if the tensor is already being watched either from an |
149 | // earlier call to `GradientTape::Watch` or being an output of an op with |
150 | // watched inputs. |
151 | void Watch(int64_t tensor_id); |
152 | |
153 | // Records an operation with inputs `input_tensor_id` and outputs |
154 | // `output_tensors` on the tape and marks all its outputs as watched if at |
155 | // least one input of the op is watched and has trainable dtype. |
156 | // |
157 | // op_type is used to decide which of the incoming gradients can be left as |
158 | // nullptr instead of building zeros when build_default_zeros_grads == true. |
159 | void RecordOperation( |
160 | const string& op_type, const std::vector<TapeTensor>& output_tensors, |
161 | gtl::ArraySlice<int64_t> input_tensor_id, |
162 | gtl::ArraySlice<tensorflow::DataType> input_dtypes, |
163 | const std::function<BackwardFunction*()>& backward_function_getter, |
164 | const std::function<void(BackwardFunction*)>& backward_function_deleter); |
165 | |
166 | void DeleteTrace(int64_t tensor_id); |
167 | |
168 | // Consumes the internal state of the tape (so cannot be called more than |
169 | // once) and produces the gradient of the target tensors with respect to the |
170 | // source tensors. The output gradients are used if not empty and not |
171 | // null. The result is populated with one tensor per target element. |
172 | // When running backward functions, builds zeros-like tensors for |
173 | // incoming grads which are nullptrs, unless `build_default_zeros_grads` |
174 | // is set to false. |
175 | Status ComputeGradient( |
176 | const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace, |
177 | const gtl::ArraySlice<int64_t> target_tensor_ids, |
178 | const gtl::ArraySlice<int64_t> source_tensor_ids, |
179 | const std::unordered_map<int64, TapeTensor>& sources_that_are_targets, |
180 | gtl::ArraySlice<Gradient*> output_gradients, absl::Span<Gradient*> result, |
181 | bool build_default_zeros_grads = true); |
182 | |
183 | // Whether the tape is persistent. See ctor for detailed description. |
184 | bool IsPersistent() const { return persistent_; } |
185 | |
186 | private: |
187 | TensorTape tensor_tape_; |
188 | OpTape<BackwardFunction, TapeTensor> op_tape_; |
189 | int64_t next_op_id_{0}; |
190 | |
191 | // Map from tensor id to number of remaining usages (i.e. how many entries in |
192 | // the tape refer to it); to aid in tape garbage collection. |
193 | std::unordered_map<int64_t, int64_t> tensor_usage_; |
194 | |
195 | // If false, all activations are deleted in the first call to ComputeGradient. |
196 | // Else, only when this is destructed. |
197 | bool persistent_; |
198 | }; |
199 | |
200 | // Describes a callback for special-cased and more efficient jvp computation. |
201 | // |
202 | // Could just be a simple typedef in ForwardAccumulator, but MSVC chokes on |
203 | // that. |
204 | template <typename Gradient> |
205 | class ForwardFunction |
206 | : public std::function<Status(const std::vector<Gradient*>&, |
207 | std::vector<Gradient*>*, bool)> { |
208 | public: |
209 | template <typename lambda_type> |
210 | explicit ForwardFunction(lambda_type lambda) |
211 | : std::function<Status(const std::vector<Gradient*>&, |
212 | std::vector<Gradient*>*, bool)>(lambda) {} |
213 | }; |
214 | |
215 | // Computes Jacobian-vector products using forward-mode automatic |
216 | // differentiation. |
217 | // |
218 | // While GradientTape's RecordOperation is trivial, ForwardAccumulator's |
219 | // Accumulate runs the gradient computation immediately. |
220 | // |
221 | // Keeps references to Tensors watched via Watch and computed in Accumulate |
222 | // corresponding to output_tensors, and releases these references in its |
223 | // destructor. However, waiting until the destructor runs loses the memory |
224 | // efficiency of forward-mode autodiff. Instead, language bindings should call |
225 | // DeleteGradient as soon as a Tensor which was `Watch`ed or was an output |
226 | // Tensor passed to Accumulate goes out of scope. |
227 | // |
228 | // Not thread-safe. |
229 | template <typename Gradient, typename BackwardFunction, typename TapeTensor> |
230 | class ForwardAccumulator { |
231 | public: |
232 | // Does not take ownership of `vspace`, which must outlive the |
233 | // ForwardAccumulator. |
234 | explicit ForwardAccumulator( |
235 | const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace, |
236 | bool use_batch) |
237 | : vspace_(vspace), use_batch_(use_batch) { |
238 | call_state_.emplace(nullptr, false); |
239 | } |
240 | |
241 | virtual ~ForwardAccumulator() { |
242 | for (auto accumulated : accumulated_gradients_) { |
243 | vspace_.DeleteGradient(accumulated.second); |
244 | } |
245 | } |
246 | |
247 | // Tell the forward accumulator to watch tensor_id, with a Tensor tangent |
248 | // vector `tangent` of matching shape and dtype. Tangents are the "vector" in |
249 | // "Jacobian-vector product"; `Watch`ing a new Tensor and immediately calling |
250 | // FetchJVP for it would return `tangent`. |
251 | void Watch(int64_t tensor_id, Gradient* tangent); |
252 | |
253 | // Removes the gradient associated with tensor_id. Should be called when the |
254 | // Tensor associated with `tensor_id` is deleted. |
255 | void DeleteGradient(int64_t tensor_id); |
256 | |
257 | // Runs forward autodiff. Should be called whenever a new operation is |
258 | // available and the accumulator is active. |
259 | // |
260 | // Like GradientTape::RecordOperation, this method takes the operation type |
261 | // `op_type` (e.g. "Add"), the operation's inputs (`input_tensors`, |
262 | // `input_tensor_id`, and `input_dtypes`; the latter two are somewhat |
263 | // redundant but taken as arguments to avoid repeatedly fetching these values |
264 | // between calls to ShouldRecord and Accumulator), and its outputs |
265 | // (`output_tensors`). |
266 | // |
267 | // If provided, a non-null `forward_function` will be used instead of the |
268 | // backward function (`backward_function_getter` / |
269 | // `backward_function_deleter`) to compute jvps for this operation. If |
270 | // `forward_function` is null, a GradientTape is used on the backward function |
271 | // to compute the jvp, which will waste computation when executing eagerly. |
272 | // |
273 | // Unlike GradientTape::RecordOperation, Accumulate runs gradient computation |
274 | // immediately. It stores the results, which feed into Accumulate for future |
275 | // operations and may be fetched by calling FetchJVP. ForwardAccumulator |
276 | // maintains a reference to these JVPs: if an `output_tensors` Tensor is |
277 | // deleted, `DeleteGradient` should be called as soon as possible to free the |
278 | // (now inaccessible) corresponding JVPs, but ForwardAccumulator's destructor |
279 | // will release remaining references. |
280 | // |
281 | // This method is not thread-safe (and in general ForwardAccumulator is not |
282 | // thread-safe). |
283 | Status Accumulate( |
284 | const string& op_type, const std::vector<TapeTensor>& input_tensors, |
285 | const std::vector<TapeTensor>& output_tensors, |
286 | gtl::ArraySlice<int64_t> input_tensor_id, |
287 | gtl::ArraySlice<tensorflow::DataType> input_dtypes, |
288 | const ForwardFunction<Gradient>* forward_function, |
289 | const std::function<BackwardFunction*()>& backward_function_getter, |
290 | const std::function<void(BackwardFunction*)>& backward_function_deleter); |
291 | |
292 | // Returns true if `Accumulate` is active somewhere above on the stack and |
293 | // there isn't an intervening PushState. This is useful for ordering |
294 | // ForwardAccumulators, where more deeply nested accumulators should not see |
295 | // computations from less deeply nested accumulators. |
296 | bool BusyAccumulating() const { return call_state_.top().accumulating; } |
297 | |
298 | // Fetches the current Jacobian-vector product associated with `tensor_id`, or |
299 | // a nullptr if none is available. |
300 | // |
301 | // Returns a borrowed reference, i.e. does not run VSpace::MarkAsResult on its |
302 | // return value. The caller should increment the reference count before |
303 | // deleting the ForwardAccumulator or calling DeleteGradient if keeping a |
304 | // persistent reference to a non-null result. |
305 | Gradient* FetchJVP(int64_t tensor_id); |
306 | |
307 | // Indicates whether the forward accumulator should run on an operation with |
308 | // the specified inputs and dtypes. |
309 | bool ShouldRecord(gtl::ArraySlice<int64_t> tensor_ids, |
310 | gtl::ArraySlice<tensorflow::DataType> dtypes); |
311 | |
312 | // Temporarily push or pop transient state for this accumulator. |
313 | // |
314 | // Allows an accumulator which is currently processing an operation to |
315 | // temporarily reset its state. Without pushing and popping, accumulators |
316 | // ignore operations executed as a direct result of their own jvp |
317 | // computations. |
318 | void PushState() { call_state_.emplace(nullptr, false); } |
319 | void PopState() { call_state_.pop(); } |
320 | |
321 | private: |
322 | // Helper for Accumulate: uses a GradientTape to compute forward gradients |
323 | // from a backward gradient function. Fills `out_grads` corresponding to |
324 | // `output_tensors`. `out_grads` must not be null. |
325 | // |
326 | // Executes the backward function in order to trace its gradient, which will |
327 | // waste computation if executing eagerly (when graph building the unneeded |
328 | // computation is pruned). Temporarily sets `backward_tape` so that |
329 | // Accumulate will forward op executions to the tape while the backward |
330 | // function is running; this effectively adds the backward tape to the active |
331 | // set (but does not require complicated callbacks to the language bindings). |
332 | Status ForwardpropFromTape( |
333 | const string& op_type, const std::vector<TapeTensor>& output_tensors, |
334 | const std::function<BackwardFunction*()>& backward_function_getter, |
335 | const std::function<void(BackwardFunction*)>& backward_function_deleter, |
336 | const std::vector<Gradient*>& in_grads, absl::Span<Gradient*> out_grads); |
337 | |
338 | // Maps from tensor IDs to corresponding JVPs. |
339 | std::unordered_map<int64_t, Gradient*> accumulated_gradients_; |
340 | // Not owned; provides operations on Tensors which are currently only |
341 | // available in language bindings (e.g. Python). |
342 | const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace_; |
343 | |
344 | // Decides if tangents are vectorized or not |
345 | bool use_batch_; |
346 | |
347 | struct AccumulatorCallState { |
348 | AccumulatorCallState( |
349 | GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape, |
350 | bool accumulating) |
351 | : backward_tape(backward_tape), accumulating(accumulating) {} |
352 | // Set temporarily while in the Accumulate method; if backward_tape is not |
353 | // nullptr then we forward op executions to it so Accumulate can compute a |
354 | // backward pass on its backward function. |
355 | // |
356 | // Not owned by the ForwardAccumulator. The method which sets |
357 | // `backward_tape` keeps ownership. |
358 | GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape; |
359 | // While the Accumulate method is running (accumulating is True), any op |
360 | // executions not forwarded to backward_tape should be ignored. |
361 | bool accumulating; |
362 | }; |
363 | // A deque-backed stack, whose element references are not invalidated by |
364 | // pushes and pops at the back. |
365 | std::stack<AccumulatorCallState> call_state_; |
366 | }; |
367 | |
368 | // Template instantiations here |
369 | |
370 | inline bool IsDtypeTrainable(DataType dtype) { |
371 | switch (dtype) { |
372 | case DT_HALF: |
373 | case DT_BFLOAT16: |
374 | case DT_FLOAT: |
375 | case DT_DOUBLE: |
376 | case DT_COMPLEX64: |
377 | case DT_COMPLEX128: |
378 | case DT_RESOURCE: |
379 | case DT_VARIANT: |
380 | return true; |
381 | case DT_QINT8: |
382 | case DT_QINT16: |
383 | case DT_QINT32: |
384 | case DT_QUINT8: |
385 | case DT_QUINT16: |
386 | return tensorflow::flags::Global() |
387 | .enable_quantized_dtypes_training.value(); |
388 | default: |
389 | return false; |
390 | } |
391 | } |
392 | |
393 | template <typename Gradient, typename BackwardFunction, typename TapeTensor> |
394 | bool GradientTape<Gradient, BackwardFunction, TapeTensor>::ShouldRecord( |
395 | gtl::ArraySlice<int64_t> tensor_ids, |
396 | gtl::ArraySlice<tensorflow::DataType> dtypes) const { |
397 | CHECK_EQ(tensor_ids.size(), dtypes.size()); |
398 | for (int i = 0; i < tensor_ids.size(); ++i) { |
399 | if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) { |
400 | if (IsDtypeTrainable(dtypes[i])) { |
401 | return true; |
402 | } |
403 | } |
404 | } |
405 | return false; |
406 | } |
407 | |
408 | template <typename Gradient, typename BackwardFunction, typename TapeTensor> |
409 | void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch( |
410 | int64_t tensor_id) { |
411 | tensor_tape_.emplace(tensor_id, -1); |
412 | } |
413 | |
414 | template <typename Gradient, typename BackwardFunction, typename TapeTensor> |
415 | void GradientTape<Gradient, BackwardFunction, TapeTensor>::RecordOperation( |
416 | const string& op_type, const std::vector<TapeTensor>& output_tensors, |
417 | gtl::ArraySlice<int64_t> input_tensor_id, |
418 | gtl::ArraySlice<tensorflow::DataType> input_dtypes, |
419 | const std::function<BackwardFunction*()>& backward_function_getter, |
420 | const std::function<void(BackwardFunction*)>& backward_function_deleter) { |
421 | if (!ShouldRecord(input_tensor_id, input_dtypes)) { |
422 | return; |
423 | } |
424 | std::vector<int64_t> ids; |
425 | ids.reserve(input_tensor_id.size()); |
426 | for (int64_t i : input_tensor_id) { |
427 | tensor_usage_[i]++; |
428 | ids.push_back(i); |
429 | } |
430 | const int64_t op_id = next_op_id_++; |
431 | std::vector<TapeTensor> tensors; |
432 | tensors.reserve(output_tensors.size()); |
433 | for (const TapeTensor& o : output_tensors) { |
434 | // Note: the tensor can have already been watched and hence be in the tape, |
435 | // so we cannot check that we're inserting it here. |
436 | tensor_tape_[o.GetID()] = op_id; |
437 | tensor_usage_[o.GetID()] = 1; |
438 | tensors.push_back(o); |
439 | } |
440 | op_tape_[op_id] = OpTapeEntry<BackwardFunction, TapeTensor>{ |
441 | op_type, std::move(tensors), std::move(ids), backward_function_getter(), |
442 | backward_function_deleter}; |
443 | } |
444 | |
445 | template <typename Gradient, typename BackwardFunction, typename TapeTensor> |
446 | void GradientTape<Gradient, BackwardFunction, TapeTensor>::DeleteTrace( |
447 | int64_t tensor_id) { |
448 | auto it = tensor_usage_.find(tensor_id); |
449 | if (it == tensor_usage_.end()) { |
450 | return; |
451 | } |
452 | it->second--; |
453 | if (it->second != 0) { |
454 | return; |
455 | } |
456 | tensor_usage_.erase(it); |
457 | auto tensor_op_it = tensor_tape_.find(tensor_id); |
458 | if (tensor_op_it == tensor_tape_.end()) { |
459 | return; |
460 | } |
461 | const int64_t op_id = tensor_op_it->second; |
462 | if (op_id == -1) { |
463 | // Do not delete watched tensors. |
464 | return; |
465 | } |
466 | tensor_tape_.erase(tensor_op_it); |
467 | auto op_it = op_tape_.find(op_id); |
468 | CHECK(op_it != op_tape_.end()); |
469 | for (const auto& output : op_it->second.output_tensor_info) { |
470 | if (tensor_usage_.find(output.GetID()) != tensor_usage_.end()) { |
471 | // Found a usage for an output, so cannot delete the op. |
472 | return; |
473 | } |
474 | } |
475 | for (int64_t id : op_it->second.input_tensor_id) { |
476 | DeleteTrace(id); |
477 | } |
478 | op_it->second.backward_function_deleter(op_it->second.backward_function); |
479 | op_tape_.erase(op_it); |
480 | } |
481 | |
482 | // Terminology: |
483 | // |
484 | // - op: a possibly composite operation, which has an entry in the tape |
485 | // - target: dy in dx/dy |
486 | // - source: dx in dx/dy |
487 | // - tensor: one of the many inputs or outputs of an operation |
488 | // |
489 | // Below here we do the gradient algorithm. It works as follows: |
490 | // |
491 | // First we filter the tape to just the subset of operations we want to |
492 | // differentiate. In the process of doing so we count how many times each Tensor |
493 | // is used as an input to an op (so we know when we're done computing gradients |
494 | // for that Tensor). We also count, for each tape entry, how many of its output |
495 | // Tensors need gradients to be computed (Tensors which are not used do not need |
496 | // any gradients to be computed). |
497 | // |
498 | // Finally, we start a backprop stack with a set of tape entries for which we |
499 | // have all gradients available. This set usually is a subset of the set of |
500 | // targets (not all since targets which have outputs in the tape will not have |
501 | // gradients available initially). |
502 | // |
503 | // Then we repeatedly pop an entry from the stack, run its backprop, and update |
504 | // the gradients of its inputs. Once we have computed all gradients for a single |
505 | // input we can mark this input as done, and this can trigger adding an entry to |
506 | // the stack if all outputs of that entry are now done. |
507 | // |
508 | // When the stack is empty we have gradients for all tensors we're interested |
509 | // in. |
510 | |
511 | namespace { |
512 | |
513 | template <typename BackwardFunction, typename TapeTensor> |
514 | struct BackpropInitialState { |
515 | OpTape<BackwardFunction, TapeTensor> op_tape; |
516 | |
517 | // Map from tensor ID to how many references still exist for this tensor in |
518 | // the tape. |
519 | std::unordered_map<int64_t, int64_t> tensor_usage_counts; |
520 | |
521 | // Maps from op ID to how many output tensors of this op still need to have |
522 | // their gradients computed. |
523 | std::unordered_map<int64_t, int64_t> op_missing_tensor; |
524 | }; |
525 | |
526 | // If `persistent_tape` is true, op_tape is not changed and none of the |
527 | // backwards functions are deleted. |
528 | // If `persistent_tape` is false, op_tape is cleared and backwards functions |
529 | // not needed for gradient computation are deleted. Backwards functions that |
530 | // are needed, are copied and returned in BackpropInitialState. |
531 | template <typename BackwardFunction, typename TapeTensor> |
532 | BackpropInitialState<BackwardFunction, TapeTensor> PrepareBackprop( |
533 | gtl::ArraySlice<int64_t> target, const TensorTape& tensor_tape, |
534 | OpTape<BackwardFunction, TapeTensor>* op_tape, |
535 | const std::unordered_set<int64_t>& sources_set, bool persistent_tape) { |
536 | std::vector<int64_t> tensor_stack; |
537 | tensor_stack.reserve(target.size()); |
538 | for (auto t : target) { |
539 | tensor_stack.push_back(t); |
540 | } |
541 | BackpropInitialState<BackwardFunction, TapeTensor> result; |
542 | while (!tensor_stack.empty()) { |
543 | int64_t tensor_id = tensor_stack.back(); |
544 | tensor_stack.pop_back(); |
545 | auto op_id_it = tensor_tape.find(tensor_id); |
546 | if (op_id_it == tensor_tape.end()) { |
547 | continue; |
548 | } |
549 | int64_t op_id = op_id_it->second; |
550 | auto op_it = op_tape->find(op_id); |
551 | auto result_op_it = result.op_tape.find(op_id); |
552 | if (op_id == -1 || op_it == op_tape->end() || |
553 | result_op_it != result.op_tape.end()) { |
554 | continue; |
555 | } |
556 | CHECK(result.op_tape.emplace(op_id, op_it->second).second); |
557 | for (auto it : op_it->second.input_tensor_id) { |
558 | auto count_it = result.tensor_usage_counts.find(it); |
559 | if (count_it != result.tensor_usage_counts.end()) { |
560 | count_it->second++; |
561 | } else { |
562 | result.tensor_usage_counts[it] = 1; |
563 | if (tensor_tape.find(it) != tensor_tape.end()) { |
564 | tensor_stack.push_back(it); |
565 | } |
566 | } |
567 | } |
568 | if (!persistent_tape) { |
569 | op_tape->erase(op_it); |
570 | } |
571 | } |
572 | for (auto& pair : result.tensor_usage_counts) { |
573 | auto it = tensor_tape.find(pair.first); |
574 | if (it != tensor_tape.end() && it->second != -1) { |
575 | result.op_missing_tensor[it->second] += 1; |
576 | } |
577 | } |
578 | if (!persistent_tape) { |
579 | // Call destructors for all unneeded gradient functions and |
580 | // clear the op_tape. We can clear the tape because ownership of |
581 | // backward functions that will be used for gradient computation |
582 | // has been transferred to `result`. |
583 | for (const auto& op_pair : *op_tape) { |
584 | op_pair.second.backward_function_deleter( |
585 | op_pair.second.backward_function); |
586 | } |
587 | op_tape->clear(); |
588 | } |
589 | return result; |
590 | } |
591 | |
592 | template <typename BackwardFunction, typename TapeTensor> |
593 | std::vector<int64_t> InitialStack( |
594 | const OpTape<BackwardFunction, TapeTensor>& op_tape, |
595 | const std::unordered_map<int64_t, int64_t>& op_missing_tensor) { |
596 | std::vector<int64_t> result; |
597 | for (auto& op_entry : op_tape) { |
598 | if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) { |
599 | result.push_back(op_entry.first); |
600 | } |
601 | } |
602 | return result; |
603 | } |
604 | |
605 | template <typename Gradient, typename BackwardFunction, typename TapeTensor> |
606 | Status InitialGradients( |
607 | const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace, |
608 | gtl::ArraySlice<int64_t> target_tensor_ids, |
609 | const std::unordered_map<int64_t, TapeTensor>& sources_that_are_targets, |
610 | gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape, |
611 | const OpTape<BackwardFunction, TapeTensor>& op_tape, |
612 | std::unordered_map<int64_t, std::vector<Gradient*>>* result) { |
613 | for (int i = 0, end = target_tensor_ids.size(); i < end; ++i) { |
614 | const int64_t id = target_tensor_ids[i]; |
615 | if (output_gradients.empty() || output_gradients[i] == nullptr) { |
616 | auto tensor_it = tensor_tape.find(id); |
617 | if (tensor_it != tensor_tape.end() && tensor_it->second != -1) { |
618 | auto op_it = op_tape.find(tensor_it->second); |
619 | if (op_it == op_tape.end()) { |
620 | return errors::Internal( |
621 | "Internal state of the gradient tape is invalid: " |
622 | "failed to find operation producing a tensor" ); |
623 | } |
624 | bool found = false; |
625 | for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) { |
626 | if (op_it->second.output_tensor_info[j].GetID() == id) { |
627 | found = true; |
628 | Gradient* ones_like = nullptr; |
629 | TF_RETURN_IF_ERROR(vspace.BuildOnesLike( |
630 | op_it->second.output_tensor_info[j], &ones_like)); |
631 | (*result)[id].push_back(ones_like); |
632 | break; |
633 | } |
634 | } |
635 | if (!found) { |
636 | return errors::Internal( |
637 | "Internal state of the gradient tape is invalid: " |
638 | "none of operations outputs match expected tensor" ); |
639 | } |
640 | } else { |
641 | // This target tensor was not generated by any operation recorded on |
642 | // the tape, so no gradient needs to be computed from it unless this |
643 | // target is also a source. |
644 | auto source_tensor = sources_that_are_targets.find(id); |
645 | if (source_tensor != sources_that_are_targets.end()) { |
646 | Gradient* ones_like = nullptr; |
647 | TF_RETURN_IF_ERROR( |
648 | vspace.BuildOnesLike(source_tensor->second, &ones_like)); |
649 | (*result)[id].push_back(ones_like); |
650 | } |
651 | } |
652 | } else { |
653 | (*result)[id].push_back(output_gradients[i]); |
654 | } |
655 | } |
656 | return OkStatus(); |
657 | } |
658 | |
659 | // TODO(agarwal): use an automatic mechanism for handling None arguments to |
660 | // gradient functions. |
661 | // |
662 | // Some gradient functions can accept None arguments for gradients. The |
663 | // following maps the operation name to the indices at which the corresponding |
664 | // gradient function can accept None values. e.g. FusedBatchNorm outputs 5 |
665 | // values and hence receives 5 gradient values during backprop. However the |
666 | // gradient function uses only the first of those values and ignores the rest. |
667 | // The entry, "FusedBatchNorm": [1, 2, 3, 4], indicates that only the gradient |
668 | // corresponding to index 0 is used, and the gradient values at indices 1-4 are |
669 | // ignored (and hence can be None). The backprop algorithm can then leverage |
670 | // this by not constructing zeros to pass for those indices. |
671 | std::unordered_map<string, std::unordered_set<int>>* |
672 | FunctionsAcceptingNoneForIndicesMap() { |
673 | static auto* const m = |
674 | new std::unordered_map<string, std::unordered_set<int>>({ |
675 | {"SoftmaxCrossEntropyWithLogits" , {1}}, |
676 | {"SparseSoftmaxCrossEntropyWithLogits" , {1}}, |
677 | {"FusedBatchNorm" , {1, 2, 3, 4}}, |
678 | }); |
679 | return m; |
680 | } |
681 | |
682 | } // namespace |
683 | |
684 | // If over kMinAggregateCount gradients are accumulated and the total |
685 | // memory consumption is over kMinAggregateBytes, do an early aggregation |
686 | // so as to release the gradient tensor to save memory. |
687 | constexpr int kMinAggregateCount = 4; |
688 | constexpr int kMinAggregateBytes = 128 * 1024 * 1024; |
689 | |
690 | template <typename Gradient, typename BackwardFunction, typename TapeTensor> |
691 | Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient( |
692 | const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace, |
693 | const gtl::ArraySlice<int64_t> target_tensor_ids, |
694 | const gtl::ArraySlice<int64_t> source_tensor_ids, |
695 | const std::unordered_map<int64_t, TapeTensor>& sources_that_are_targets, |
696 | gtl::ArraySlice<Gradient*> output_gradients, absl::Span<Gradient*> result, |
697 | bool build_default_zeros_grads) { |
698 | std::unordered_set<int64_t> sources_set(source_tensor_ids.begin(), |
699 | source_tensor_ids.end()); |
700 | BackpropInitialState<BackwardFunction, TapeTensor> state = PrepareBackprop( |
701 | target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_); |
702 | std::vector<int64_t> op_stack = |
703 | InitialStack(state.op_tape, state.op_missing_tensor); |
704 | std::unordered_map<int64_t, std::vector<Gradient*>> gradients; |
705 | Status s = InitialGradients(vspace, target_tensor_ids, |
706 | sources_that_are_targets, output_gradients, |
707 | tensor_tape_, state.op_tape, &gradients); |
708 | auto cleanup = gtl::MakeCleanup([this, &state]() { |
709 | if (!persistent_) { |
710 | // Release all backprop functions |
711 | for (const auto& pair : state.op_tape) { |
712 | pair.second.backward_function_deleter(pair.second.backward_function); |
713 | } |
714 | } |
715 | }); |
716 | if (!s.ok()) { |
717 | return s; |
718 | } |
719 | |
720 | std::unordered_map<int64_t, int64_t> gradients_size; |
721 | // TODO(apassos) multiple threads could be dequeuing from op_stack at the same |
722 | // time, for better CPU backprop performance. |
723 | VLOG(1) << "Initial stack:" ; |
724 | if (VLOG_IS_ON(1)) { |
725 | for (auto t : op_stack) { |
726 | VLOG(1) << " " << t; |
727 | } |
728 | } |
729 | while (!op_stack.empty()) { |
730 | const int64_t op = op_stack.back(); |
731 | VLOG(1) << "Popped " << op; |
732 | op_stack.pop_back(); |
733 | auto op_it = state.op_tape.find(op); |
734 | if (op_it == state.op_tape.end()) { |
735 | // It is possible for ops to end up on the stack if they are unrelated to |
736 | // the target; we should just skip them. |
737 | continue; |
738 | } |
739 | auto trace = std::move(op_it->second); |
740 | state.op_tape.erase(op_it); |
741 | std::vector<Gradient*> out_gradients; |
742 | out_gradients.reserve(trace.output_tensor_info.size()); |
743 | std::vector<int64_t> unneeded_gradients; |
744 | for (int i = 0, end = trace.input_tensor_id.size(); i < end; i++) { |
745 | const auto& in_tensor_id = trace.input_tensor_id[i]; |
746 | if (tensor_tape_.find(in_tensor_id) == tensor_tape_.end() && |
747 | sources_set.find(in_tensor_id) == sources_set.end()) { |
748 | unneeded_gradients.push_back(i); |
749 | } |
750 | } |
751 | |
752 | bool any_gradient_nonzero = false; |
753 | std::vector<int> zero_indices; |
754 | for (int i = 0, end = trace.output_tensor_info.size(); i < end; ++i) { |
755 | const int64_t id = trace.output_tensor_info[i].GetID(); |
756 | auto grad_it = gradients.find(id); |
757 | if (grad_it == gradients.end()) { |
758 | out_gradients.push_back(nullptr); |
759 | if (build_default_zeros_grads) { |
760 | auto func_name_it = |
761 | FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type); |
762 | if (func_name_it == FunctionsAcceptingNoneForIndicesMap()->end() || |
763 | func_name_it->second.find(i) == func_name_it->second.end()) { |
764 | zero_indices.push_back(i); |
765 | } |
766 | } |
767 | } else { |
768 | any_gradient_nonzero = true; |
769 | Gradient* new_gradients = nullptr; |
770 | if (grad_it->second.size() == 1) { |
771 | new_gradients = grad_it->second.at(0); |
772 | } else { |
773 | new_gradients = vspace.AggregateGradients(grad_it->second); |
774 | } |
775 | if (sources_set.find(grad_it->first) == sources_set.end()) { |
776 | gradients.erase(grad_it); |
777 | } else { |
778 | grad_it->second.clear(); |
779 | grad_it->second.push_back(new_gradients); |
780 | vspace.MarkAsResult(new_gradients); |
781 | } |
782 | out_gradients.push_back(new_gradients); |
783 | } |
784 | } |
785 | VLOG(1) << "Calling gradient function for '" << trace.op_type << "'" ; |
786 | std::vector<Gradient*> in_gradients(trace.input_tensor_id.size()); |
787 | DCHECK(build_default_zeros_grads || zero_indices.empty()); |
788 | if (any_gradient_nonzero) { |
789 | for (const auto i : zero_indices) { |
790 | out_gradients[i] = trace.output_tensor_info[i].ZerosLike(); |
791 | } |
792 | Status s; |
793 | s = vspace.CallBackwardFunction(trace.op_type, trace.backward_function, |
794 | unneeded_gradients, out_gradients, |
795 | absl::MakeSpan(in_gradients)); |
796 | if (!persistent_) { |
797 | trace.backward_function_deleter(trace.backward_function); |
798 | } |
799 | if (!s.ok()) { |
800 | return s; |
801 | } |
802 | } else { |
803 | if (!persistent_) { |
804 | trace.backward_function_deleter(trace.backward_function); |
805 | } |
806 | for (Gradient* grad : out_gradients) { |
807 | if (grad != nullptr) { |
808 | vspace.DeleteGradient(grad); |
809 | } |
810 | } |
811 | } |
812 | for (int i = 0, end = in_gradients.size(); i < end; ++i) { |
813 | const int64_t id = trace.input_tensor_id[i]; |
814 | if (in_gradients[i] != nullptr) { |
815 | auto& unaggregated_grads = gradients[id]; |
816 | unaggregated_grads.push_back(in_gradients[i]); |
817 | if (unaggregated_grads.size() > kMinAggregateCount) { |
818 | auto size_it = gradients_size.find(id); |
819 | int64_t size; |
820 | if (size_it == gradients_size.end()) { |
821 | size = vspace.NumElements(unaggregated_grads[0]); |
822 | gradients_size.emplace(id, size); |
823 | } else { |
824 | size = size_it->second; |
825 | } |
826 | if (unaggregated_grads.size() * size * 4 > kMinAggregateBytes) { |
827 | Gradient* grad = vspace.AggregateGradients(unaggregated_grads); |
828 | unaggregated_grads.clear(); |
829 | unaggregated_grads.push_back(grad); |
830 | } |
831 | } |
832 | } |
833 | auto usage_count_it = state.tensor_usage_counts.find(id); |
834 | if (usage_count_it == state.tensor_usage_counts.end()) { |
835 | VLOG(1) << "Tensor " << id << " not used" ; |
836 | continue; |
837 | } |
838 | usage_count_it->second--; |
839 | if (usage_count_it->second > 0) { |
840 | VLOG(1) << "Tensor " << id << " usage count " << usage_count_it->second; |
841 | continue; |
842 | } |
843 | auto tape_it = tensor_tape_.find(id); |
844 | if (tape_it == tensor_tape_.end()) { |
845 | VLOG(1) << "Tensor " << id |
846 | << " has no associated op. Deleting gradient" ; |
847 | auto grad_it = gradients.find(id); |
848 | if (grad_it != gradients.end()) { |
849 | for (auto g : grad_it->second) { |
850 | vspace.DeleteGradient(g); |
851 | } |
852 | gradients.erase(grad_it); |
853 | } |
854 | continue; |
855 | } |
856 | const int64_t op_id = tape_it->second; |
857 | if (op_id == -1) { |
858 | VLOG(1) << "Tensor " << id << " is source" ; |
859 | continue; |
860 | } |
861 | auto missing_it = state.op_missing_tensor.find(op_id); |
862 | if (missing_it != state.op_missing_tensor.end()) { |
863 | missing_it->second--; |
864 | VLOG(1) << "Op " << op_id << " missing " << missing_it->second |
865 | << " output gradients" ; |
866 | if (missing_it->second == 0) { |
867 | op_stack.insert(op_stack.begin(), op_id); |
868 | } |
869 | } |
870 | } |
871 | } |
872 | if (!state.op_tape.empty()) { |
873 | return tensorflow::errors::Internal("Invalid tape state." ); |
874 | } |
875 | if (result.size() != source_tensor_ids.size()) { |
876 | return errors::Internal("Expected result Span to be of size " , |
877 | source_tensor_ids.size(), " found " , result.size(), |
878 | " in call to Tape::ComputeGradient." ); |
879 | } |
880 | std::unordered_set<int64_t> used_gradient_ids(source_tensor_ids.size()); |
881 | for (int i = 0; i < source_tensor_ids.size(); i++) { |
882 | int64_t tensor_id = source_tensor_ids[i]; |
883 | auto grad_it = gradients.find(tensor_id); |
884 | if (grad_it == gradients.end()) { |
885 | result[i] = nullptr; |
886 | } else { |
887 | if (grad_it->second.size() > 1) { |
888 | Gradient* grad = vspace.AggregateGradients(grad_it->second); |
889 | grad_it->second.clear(); |
890 | grad_it->second.push_back(grad); |
891 | } |
892 | result[i] = grad_it->second[0]; |
893 | used_gradient_ids.insert(tensor_id); |
894 | } |
895 | } |
896 | VLOG(1) << "Final gradients size: " |
897 | << gradients.size() - used_gradient_ids.size(); |
898 | for (const auto& grad_pair : gradients) { |
899 | if (used_gradient_ids.find(grad_pair.first) == used_gradient_ids.end()) { |
900 | for (const auto& g : grad_pair.second) { |
901 | vspace.DeleteGradient(g); |
902 | } |
903 | } |
904 | } |
905 | return OkStatus(); |
906 | } |
907 | |
908 | template <typename Gradient, typename BackwardFunction, typename TapeTensor> |
909 | bool ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ShouldRecord( |
910 | gtl::ArraySlice<int64_t> tensor_ids, |
911 | gtl::ArraySlice<tensorflow::DataType> dtypes) { |
912 | if (call_state_.top().backward_tape != nullptr) { |
913 | // If we're forwarding Accumulate calls to backward_tape's RecordOperation, |
914 | // we should also delegate ShouldRecord. |
915 | return call_state_.top().backward_tape->ShouldRecord(tensor_ids, dtypes); |
916 | } |
917 | if (call_state_.top().accumulating) { |
918 | return false; |
919 | } |
920 | for (int i = 0; i < tensor_ids.size(); ++i) { |
921 | if (accumulated_gradients_.find(tensor_ids[i]) != |
922 | accumulated_gradients_.end()) { |
923 | if (IsDtypeTrainable(dtypes[i])) { |
924 | return true; |
925 | } |
926 | } |
927 | } |
928 | return false; |
929 | } |
930 | |
931 | template <typename Gradient, typename BackwardFunction, typename TapeTensor> |
932 | Status |
933 | ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape( |
934 | const string& op_type, const std::vector<TapeTensor>& output_tensors, |
935 | const std::function<BackwardFunction*()>& backward_function_getter, |
936 | const std::function<void(BackwardFunction*)>& backward_function_deleter, |
937 | const std::vector<Gradient*>& in_grads, absl::Span<Gradient*> out_grads) { |
938 | /* This function is approximately equivalent to this Python code: |
939 | |
940 | forwardprop_aids = tf.ones_like(output_tensors) |
941 | with tf.GradientTape() as g: |
942 | g.watch(forwardprop_aids) |
943 | grad = backward_function(forwardprop_aids) |
944 | forward_grads = g.gradient(grad, forwardprop_aids, output_gradients=in_grads) |
945 | accumulated_gradients_[ID(output_tensors)] = forward_grads |
946 | */ |
947 | std::unique_ptr<GradientTape<Gradient, BackwardFunction, TapeTensor>> tape( |
948 | new GradientTape<Gradient, BackwardFunction, TapeTensor>(false)); |
949 | AccumulatorCallState& call_state = call_state_.top(); |
950 | call_state.backward_tape = tape.get(); |
951 | auto pop_backward_tape = |
952 | gtl::MakeCleanup([&call_state] { call_state.backward_tape = nullptr; }); |
953 | std::vector<Gradient*> forwardprop_aids; |
954 | std::vector<int64_t> sources; |
955 | std::unordered_set<int64_t> sources_set; |
956 | sources.reserve(output_tensors.size()); |
957 | for (const TapeTensor& output_tensor : output_tensors) { |
958 | // Ownership of `aid` transferred to CallBackwardFunction below. |
959 | Gradient* aid; |
960 | if (output_tensor.GetDType() == tensorflow::DT_VARIANT) { |
961 | // Note: Needs to be zeros rather than ones since there's currently no |
962 | // ones_like for variants. |
963 | aid = output_tensor.ZerosLike(); |
964 | } else { |
965 | // TODO(allenl): Figure out why using zeros_like everywhere causes issues |
966 | // for some gradient functions and if there's another way to work around |
967 | // it (e.g. conds instead of ifs). The value shouldn't really matter. |
968 | TF_RETURN_IF_ERROR(vspace_.BuildOnesLike(output_tensor, &aid)); |
969 | } |
970 | if (TF_PREDICT_FALSE(aid == nullptr)) { |
971 | return tensorflow::errors::Internal( |
972 | "Failed to create ones tensor for tensor " , output_tensor.GetID(), |
973 | " with dtype " , output_tensor.GetDType()); |
974 | } |
975 | forwardprop_aids.push_back(aid); |
976 | int64_t aid_id = vspace_.TensorId(aid); |
977 | sources.push_back(aid_id); |
978 | sources_set.insert(aid_id); |
979 | tape->Watch(aid_id); |
980 | } |
981 | std::vector<Gradient*> grad(in_grads.size()); |
982 | auto delete_grad = gtl::MakeCleanup([&grad, this] { |
983 | for (Gradient* tensor : grad) { |
984 | this->vspace_.DeleteGradient(tensor); |
985 | } |
986 | }); |
987 | { |
988 | std::vector<int64_t> unneeded_gradients; |
989 | std::unique_ptr<BackwardFunction, std::function<void(BackwardFunction*)>> |
990 | backward_function(backward_function_getter(), |
991 | backward_function_deleter); |
992 | TF_RETURN_IF_ERROR(vspace_.CallBackwardFunction( |
993 | op_type, backward_function.get(), unneeded_gradients, forwardprop_aids, |
994 | absl::MakeSpan(grad))); |
995 | } |
996 | |
997 | // Stop the tape from recording |
998 | pop_backward_tape.release()(); |
999 | |
1000 | std::vector<int64_t> targets; |
1001 | std::vector<Gradient*> used_in_grads; |
1002 | // We may end up with slightly fewer elements than we reserve, but grad.size() |
1003 | // should be a reasonably tight upper bound. |
1004 | targets.reserve(grad.size()); |
1005 | used_in_grads.reserve(grad.size()); |
1006 | std::unordered_map<int64_t, TapeTensor> sources_that_are_targets; |
1007 | for (int grad_index = 0, end = grad.size(); grad_index < end; ++grad_index) { |
1008 | Gradient* grad_tensor = grad[grad_index]; |
1009 | if (grad_tensor != nullptr) { |
1010 | int64_t tensor_id = vspace_.TensorId(grad_tensor); |
1011 | targets.push_back(tensor_id); |
1012 | if (sources_set.find(tensor_id) != sources_set.end()) { |
1013 | sources_that_are_targets.emplace( |
1014 | tensor_id, vspace_.TapeTensorFromGradient(grad_tensor)); |
1015 | } |
1016 | Gradient* in_grad = in_grads[grad_index]; |
1017 | if (in_grad != nullptr) { |
1018 | // ComputeGradient steals a reference |
1019 | vspace_.MarkAsResult(in_grad); |
1020 | } |
1021 | used_in_grads.push_back(in_grad); |
1022 | } |
1023 | } |
1024 | |
1025 | return tape->ComputeGradient(vspace_, targets, sources, |
1026 | sources_that_are_targets, used_in_grads, |
1027 | out_grads); |
1028 | } |
1029 | |
1030 | template <typename Gradient, typename BackwardFunction, typename TapeTensor> |
1031 | Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate( |
1032 | const string& op_type, const std::vector<TapeTensor>& input_tensors, |
1033 | const std::vector<TapeTensor>& output_tensors, |
1034 | gtl::ArraySlice<int64_t> input_tensor_id, |
1035 | gtl::ArraySlice<tensorflow::DataType> input_dtypes, |
1036 | const ForwardFunction<Gradient>* forward_function, |
1037 | const std::function<BackwardFunction*()>& backward_function_getter, |
1038 | const std::function<void(BackwardFunction*)>& backward_function_deleter) { |
1039 | if (call_state_.top().backward_tape != nullptr) { |
1040 | // If backward_tape is not null, then this call to Accumulate is the result |
1041 | // of a still-active call to Accumulate which is running operations. We |
1042 | // forward these operations to backward_tape so the outer Accumulate call |
1043 | // can do its work. |
1044 | // |
1045 | // Rather than re-entering and delegating Accumulate like this, we could |
1046 | // instead allow ForwardAccumulator some control over the current tape set |
1047 | // (so it can deactivate itself and activate its GradientTape). Currently |
1048 | // that is managed by the language binding and would require relatively |
1049 | // messy callbacks. |
1050 | call_state_.top().backward_tape->RecordOperation( |
1051 | op_type, output_tensors, input_tensor_id, input_dtypes, |
1052 | backward_function_getter, backward_function_deleter); |
1053 | return OkStatus(); |
1054 | } |
1055 | if (!ShouldRecord(input_tensor_id, input_dtypes)) { |
1056 | return OkStatus(); |
1057 | } |
1058 | |
1059 | // We may need to allocate zero inputs for trainable dtypes we don't have JVPs |
1060 | // for. Make sure they get cleaned up. |
1061 | std::vector<Gradient*> new_zeros; |
1062 | auto delete_new_zeros = gtl::MakeCleanup([&new_zeros, this] { |
1063 | for (Gradient* tensor : new_zeros) { |
1064 | this->vspace_.DeleteGradient(tensor); |
1065 | } |
1066 | }); |
1067 | std::vector<Gradient*> in_grads; |
1068 | in_grads.reserve(input_tensors.size()); |
1069 | for (int target_index = 0; target_index < input_tensors.size(); |
1070 | ++target_index) { |
1071 | const auto current_grad = |
1072 | accumulated_gradients_.find(input_tensors[target_index].GetID()); |
1073 | if (current_grad == accumulated_gradients_.end()) { |
1074 | if (IsDtypeTrainable(input_tensors[target_index].GetDType())) { |
1075 | // ForwardAccumulator defaults to zeros for unwatched Tensors, unlike |
1076 | // GradientTape which uses ones. |
1077 | Gradient* zero = input_tensors[target_index].ZerosLike(); |
1078 | new_zeros.push_back(zero); |
1079 | in_grads.push_back(zero); |
1080 | } else { |
1081 | in_grads.push_back(nullptr); |
1082 | } |
1083 | } else { |
1084 | in_grads.push_back(current_grad->second); |
1085 | } |
1086 | } |
1087 | |
1088 | // Avoid infinite recursion. Whichever forward function we run, it'll end up |
1089 | // executing ops, and we don't want to watch those with this accumulator. |
1090 | call_state_.emplace(nullptr, true); |
1091 | auto pop_call_state = gtl::MakeCleanup([this] { this->call_state_.pop(); }); |
1092 | |
1093 | std::vector<Gradient*> forward_grads; |
1094 | if (forward_function == nullptr) { |
1095 | // We have no special-cased forward gradient. Fall back to running the |
1096 | // backward function under a gradient tape. |
1097 | forward_grads.resize(output_tensors.size()); |
1098 | TF_RETURN_IF_ERROR(ForwardpropFromTape( |
1099 | op_type, output_tensors, backward_function_getter, |
1100 | backward_function_deleter, in_grads, absl::MakeSpan(forward_grads))); |
1101 | } else { |
1102 | TF_RETURN_IF_ERROR( |
1103 | (*forward_function)(in_grads, &forward_grads, use_batch_)); |
1104 | } |
1105 | for (int i = 0; i < forward_grads.size(); ++i) { |
1106 | if (forward_grads[i] != nullptr) { |
1107 | int64_t tensor_id = output_tensors[i].GetID(); |
1108 | auto existing = accumulated_gradients_.find(tensor_id); |
1109 | if (existing != accumulated_gradients_.end()) { |
1110 | // This is a somewhat odd case to be in, since it means we have two |
1111 | // operations which supposedly both created the same Tensor. It comes up |
1112 | // in recompute_grad, where the gradients have the same value. However, |
1113 | // only the original gradient is connected to everything else, so we |
1114 | // should still use that. |
1115 | vspace_.DeleteGradient(forward_grads[i]); |
1116 | } else { |
1117 | accumulated_gradients_[output_tensors[i].GetID()] = forward_grads[i]; |
1118 | } |
1119 | } |
1120 | } |
1121 | return OkStatus(); |
1122 | } |
1123 | |
1124 | template <typename Gradient, typename BackwardFunction, typename TapeTensor> |
1125 | void ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Watch( |
1126 | int64_t tensor_id, Gradient* tangent) { |
1127 | typename std::unordered_map<int64_t, Gradient*>::iterator existing = |
1128 | accumulated_gradients_.find(tensor_id); |
1129 | vspace_.MarkAsResult(tangent); |
1130 | if (existing == accumulated_gradients_.end()) { |
1131 | accumulated_gradients_.emplace(tensor_id, tangent); |
1132 | } else { |
1133 | std::array<Gradient*, 2> to_aggregate; |
1134 | to_aggregate[0] = tangent; |
1135 | to_aggregate[1] = existing->second; |
1136 | // AggregateGradients steals a reference to each of its arguments. We |
1137 | // MarkAsResult on `tangent` above so we don't steal a reference to it. |
1138 | existing->second = vspace_.AggregateGradients(to_aggregate); |
1139 | } |
1140 | } |
1141 | |
1142 | template <typename Gradient, typename BackwardFunction, typename TapeTensor> |
1143 | void ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::DeleteGradient( |
1144 | int64_t tensor_id) { |
1145 | auto existing = accumulated_gradients_.find(tensor_id); |
1146 | if (existing != accumulated_gradients_.end()) { |
1147 | vspace_.DeleteGradient(existing->second); |
1148 | accumulated_gradients_.erase(existing); |
1149 | } |
1150 | } |
1151 | |
1152 | template <typename Gradient, typename BackwardFunction, typename TapeTensor> |
1153 | Gradient* ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::FetchJVP( |
1154 | int64_t tensor_id) { |
1155 | auto lookup = accumulated_gradients_.find(tensor_id); |
1156 | if (lookup == accumulated_gradients_.end()) { |
1157 | return nullptr; |
1158 | } else { |
1159 | return lookup->second; |
1160 | } |
1161 | } |
1162 | |
1163 | } // namespace eager |
1164 | } // namespace tensorflow |
1165 | |
1166 | #endif // TENSORFLOW_C_EAGER_TAPE_H_ |
1167 | |