1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_C_EAGER_TAPE_H_
16#define TENSORFLOW_C_EAGER_TAPE_H_
17
18// Language-agnostic gradient tape. Does not perform backpropagation, just
19// maintains the data structures required to do so.
20
21#include <stack>
22#include <unordered_map>
23#include <unordered_set>
24#include <vector>
25
26#include "tensorflow/core/config/flag_defs.h"
27#include "tensorflow/core/config/flags.h"
28#include "tensorflow/core/framework/tensor_shape.h"
29#include "tensorflow/core/framework/types.h"
30#include "tensorflow/core/lib/gtl/array_slice.h"
31#include "tensorflow/core/lib/gtl/cleanup.h"
32#include "tensorflow/core/lib/gtl/flatmap.h"
33#include "tensorflow/core/lib/gtl/flatset.h"
34#include "tensorflow/core/platform/errors.h"
35#include "tensorflow/core/platform/types.h"
36
37namespace tensorflow {
38namespace eager {
39
40// Represents an entry in the tape.
41template <typename BackwardFunction, typename TapeTensor>
42struct OpTapeEntry {
43 string op_type;
44 std::vector<TapeTensor> output_tensor_info;
45 std::vector<int64_t> input_tensor_id;
46
47 // TODO(apassos) consider narrowing down this interface.
48 BackwardFunction* backward_function;
49
50 // Should be called before deleting the backward function. TODO(apassos) use
51 // unique_ptrs to ensure this happens.
52 std::function<void(BackwardFunction*)> backward_function_deleter;
53};
54
55// Map from tensor_id to internally-defined operation-id of the operation which
56// produced this tensor. A value of -1 means that the tensor was directly
57// watched and not the result of any operation in the tape.
58using TensorTape = std::unordered_map<int64_t, int64_t>;
59
60// Map from operation-id to tape entry.
61template <typename BackwardFunction, typename TapeTensor>
62using OpTape =
63 std::unordered_map<int64_t, OpTapeEntry<BackwardFunction, TapeTensor>>;
64
65// Operations the tape needs to perform on tensors to do backpropagation. Named
66// "vspace" because a subset of these are related to a vector space, such as
67// adding gradients, getting zeroes, etc. Currently cannot be implemented
68// without using tensorflow python code, hence left unspecified here.
69//
70// Gradient is the type returned by gradient functions. In Python TF it's either
71// Tensor or IndexedSlices or None, which here we map to nullptr. Gradients need
72// to allow their size to be computed and they need to be passable to a backward
73// function and deleted (as the backprop code creates lots of gradients the user
74// is not interested in).
75//
76// BackwardFunction needs to be a closure which stores intermediate activations
77// from the forward computation and calls a vector-jacobian product function
78// (also known as adjoint function) to compute, given downstream gradients,
79// upstream gradients.
80//
81// TODO(apassos) provide concrete template instantiations for TFE_TensorHandle
82// specialization, which is blocked by quite a few things needing to loop back
83// into python now.
84template <typename Gradient, typename BackwardFunction, typename TapeTensor>
85class VSpace {
86 public:
87 virtual ~VSpace() {}
88
89 // Returns the number of elements in the gradient tensor.
90 virtual int64_t NumElements(Gradient* tensor) const = 0;
91
92 // Consumes references to the tensors in the gradient_tensors list and returns
93 // a tensor with the result.
94 virtual Gradient* AggregateGradients(
95 gtl::ArraySlice<Gradient*> gradient_tensors) const = 0;
96
97 // Calls the passed-in backward function.
98 //
99 // `unneeded_gradients` contains sorted list of input indices for which a
100 // gradient is not required.
101 virtual Status CallBackwardFunction(
102 const string& op_type, BackwardFunction* backward_function,
103 const std::vector<int64_t>& unneeded_gradients,
104 gtl::ArraySlice<Gradient*> output_gradients,
105 absl::Span<Gradient*> result) const = 0;
106
107 // Builds a tensor filled with ones with the same shape and dtype as `t`.
108 virtual Status BuildOnesLike(const TapeTensor& t,
109 Gradient** result) const = 0;
110
111 // Looks up the ID of a Gradient.
112 virtual int64_t TensorId(Gradient* tensor) const = 0;
113
114 // Converts a Gradient to a TapeTensor.
115 virtual TapeTensor TapeTensorFromGradient(Gradient* gradient) const = 0;
116
117 // Marks the following gradient as a result so it's not consumed by backward
118 // functions.
119 virtual void MarkAsResult(Gradient* gradient) const = 0;
120
121 // Deletes the input tensor.
122 virtual void DeleteGradient(Gradient* gradient) const = 0;
123};
124
125// Traces the execution of operations, doing eager garbage collection, and
126// exporting a full trace so other code can do backpropagation. Not thread-safe.
127template <typename Gradient, typename BackwardFunction, typename TapeTensor>
128class GradientTape {
129 public:
130 // If `persistent` is true, GradientTape will not eagerly delete backward
131 // functions (and hence the tensors they keep alive). Instead, everything
132 // is deleted in ~GradientTape. Persistent GradientTapes are useful when
133 // users want to compute multiple gradients over the same tape.
134 explicit GradientTape(bool persistent) : persistent_(persistent) {}
135 ~GradientTape() {
136 for (const auto& pair : op_tape_) {
137 pair.second.backward_function_deleter(pair.second.backward_function);
138 }
139 }
140
141 // Returns whether any tensor in a list of tensors is being watched and has
142 // a trainable dtype.
143 bool ShouldRecord(gtl::ArraySlice<int64_t> tensor_ids,
144 gtl::ArraySlice<tensorflow::DataType> dtypes) const;
145
146 // Adds this tensor to the list of watched tensors.
147 //
148 // This is a no-op if the tensor is already being watched either from an
149 // earlier call to `GradientTape::Watch` or being an output of an op with
150 // watched inputs.
151 void Watch(int64_t tensor_id);
152
153 // Records an operation with inputs `input_tensor_id` and outputs
154 // `output_tensors` on the tape and marks all its outputs as watched if at
155 // least one input of the op is watched and has trainable dtype.
156 //
157 // op_type is used to decide which of the incoming gradients can be left as
158 // nullptr instead of building zeros when build_default_zeros_grads == true.
159 void RecordOperation(
160 const string& op_type, const std::vector<TapeTensor>& output_tensors,
161 gtl::ArraySlice<int64_t> input_tensor_id,
162 gtl::ArraySlice<tensorflow::DataType> input_dtypes,
163 const std::function<BackwardFunction*()>& backward_function_getter,
164 const std::function<void(BackwardFunction*)>& backward_function_deleter);
165
166 void DeleteTrace(int64_t tensor_id);
167
168 // Consumes the internal state of the tape (so cannot be called more than
169 // once) and produces the gradient of the target tensors with respect to the
170 // source tensors. The output gradients are used if not empty and not
171 // null. The result is populated with one tensor per target element.
172 // When running backward functions, builds zeros-like tensors for
173 // incoming grads which are nullptrs, unless `build_default_zeros_grads`
174 // is set to false.
175 Status ComputeGradient(
176 const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
177 const gtl::ArraySlice<int64_t> target_tensor_ids,
178 const gtl::ArraySlice<int64_t> source_tensor_ids,
179 const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
180 gtl::ArraySlice<Gradient*> output_gradients, absl::Span<Gradient*> result,
181 bool build_default_zeros_grads = true);
182
183 // Whether the tape is persistent. See ctor for detailed description.
184 bool IsPersistent() const { return persistent_; }
185
186 private:
187 TensorTape tensor_tape_;
188 OpTape<BackwardFunction, TapeTensor> op_tape_;
189 int64_t next_op_id_{0};
190
191 // Map from tensor id to number of remaining usages (i.e. how many entries in
192 // the tape refer to it); to aid in tape garbage collection.
193 std::unordered_map<int64_t, int64_t> tensor_usage_;
194
195 // If false, all activations are deleted in the first call to ComputeGradient.
196 // Else, only when this is destructed.
197 bool persistent_;
198};
199
200// Describes a callback for special-cased and more efficient jvp computation.
201//
202// Could just be a simple typedef in ForwardAccumulator, but MSVC chokes on
203// that.
204template <typename Gradient>
205class ForwardFunction
206 : public std::function<Status(const std::vector<Gradient*>&,
207 std::vector<Gradient*>*, bool)> {
208 public:
209 template <typename lambda_type>
210 explicit ForwardFunction(lambda_type lambda)
211 : std::function<Status(const std::vector<Gradient*>&,
212 std::vector<Gradient*>*, bool)>(lambda) {}
213};
214
215// Computes Jacobian-vector products using forward-mode automatic
216// differentiation.
217//
218// While GradientTape's RecordOperation is trivial, ForwardAccumulator's
219// Accumulate runs the gradient computation immediately.
220//
221// Keeps references to Tensors watched via Watch and computed in Accumulate
222// corresponding to output_tensors, and releases these references in its
223// destructor. However, waiting until the destructor runs loses the memory
224// efficiency of forward-mode autodiff. Instead, language bindings should call
225// DeleteGradient as soon as a Tensor which was `Watch`ed or was an output
226// Tensor passed to Accumulate goes out of scope.
227//
228// Not thread-safe.
229template <typename Gradient, typename BackwardFunction, typename TapeTensor>
230class ForwardAccumulator {
231 public:
232 // Does not take ownership of `vspace`, which must outlive the
233 // ForwardAccumulator.
234 explicit ForwardAccumulator(
235 const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
236 bool use_batch)
237 : vspace_(vspace), use_batch_(use_batch) {
238 call_state_.emplace(nullptr, false);
239 }
240
241 virtual ~ForwardAccumulator() {
242 for (auto accumulated : accumulated_gradients_) {
243 vspace_.DeleteGradient(accumulated.second);
244 }
245 }
246
247 // Tell the forward accumulator to watch tensor_id, with a Tensor tangent
248 // vector `tangent` of matching shape and dtype. Tangents are the "vector" in
249 // "Jacobian-vector product"; `Watch`ing a new Tensor and immediately calling
250 // FetchJVP for it would return `tangent`.
251 void Watch(int64_t tensor_id, Gradient* tangent);
252
253 // Removes the gradient associated with tensor_id. Should be called when the
254 // Tensor associated with `tensor_id` is deleted.
255 void DeleteGradient(int64_t tensor_id);
256
257 // Runs forward autodiff. Should be called whenever a new operation is
258 // available and the accumulator is active.
259 //
260 // Like GradientTape::RecordOperation, this method takes the operation type
261 // `op_type` (e.g. "Add"), the operation's inputs (`input_tensors`,
262 // `input_tensor_id`, and `input_dtypes`; the latter two are somewhat
263 // redundant but taken as arguments to avoid repeatedly fetching these values
264 // between calls to ShouldRecord and Accumulator), and its outputs
265 // (`output_tensors`).
266 //
267 // If provided, a non-null `forward_function` will be used instead of the
268 // backward function (`backward_function_getter` /
269 // `backward_function_deleter`) to compute jvps for this operation. If
270 // `forward_function` is null, a GradientTape is used on the backward function
271 // to compute the jvp, which will waste computation when executing eagerly.
272 //
273 // Unlike GradientTape::RecordOperation, Accumulate runs gradient computation
274 // immediately. It stores the results, which feed into Accumulate for future
275 // operations and may be fetched by calling FetchJVP. ForwardAccumulator
276 // maintains a reference to these JVPs: if an `output_tensors` Tensor is
277 // deleted, `DeleteGradient` should be called as soon as possible to free the
278 // (now inaccessible) corresponding JVPs, but ForwardAccumulator's destructor
279 // will release remaining references.
280 //
281 // This method is not thread-safe (and in general ForwardAccumulator is not
282 // thread-safe).
283 Status Accumulate(
284 const string& op_type, const std::vector<TapeTensor>& input_tensors,
285 const std::vector<TapeTensor>& output_tensors,
286 gtl::ArraySlice<int64_t> input_tensor_id,
287 gtl::ArraySlice<tensorflow::DataType> input_dtypes,
288 const ForwardFunction<Gradient>* forward_function,
289 const std::function<BackwardFunction*()>& backward_function_getter,
290 const std::function<void(BackwardFunction*)>& backward_function_deleter);
291
292 // Returns true if `Accumulate` is active somewhere above on the stack and
293 // there isn't an intervening PushState. This is useful for ordering
294 // ForwardAccumulators, where more deeply nested accumulators should not see
295 // computations from less deeply nested accumulators.
296 bool BusyAccumulating() const { return call_state_.top().accumulating; }
297
298 // Fetches the current Jacobian-vector product associated with `tensor_id`, or
299 // a nullptr if none is available.
300 //
301 // Returns a borrowed reference, i.e. does not run VSpace::MarkAsResult on its
302 // return value. The caller should increment the reference count before
303 // deleting the ForwardAccumulator or calling DeleteGradient if keeping a
304 // persistent reference to a non-null result.
305 Gradient* FetchJVP(int64_t tensor_id);
306
307 // Indicates whether the forward accumulator should run on an operation with
308 // the specified inputs and dtypes.
309 bool ShouldRecord(gtl::ArraySlice<int64_t> tensor_ids,
310 gtl::ArraySlice<tensorflow::DataType> dtypes);
311
312 // Temporarily push or pop transient state for this accumulator.
313 //
314 // Allows an accumulator which is currently processing an operation to
315 // temporarily reset its state. Without pushing and popping, accumulators
316 // ignore operations executed as a direct result of their own jvp
317 // computations.
318 void PushState() { call_state_.emplace(nullptr, false); }
319 void PopState() { call_state_.pop(); }
320
321 private:
322 // Helper for Accumulate: uses a GradientTape to compute forward gradients
323 // from a backward gradient function. Fills `out_grads` corresponding to
324 // `output_tensors`. `out_grads` must not be null.
325 //
326 // Executes the backward function in order to trace its gradient, which will
327 // waste computation if executing eagerly (when graph building the unneeded
328 // computation is pruned). Temporarily sets `backward_tape` so that
329 // Accumulate will forward op executions to the tape while the backward
330 // function is running; this effectively adds the backward tape to the active
331 // set (but does not require complicated callbacks to the language bindings).
332 Status ForwardpropFromTape(
333 const string& op_type, const std::vector<TapeTensor>& output_tensors,
334 const std::function<BackwardFunction*()>& backward_function_getter,
335 const std::function<void(BackwardFunction*)>& backward_function_deleter,
336 const std::vector<Gradient*>& in_grads, absl::Span<Gradient*> out_grads);
337
338 // Maps from tensor IDs to corresponding JVPs.
339 std::unordered_map<int64_t, Gradient*> accumulated_gradients_;
340 // Not owned; provides operations on Tensors which are currently only
341 // available in language bindings (e.g. Python).
342 const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace_;
343
344 // Decides if tangents are vectorized or not
345 bool use_batch_;
346
347 struct AccumulatorCallState {
348 AccumulatorCallState(
349 GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape,
350 bool accumulating)
351 : backward_tape(backward_tape), accumulating(accumulating) {}
352 // Set temporarily while in the Accumulate method; if backward_tape is not
353 // nullptr then we forward op executions to it so Accumulate can compute a
354 // backward pass on its backward function.
355 //
356 // Not owned by the ForwardAccumulator. The method which sets
357 // `backward_tape` keeps ownership.
358 GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape;
359 // While the Accumulate method is running (accumulating is True), any op
360 // executions not forwarded to backward_tape should be ignored.
361 bool accumulating;
362 };
363 // A deque-backed stack, whose element references are not invalidated by
364 // pushes and pops at the back.
365 std::stack<AccumulatorCallState> call_state_;
366};
367
368// Template instantiations here
369
370inline bool IsDtypeTrainable(DataType dtype) {
371 switch (dtype) {
372 case DT_HALF:
373 case DT_BFLOAT16:
374 case DT_FLOAT:
375 case DT_DOUBLE:
376 case DT_COMPLEX64:
377 case DT_COMPLEX128:
378 case DT_RESOURCE:
379 case DT_VARIANT:
380 return true;
381 case DT_QINT8:
382 case DT_QINT16:
383 case DT_QINT32:
384 case DT_QUINT8:
385 case DT_QUINT16:
386 return tensorflow::flags::Global()
387 .enable_quantized_dtypes_training.value();
388 default:
389 return false;
390 }
391}
392
393template <typename Gradient, typename BackwardFunction, typename TapeTensor>
394bool GradientTape<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
395 gtl::ArraySlice<int64_t> tensor_ids,
396 gtl::ArraySlice<tensorflow::DataType> dtypes) const {
397 CHECK_EQ(tensor_ids.size(), dtypes.size());
398 for (int i = 0; i < tensor_ids.size(); ++i) {
399 if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) {
400 if (IsDtypeTrainable(dtypes[i])) {
401 return true;
402 }
403 }
404 }
405 return false;
406}
407
408template <typename Gradient, typename BackwardFunction, typename TapeTensor>
409void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
410 int64_t tensor_id) {
411 tensor_tape_.emplace(tensor_id, -1);
412}
413
414template <typename Gradient, typename BackwardFunction, typename TapeTensor>
415void GradientTape<Gradient, BackwardFunction, TapeTensor>::RecordOperation(
416 const string& op_type, const std::vector<TapeTensor>& output_tensors,
417 gtl::ArraySlice<int64_t> input_tensor_id,
418 gtl::ArraySlice<tensorflow::DataType> input_dtypes,
419 const std::function<BackwardFunction*()>& backward_function_getter,
420 const std::function<void(BackwardFunction*)>& backward_function_deleter) {
421 if (!ShouldRecord(input_tensor_id, input_dtypes)) {
422 return;
423 }
424 std::vector<int64_t> ids;
425 ids.reserve(input_tensor_id.size());
426 for (int64_t i : input_tensor_id) {
427 tensor_usage_[i]++;
428 ids.push_back(i);
429 }
430 const int64_t op_id = next_op_id_++;
431 std::vector<TapeTensor> tensors;
432 tensors.reserve(output_tensors.size());
433 for (const TapeTensor& o : output_tensors) {
434 // Note: the tensor can have already been watched and hence be in the tape,
435 // so we cannot check that we're inserting it here.
436 tensor_tape_[o.GetID()] = op_id;
437 tensor_usage_[o.GetID()] = 1;
438 tensors.push_back(o);
439 }
440 op_tape_[op_id] = OpTapeEntry<BackwardFunction, TapeTensor>{
441 op_type, std::move(tensors), std::move(ids), backward_function_getter(),
442 backward_function_deleter};
443}
444
445template <typename Gradient, typename BackwardFunction, typename TapeTensor>
446void GradientTape<Gradient, BackwardFunction, TapeTensor>::DeleteTrace(
447 int64_t tensor_id) {
448 auto it = tensor_usage_.find(tensor_id);
449 if (it == tensor_usage_.end()) {
450 return;
451 }
452 it->second--;
453 if (it->second != 0) {
454 return;
455 }
456 tensor_usage_.erase(it);
457 auto tensor_op_it = tensor_tape_.find(tensor_id);
458 if (tensor_op_it == tensor_tape_.end()) {
459 return;
460 }
461 const int64_t op_id = tensor_op_it->second;
462 if (op_id == -1) {
463 // Do not delete watched tensors.
464 return;
465 }
466 tensor_tape_.erase(tensor_op_it);
467 auto op_it = op_tape_.find(op_id);
468 CHECK(op_it != op_tape_.end());
469 for (const auto& output : op_it->second.output_tensor_info) {
470 if (tensor_usage_.find(output.GetID()) != tensor_usage_.end()) {
471 // Found a usage for an output, so cannot delete the op.
472 return;
473 }
474 }
475 for (int64_t id : op_it->second.input_tensor_id) {
476 DeleteTrace(id);
477 }
478 op_it->second.backward_function_deleter(op_it->second.backward_function);
479 op_tape_.erase(op_it);
480}
481
482// Terminology:
483//
484// - op: a possibly composite operation, which has an entry in the tape
485// - target: dy in dx/dy
486// - source: dx in dx/dy
487// - tensor: one of the many inputs or outputs of an operation
488//
489// Below here we do the gradient algorithm. It works as follows:
490//
491// First we filter the tape to just the subset of operations we want to
492// differentiate. In the process of doing so we count how many times each Tensor
493// is used as an input to an op (so we know when we're done computing gradients
494// for that Tensor). We also count, for each tape entry, how many of its output
495// Tensors need gradients to be computed (Tensors which are not used do not need
496// any gradients to be computed).
497//
498// Finally, we start a backprop stack with a set of tape entries for which we
499// have all gradients available. This set usually is a subset of the set of
500// targets (not all since targets which have outputs in the tape will not have
501// gradients available initially).
502//
503// Then we repeatedly pop an entry from the stack, run its backprop, and update
504// the gradients of its inputs. Once we have computed all gradients for a single
505// input we can mark this input as done, and this can trigger adding an entry to
506// the stack if all outputs of that entry are now done.
507//
508// When the stack is empty we have gradients for all tensors we're interested
509// in.
510
511namespace {
512
513template <typename BackwardFunction, typename TapeTensor>
514struct BackpropInitialState {
515 OpTape<BackwardFunction, TapeTensor> op_tape;
516
517 // Map from tensor ID to how many references still exist for this tensor in
518 // the tape.
519 std::unordered_map<int64_t, int64_t> tensor_usage_counts;
520
521 // Maps from op ID to how many output tensors of this op still need to have
522 // their gradients computed.
523 std::unordered_map<int64_t, int64_t> op_missing_tensor;
524};
525
526// If `persistent_tape` is true, op_tape is not changed and none of the
527// backwards functions are deleted.
528// If `persistent_tape` is false, op_tape is cleared and backwards functions
529// not needed for gradient computation are deleted. Backwards functions that
530// are needed, are copied and returned in BackpropInitialState.
531template <typename BackwardFunction, typename TapeTensor>
532BackpropInitialState<BackwardFunction, TapeTensor> PrepareBackprop(
533 gtl::ArraySlice<int64_t> target, const TensorTape& tensor_tape,
534 OpTape<BackwardFunction, TapeTensor>* op_tape,
535 const std::unordered_set<int64_t>& sources_set, bool persistent_tape) {
536 std::vector<int64_t> tensor_stack;
537 tensor_stack.reserve(target.size());
538 for (auto t : target) {
539 tensor_stack.push_back(t);
540 }
541 BackpropInitialState<BackwardFunction, TapeTensor> result;
542 while (!tensor_stack.empty()) {
543 int64_t tensor_id = tensor_stack.back();
544 tensor_stack.pop_back();
545 auto op_id_it = tensor_tape.find(tensor_id);
546 if (op_id_it == tensor_tape.end()) {
547 continue;
548 }
549 int64_t op_id = op_id_it->second;
550 auto op_it = op_tape->find(op_id);
551 auto result_op_it = result.op_tape.find(op_id);
552 if (op_id == -1 || op_it == op_tape->end() ||
553 result_op_it != result.op_tape.end()) {
554 continue;
555 }
556 CHECK(result.op_tape.emplace(op_id, op_it->second).second);
557 for (auto it : op_it->second.input_tensor_id) {
558 auto count_it = result.tensor_usage_counts.find(it);
559 if (count_it != result.tensor_usage_counts.end()) {
560 count_it->second++;
561 } else {
562 result.tensor_usage_counts[it] = 1;
563 if (tensor_tape.find(it) != tensor_tape.end()) {
564 tensor_stack.push_back(it);
565 }
566 }
567 }
568 if (!persistent_tape) {
569 op_tape->erase(op_it);
570 }
571 }
572 for (auto& pair : result.tensor_usage_counts) {
573 auto it = tensor_tape.find(pair.first);
574 if (it != tensor_tape.end() && it->second != -1) {
575 result.op_missing_tensor[it->second] += 1;
576 }
577 }
578 if (!persistent_tape) {
579 // Call destructors for all unneeded gradient functions and
580 // clear the op_tape. We can clear the tape because ownership of
581 // backward functions that will be used for gradient computation
582 // has been transferred to `result`.
583 for (const auto& op_pair : *op_tape) {
584 op_pair.second.backward_function_deleter(
585 op_pair.second.backward_function);
586 }
587 op_tape->clear();
588 }
589 return result;
590}
591
592template <typename BackwardFunction, typename TapeTensor>
593std::vector<int64_t> InitialStack(
594 const OpTape<BackwardFunction, TapeTensor>& op_tape,
595 const std::unordered_map<int64_t, int64_t>& op_missing_tensor) {
596 std::vector<int64_t> result;
597 for (auto& op_entry : op_tape) {
598 if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) {
599 result.push_back(op_entry.first);
600 }
601 }
602 return result;
603}
604
605template <typename Gradient, typename BackwardFunction, typename TapeTensor>
606Status InitialGradients(
607 const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
608 gtl::ArraySlice<int64_t> target_tensor_ids,
609 const std::unordered_map<int64_t, TapeTensor>& sources_that_are_targets,
610 gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
611 const OpTape<BackwardFunction, TapeTensor>& op_tape,
612 std::unordered_map<int64_t, std::vector<Gradient*>>* result) {
613 for (int i = 0, end = target_tensor_ids.size(); i < end; ++i) {
614 const int64_t id = target_tensor_ids[i];
615 if (output_gradients.empty() || output_gradients[i] == nullptr) {
616 auto tensor_it = tensor_tape.find(id);
617 if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
618 auto op_it = op_tape.find(tensor_it->second);
619 if (op_it == op_tape.end()) {
620 return errors::Internal(
621 "Internal state of the gradient tape is invalid: "
622 "failed to find operation producing a tensor");
623 }
624 bool found = false;
625 for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
626 if (op_it->second.output_tensor_info[j].GetID() == id) {
627 found = true;
628 Gradient* ones_like = nullptr;
629 TF_RETURN_IF_ERROR(vspace.BuildOnesLike(
630 op_it->second.output_tensor_info[j], &ones_like));
631 (*result)[id].push_back(ones_like);
632 break;
633 }
634 }
635 if (!found) {
636 return errors::Internal(
637 "Internal state of the gradient tape is invalid: "
638 "none of operations outputs match expected tensor");
639 }
640 } else {
641 // This target tensor was not generated by any operation recorded on
642 // the tape, so no gradient needs to be computed from it unless this
643 // target is also a source.
644 auto source_tensor = sources_that_are_targets.find(id);
645 if (source_tensor != sources_that_are_targets.end()) {
646 Gradient* ones_like = nullptr;
647 TF_RETURN_IF_ERROR(
648 vspace.BuildOnesLike(source_tensor->second, &ones_like));
649 (*result)[id].push_back(ones_like);
650 }
651 }
652 } else {
653 (*result)[id].push_back(output_gradients[i]);
654 }
655 }
656 return OkStatus();
657}
658
659// TODO(agarwal): use an automatic mechanism for handling None arguments to
660// gradient functions.
661//
662// Some gradient functions can accept None arguments for gradients. The
663// following maps the operation name to the indices at which the corresponding
664// gradient function can accept None values. e.g. FusedBatchNorm outputs 5
665// values and hence receives 5 gradient values during backprop. However the
666// gradient function uses only the first of those values and ignores the rest.
667// The entry, "FusedBatchNorm": [1, 2, 3, 4], indicates that only the gradient
668// corresponding to index 0 is used, and the gradient values at indices 1-4 are
669// ignored (and hence can be None). The backprop algorithm can then leverage
670// this by not constructing zeros to pass for those indices.
671std::unordered_map<string, std::unordered_set<int>>*
672FunctionsAcceptingNoneForIndicesMap() {
673 static auto* const m =
674 new std::unordered_map<string, std::unordered_set<int>>({
675 {"SoftmaxCrossEntropyWithLogits", {1}},
676 {"SparseSoftmaxCrossEntropyWithLogits", {1}},
677 {"FusedBatchNorm", {1, 2, 3, 4}},
678 });
679 return m;
680}
681
682} // namespace
683
684// If over kMinAggregateCount gradients are accumulated and the total
685// memory consumption is over kMinAggregateBytes, do an early aggregation
686// so as to release the gradient tensor to save memory.
687constexpr int kMinAggregateCount = 4;
688constexpr int kMinAggregateBytes = 128 * 1024 * 1024;
689
690template <typename Gradient, typename BackwardFunction, typename TapeTensor>
691Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
692 const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
693 const gtl::ArraySlice<int64_t> target_tensor_ids,
694 const gtl::ArraySlice<int64_t> source_tensor_ids,
695 const std::unordered_map<int64_t, TapeTensor>& sources_that_are_targets,
696 gtl::ArraySlice<Gradient*> output_gradients, absl::Span<Gradient*> result,
697 bool build_default_zeros_grads) {
698 std::unordered_set<int64_t> sources_set(source_tensor_ids.begin(),
699 source_tensor_ids.end());
700 BackpropInitialState<BackwardFunction, TapeTensor> state = PrepareBackprop(
701 target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_);
702 std::vector<int64_t> op_stack =
703 InitialStack(state.op_tape, state.op_missing_tensor);
704 std::unordered_map<int64_t, std::vector<Gradient*>> gradients;
705 Status s = InitialGradients(vspace, target_tensor_ids,
706 sources_that_are_targets, output_gradients,
707 tensor_tape_, state.op_tape, &gradients);
708 auto cleanup = gtl::MakeCleanup([this, &state]() {
709 if (!persistent_) {
710 // Release all backprop functions
711 for (const auto& pair : state.op_tape) {
712 pair.second.backward_function_deleter(pair.second.backward_function);
713 }
714 }
715 });
716 if (!s.ok()) {
717 return s;
718 }
719
720 std::unordered_map<int64_t, int64_t> gradients_size;
721 // TODO(apassos) multiple threads could be dequeuing from op_stack at the same
722 // time, for better CPU backprop performance.
723 VLOG(1) << "Initial stack:";
724 if (VLOG_IS_ON(1)) {
725 for (auto t : op_stack) {
726 VLOG(1) << " " << t;
727 }
728 }
729 while (!op_stack.empty()) {
730 const int64_t op = op_stack.back();
731 VLOG(1) << "Popped " << op;
732 op_stack.pop_back();
733 auto op_it = state.op_tape.find(op);
734 if (op_it == state.op_tape.end()) {
735 // It is possible for ops to end up on the stack if they are unrelated to
736 // the target; we should just skip them.
737 continue;
738 }
739 auto trace = std::move(op_it->second);
740 state.op_tape.erase(op_it);
741 std::vector<Gradient*> out_gradients;
742 out_gradients.reserve(trace.output_tensor_info.size());
743 std::vector<int64_t> unneeded_gradients;
744 for (int i = 0, end = trace.input_tensor_id.size(); i < end; i++) {
745 const auto& in_tensor_id = trace.input_tensor_id[i];
746 if (tensor_tape_.find(in_tensor_id) == tensor_tape_.end() &&
747 sources_set.find(in_tensor_id) == sources_set.end()) {
748 unneeded_gradients.push_back(i);
749 }
750 }
751
752 bool any_gradient_nonzero = false;
753 std::vector<int> zero_indices;
754 for (int i = 0, end = trace.output_tensor_info.size(); i < end; ++i) {
755 const int64_t id = trace.output_tensor_info[i].GetID();
756 auto grad_it = gradients.find(id);
757 if (grad_it == gradients.end()) {
758 out_gradients.push_back(nullptr);
759 if (build_default_zeros_grads) {
760 auto func_name_it =
761 FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type);
762 if (func_name_it == FunctionsAcceptingNoneForIndicesMap()->end() ||
763 func_name_it->second.find(i) == func_name_it->second.end()) {
764 zero_indices.push_back(i);
765 }
766 }
767 } else {
768 any_gradient_nonzero = true;
769 Gradient* new_gradients = nullptr;
770 if (grad_it->second.size() == 1) {
771 new_gradients = grad_it->second.at(0);
772 } else {
773 new_gradients = vspace.AggregateGradients(grad_it->second);
774 }
775 if (sources_set.find(grad_it->first) == sources_set.end()) {
776 gradients.erase(grad_it);
777 } else {
778 grad_it->second.clear();
779 grad_it->second.push_back(new_gradients);
780 vspace.MarkAsResult(new_gradients);
781 }
782 out_gradients.push_back(new_gradients);
783 }
784 }
785 VLOG(1) << "Calling gradient function for '" << trace.op_type << "'";
786 std::vector<Gradient*> in_gradients(trace.input_tensor_id.size());
787 DCHECK(build_default_zeros_grads || zero_indices.empty());
788 if (any_gradient_nonzero) {
789 for (const auto i : zero_indices) {
790 out_gradients[i] = trace.output_tensor_info[i].ZerosLike();
791 }
792 Status s;
793 s = vspace.CallBackwardFunction(trace.op_type, trace.backward_function,
794 unneeded_gradients, out_gradients,
795 absl::MakeSpan(in_gradients));
796 if (!persistent_) {
797 trace.backward_function_deleter(trace.backward_function);
798 }
799 if (!s.ok()) {
800 return s;
801 }
802 } else {
803 if (!persistent_) {
804 trace.backward_function_deleter(trace.backward_function);
805 }
806 for (Gradient* grad : out_gradients) {
807 if (grad != nullptr) {
808 vspace.DeleteGradient(grad);
809 }
810 }
811 }
812 for (int i = 0, end = in_gradients.size(); i < end; ++i) {
813 const int64_t id = trace.input_tensor_id[i];
814 if (in_gradients[i] != nullptr) {
815 auto& unaggregated_grads = gradients[id];
816 unaggregated_grads.push_back(in_gradients[i]);
817 if (unaggregated_grads.size() > kMinAggregateCount) {
818 auto size_it = gradients_size.find(id);
819 int64_t size;
820 if (size_it == gradients_size.end()) {
821 size = vspace.NumElements(unaggregated_grads[0]);
822 gradients_size.emplace(id, size);
823 } else {
824 size = size_it->second;
825 }
826 if (unaggregated_grads.size() * size * 4 > kMinAggregateBytes) {
827 Gradient* grad = vspace.AggregateGradients(unaggregated_grads);
828 unaggregated_grads.clear();
829 unaggregated_grads.push_back(grad);
830 }
831 }
832 }
833 auto usage_count_it = state.tensor_usage_counts.find(id);
834 if (usage_count_it == state.tensor_usage_counts.end()) {
835 VLOG(1) << "Tensor " << id << " not used";
836 continue;
837 }
838 usage_count_it->second--;
839 if (usage_count_it->second > 0) {
840 VLOG(1) << "Tensor " << id << " usage count " << usage_count_it->second;
841 continue;
842 }
843 auto tape_it = tensor_tape_.find(id);
844 if (tape_it == tensor_tape_.end()) {
845 VLOG(1) << "Tensor " << id
846 << " has no associated op. Deleting gradient";
847 auto grad_it = gradients.find(id);
848 if (grad_it != gradients.end()) {
849 for (auto g : grad_it->second) {
850 vspace.DeleteGradient(g);
851 }
852 gradients.erase(grad_it);
853 }
854 continue;
855 }
856 const int64_t op_id = tape_it->second;
857 if (op_id == -1) {
858 VLOG(1) << "Tensor " << id << " is source";
859 continue;
860 }
861 auto missing_it = state.op_missing_tensor.find(op_id);
862 if (missing_it != state.op_missing_tensor.end()) {
863 missing_it->second--;
864 VLOG(1) << "Op " << op_id << " missing " << missing_it->second
865 << " output gradients";
866 if (missing_it->second == 0) {
867 op_stack.insert(op_stack.begin(), op_id);
868 }
869 }
870 }
871 }
872 if (!state.op_tape.empty()) {
873 return tensorflow::errors::Internal("Invalid tape state.");
874 }
875 if (result.size() != source_tensor_ids.size()) {
876 return errors::Internal("Expected result Span to be of size ",
877 source_tensor_ids.size(), " found ", result.size(),
878 " in call to Tape::ComputeGradient.");
879 }
880 std::unordered_set<int64_t> used_gradient_ids(source_tensor_ids.size());
881 for (int i = 0; i < source_tensor_ids.size(); i++) {
882 int64_t tensor_id = source_tensor_ids[i];
883 auto grad_it = gradients.find(tensor_id);
884 if (grad_it == gradients.end()) {
885 result[i] = nullptr;
886 } else {
887 if (grad_it->second.size() > 1) {
888 Gradient* grad = vspace.AggregateGradients(grad_it->second);
889 grad_it->second.clear();
890 grad_it->second.push_back(grad);
891 }
892 result[i] = grad_it->second[0];
893 used_gradient_ids.insert(tensor_id);
894 }
895 }
896 VLOG(1) << "Final gradients size: "
897 << gradients.size() - used_gradient_ids.size();
898 for (const auto& grad_pair : gradients) {
899 if (used_gradient_ids.find(grad_pair.first) == used_gradient_ids.end()) {
900 for (const auto& g : grad_pair.second) {
901 vspace.DeleteGradient(g);
902 }
903 }
904 }
905 return OkStatus();
906}
907
908template <typename Gradient, typename BackwardFunction, typename TapeTensor>
909bool ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
910 gtl::ArraySlice<int64_t> tensor_ids,
911 gtl::ArraySlice<tensorflow::DataType> dtypes) {
912 if (call_state_.top().backward_tape != nullptr) {
913 // If we're forwarding Accumulate calls to backward_tape's RecordOperation,
914 // we should also delegate ShouldRecord.
915 return call_state_.top().backward_tape->ShouldRecord(tensor_ids, dtypes);
916 }
917 if (call_state_.top().accumulating) {
918 return false;
919 }
920 for (int i = 0; i < tensor_ids.size(); ++i) {
921 if (accumulated_gradients_.find(tensor_ids[i]) !=
922 accumulated_gradients_.end()) {
923 if (IsDtypeTrainable(dtypes[i])) {
924 return true;
925 }
926 }
927 }
928 return false;
929}
930
931template <typename Gradient, typename BackwardFunction, typename TapeTensor>
932Status
933ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
934 const string& op_type, const std::vector<TapeTensor>& output_tensors,
935 const std::function<BackwardFunction*()>& backward_function_getter,
936 const std::function<void(BackwardFunction*)>& backward_function_deleter,
937 const std::vector<Gradient*>& in_grads, absl::Span<Gradient*> out_grads) {
938 /* This function is approximately equivalent to this Python code:
939
940 forwardprop_aids = tf.ones_like(output_tensors)
941 with tf.GradientTape() as g:
942 g.watch(forwardprop_aids)
943 grad = backward_function(forwardprop_aids)
944 forward_grads = g.gradient(grad, forwardprop_aids, output_gradients=in_grads)
945 accumulated_gradients_[ID(output_tensors)] = forward_grads
946 */
947 std::unique_ptr<GradientTape<Gradient, BackwardFunction, TapeTensor>> tape(
948 new GradientTape<Gradient, BackwardFunction, TapeTensor>(false));
949 AccumulatorCallState& call_state = call_state_.top();
950 call_state.backward_tape = tape.get();
951 auto pop_backward_tape =
952 gtl::MakeCleanup([&call_state] { call_state.backward_tape = nullptr; });
953 std::vector<Gradient*> forwardprop_aids;
954 std::vector<int64_t> sources;
955 std::unordered_set<int64_t> sources_set;
956 sources.reserve(output_tensors.size());
957 for (const TapeTensor& output_tensor : output_tensors) {
958 // Ownership of `aid` transferred to CallBackwardFunction below.
959 Gradient* aid;
960 if (output_tensor.GetDType() == tensorflow::DT_VARIANT) {
961 // Note: Needs to be zeros rather than ones since there's currently no
962 // ones_like for variants.
963 aid = output_tensor.ZerosLike();
964 } else {
965 // TODO(allenl): Figure out why using zeros_like everywhere causes issues
966 // for some gradient functions and if there's another way to work around
967 // it (e.g. conds instead of ifs). The value shouldn't really matter.
968 TF_RETURN_IF_ERROR(vspace_.BuildOnesLike(output_tensor, &aid));
969 }
970 if (TF_PREDICT_FALSE(aid == nullptr)) {
971 return tensorflow::errors::Internal(
972 "Failed to create ones tensor for tensor ", output_tensor.GetID(),
973 " with dtype ", output_tensor.GetDType());
974 }
975 forwardprop_aids.push_back(aid);
976 int64_t aid_id = vspace_.TensorId(aid);
977 sources.push_back(aid_id);
978 sources_set.insert(aid_id);
979 tape->Watch(aid_id);
980 }
981 std::vector<Gradient*> grad(in_grads.size());
982 auto delete_grad = gtl::MakeCleanup([&grad, this] {
983 for (Gradient* tensor : grad) {
984 this->vspace_.DeleteGradient(tensor);
985 }
986 });
987 {
988 std::vector<int64_t> unneeded_gradients;
989 std::unique_ptr<BackwardFunction, std::function<void(BackwardFunction*)>>
990 backward_function(backward_function_getter(),
991 backward_function_deleter);
992 TF_RETURN_IF_ERROR(vspace_.CallBackwardFunction(
993 op_type, backward_function.get(), unneeded_gradients, forwardprop_aids,
994 absl::MakeSpan(grad)));
995 }
996
997 // Stop the tape from recording
998 pop_backward_tape.release()();
999
1000 std::vector<int64_t> targets;
1001 std::vector<Gradient*> used_in_grads;
1002 // We may end up with slightly fewer elements than we reserve, but grad.size()
1003 // should be a reasonably tight upper bound.
1004 targets.reserve(grad.size());
1005 used_in_grads.reserve(grad.size());
1006 std::unordered_map<int64_t, TapeTensor> sources_that_are_targets;
1007 for (int grad_index = 0, end = grad.size(); grad_index < end; ++grad_index) {
1008 Gradient* grad_tensor = grad[grad_index];
1009 if (grad_tensor != nullptr) {
1010 int64_t tensor_id = vspace_.TensorId(grad_tensor);
1011 targets.push_back(tensor_id);
1012 if (sources_set.find(tensor_id) != sources_set.end()) {
1013 sources_that_are_targets.emplace(
1014 tensor_id, vspace_.TapeTensorFromGradient(grad_tensor));
1015 }
1016 Gradient* in_grad = in_grads[grad_index];
1017 if (in_grad != nullptr) {
1018 // ComputeGradient steals a reference
1019 vspace_.MarkAsResult(in_grad);
1020 }
1021 used_in_grads.push_back(in_grad);
1022 }
1023 }
1024
1025 return tape->ComputeGradient(vspace_, targets, sources,
1026 sources_that_are_targets, used_in_grads,
1027 out_grads);
1028}
1029
1030template <typename Gradient, typename BackwardFunction, typename TapeTensor>
1031Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
1032 const string& op_type, const std::vector<TapeTensor>& input_tensors,
1033 const std::vector<TapeTensor>& output_tensors,
1034 gtl::ArraySlice<int64_t> input_tensor_id,
1035 gtl::ArraySlice<tensorflow::DataType> input_dtypes,
1036 const ForwardFunction<Gradient>* forward_function,
1037 const std::function<BackwardFunction*()>& backward_function_getter,
1038 const std::function<void(BackwardFunction*)>& backward_function_deleter) {
1039 if (call_state_.top().backward_tape != nullptr) {
1040 // If backward_tape is not null, then this call to Accumulate is the result
1041 // of a still-active call to Accumulate which is running operations. We
1042 // forward these operations to backward_tape so the outer Accumulate call
1043 // can do its work.
1044 //
1045 // Rather than re-entering and delegating Accumulate like this, we could
1046 // instead allow ForwardAccumulator some control over the current tape set
1047 // (so it can deactivate itself and activate its GradientTape). Currently
1048 // that is managed by the language binding and would require relatively
1049 // messy callbacks.
1050 call_state_.top().backward_tape->RecordOperation(
1051 op_type, output_tensors, input_tensor_id, input_dtypes,
1052 backward_function_getter, backward_function_deleter);
1053 return OkStatus();
1054 }
1055 if (!ShouldRecord(input_tensor_id, input_dtypes)) {
1056 return OkStatus();
1057 }
1058
1059 // We may need to allocate zero inputs for trainable dtypes we don't have JVPs
1060 // for. Make sure they get cleaned up.
1061 std::vector<Gradient*> new_zeros;
1062 auto delete_new_zeros = gtl::MakeCleanup([&new_zeros, this] {
1063 for (Gradient* tensor : new_zeros) {
1064 this->vspace_.DeleteGradient(tensor);
1065 }
1066 });
1067 std::vector<Gradient*> in_grads;
1068 in_grads.reserve(input_tensors.size());
1069 for (int target_index = 0; target_index < input_tensors.size();
1070 ++target_index) {
1071 const auto current_grad =
1072 accumulated_gradients_.find(input_tensors[target_index].GetID());
1073 if (current_grad == accumulated_gradients_.end()) {
1074 if (IsDtypeTrainable(input_tensors[target_index].GetDType())) {
1075 // ForwardAccumulator defaults to zeros for unwatched Tensors, unlike
1076 // GradientTape which uses ones.
1077 Gradient* zero = input_tensors[target_index].ZerosLike();
1078 new_zeros.push_back(zero);
1079 in_grads.push_back(zero);
1080 } else {
1081 in_grads.push_back(nullptr);
1082 }
1083 } else {
1084 in_grads.push_back(current_grad->second);
1085 }
1086 }
1087
1088 // Avoid infinite recursion. Whichever forward function we run, it'll end up
1089 // executing ops, and we don't want to watch those with this accumulator.
1090 call_state_.emplace(nullptr, true);
1091 auto pop_call_state = gtl::MakeCleanup([this] { this->call_state_.pop(); });
1092
1093 std::vector<Gradient*> forward_grads;
1094 if (forward_function == nullptr) {
1095 // We have no special-cased forward gradient. Fall back to running the
1096 // backward function under a gradient tape.
1097 forward_grads.resize(output_tensors.size());
1098 TF_RETURN_IF_ERROR(ForwardpropFromTape(
1099 op_type, output_tensors, backward_function_getter,
1100 backward_function_deleter, in_grads, absl::MakeSpan(forward_grads)));
1101 } else {
1102 TF_RETURN_IF_ERROR(
1103 (*forward_function)(in_grads, &forward_grads, use_batch_));
1104 }
1105 for (int i = 0; i < forward_grads.size(); ++i) {
1106 if (forward_grads[i] != nullptr) {
1107 int64_t tensor_id = output_tensors[i].GetID();
1108 auto existing = accumulated_gradients_.find(tensor_id);
1109 if (existing != accumulated_gradients_.end()) {
1110 // This is a somewhat odd case to be in, since it means we have two
1111 // operations which supposedly both created the same Tensor. It comes up
1112 // in recompute_grad, where the gradients have the same value. However,
1113 // only the original gradient is connected to everything else, so we
1114 // should still use that.
1115 vspace_.DeleteGradient(forward_grads[i]);
1116 } else {
1117 accumulated_gradients_[output_tensors[i].GetID()] = forward_grads[i];
1118 }
1119 }
1120 }
1121 return OkStatus();
1122}
1123
1124template <typename Gradient, typename BackwardFunction, typename TapeTensor>
1125void ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Watch(
1126 int64_t tensor_id, Gradient* tangent) {
1127 typename std::unordered_map<int64_t, Gradient*>::iterator existing =
1128 accumulated_gradients_.find(tensor_id);
1129 vspace_.MarkAsResult(tangent);
1130 if (existing == accumulated_gradients_.end()) {
1131 accumulated_gradients_.emplace(tensor_id, tangent);
1132 } else {
1133 std::array<Gradient*, 2> to_aggregate;
1134 to_aggregate[0] = tangent;
1135 to_aggregate[1] = existing->second;
1136 // AggregateGradients steals a reference to each of its arguments. We
1137 // MarkAsResult on `tangent` above so we don't steal a reference to it.
1138 existing->second = vspace_.AggregateGradients(to_aggregate);
1139 }
1140}
1141
1142template <typename Gradient, typename BackwardFunction, typename TapeTensor>
1143void ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::DeleteGradient(
1144 int64_t tensor_id) {
1145 auto existing = accumulated_gradients_.find(tensor_id);
1146 if (existing != accumulated_gradients_.end()) {
1147 vspace_.DeleteGradient(existing->second);
1148 accumulated_gradients_.erase(existing);
1149 }
1150}
1151
1152template <typename Gradient, typename BackwardFunction, typename TapeTensor>
1153Gradient* ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::FetchJVP(
1154 int64_t tensor_id) {
1155 auto lookup = accumulated_gradients_.find(tensor_id);
1156 if (lookup == accumulated_gradients_.end()) {
1157 return nullptr;
1158 } else {
1159 return lookup->second;
1160 }
1161}
1162
1163} // namespace eager
1164} // namespace tensorflow
1165
1166#endif // TENSORFLOW_C_EAGER_TAPE_H_
1167