1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_C_EAGER_GRADIENTS_H_
17#define TENSORFLOW_C_EAGER_GRADIENTS_H_
18
19#include "absl/container/flat_hash_map.h"
20#include "tensorflow/c/eager/abstract_context.h"
21#include "tensorflow/c/eager/abstract_tensor_handle.h"
22#include "tensorflow/c/eager/tape.h"
23#include "tensorflow/core/common_runtime/eager/attr_builder.h"
24
25namespace tensorflow {
26namespace gradients {
27
28// =============== Experimental C++ API for computing gradients ===============
29
30// Sample gradient function:
31//
32// class AddGradientFunction : public GradientFunction {
33// public:
34// Status Compute(Context* ctx,
35// absl::Span<AbstractTensorHandle* const> grad_inputs,
36// absl::Span<AbstractTensorHandle*> grad_outputs) override {
37// grad_outputs[0] = grad_inputs[0];
38// grad_outputs[1] = grad_inputs[0];
39// grad_outputs[0]->Ref();
40// grad_outputs[1]->Ref();
41// return OkStatus();
42// }
43// ~AddGradientFunction() override {}
44// };
45//
46// GradientFunction* AddRegisterer(const ForwardOperation& op) {
47// // More complex gradient functions can use inputs/attrs etc. from the
48// // forward `op`.
49// return new AddGradientFunction;
50// }
51//
52// Status RegisterGradients(GradientRegistry* registry) {
53// return registry->Register("Add", AddRegisterer);
54// }
55class GradientFunction {
56 public:
57 virtual Status Compute(AbstractContext* ctx,
58 absl::Span<AbstractTensorHandle* const> grad_outputs,
59 absl::Span<AbstractTensorHandle*> grad_inputs) = 0;
60 virtual ~GradientFunction() {}
61};
62
63// Metadata from the forward operation that is made available to the
64// gradient registerer to instantiate a GradientFunction.
65struct ForwardOperation {
66 public:
67 string op_name;
68 std::vector<AbstractTensorHandle*> inputs;
69 std::vector<AbstractTensorHandle*> outputs;
70 std::vector<int64_t> skip_input_indices;
71 AttrBuilder attrs;
72};
73
74using GradientFunctionFactory =
75 std::function<GradientFunction*(const ForwardOperation& op)>;
76
77// Map from op name to a `GradientFunctionFactory`.
78class GradientRegistry {
79 public:
80 Status Register(const string& op,
81 GradientFunctionFactory gradient_function_factory);
82 Status Lookup(const ForwardOperation& op,
83 std::unique_ptr<GradientFunction>* gradient_function) const;
84
85 private:
86 absl::flat_hash_map<string, GradientFunctionFactory> registry_;
87};
88
89// TODO(srbs): Figure out if we can avoid declaring this in the public header.
90// Wrapper for a tensor output of an operation executing under a tape.
91//
92// `GetID` returns a unique id for the wrapped tensor which is used to maintain
93// a map (`tensorflow::eager::TensorTape`) from the wrapped tensor to the id of
94// the op that produced it (or -1 if this tensor was watched using
95// `GradientTape::Watch`.) The op_id is simply a unique index assigned to each
96// op executed under the tape. A separate map (`tensorflow::eager::OpTape`)
97// maintains the map from `op_id` to a `OpTapeEntry` which stores the `op_type`,
98// inputs and outputs and the gradient function These data structures combined
99// allow us to trace the data dependencies between operations and hence compute
100// gradients.
101//
102// `ZerosLike` is not expected to be called and returns a nullptr. The creation
103// of default zeros grads is handled by the `DefaultGradientFunction` registered
104// for each op.
105// TODO(srbs): We need to define `ZerosLike` here to keep the compiler happy.
106// Figure out a way to avoid this.
107// TODO(srbs): Should ZerosLike check-fail instead of returning nullptr?
108class TapeTensor {
109 public:
110 explicit TapeTensor(AbstractTensorHandle* handle);
111 TapeTensor(const TapeTensor& other);
112 ~TapeTensor();
113
114 int64_t GetID() const;
115 tensorflow::DataType GetDType() const;
116
117 AbstractTensorHandle* ZerosLike() const;
118
119 AbstractTensorHandle* GetHandle() const;
120
121 private:
122 AbstractTensorHandle* handle_;
123};
124
125// A tracing/immediate-execution agnostic tape.
126//
127// Gradient functions defined for this tape must support handling null incoming
128// gradients.
129class Tape : protected eager::GradientTape<AbstractTensorHandle,
130 GradientFunction, TapeTensor> {
131 public:
132 using GradientTape<AbstractTensorHandle, GradientFunction,
133 TapeTensor>::GradientTape;
134 // Returns whether the tape is persistent, i.e., whether the tape will hold
135 // onto its internal state after a call to `ComputeGradient`.
136 using GradientTape<AbstractTensorHandle, GradientFunction,
137 TapeTensor>::IsPersistent;
138
139 // Adds this tensor to the list of watched tensors.
140 //
141 // This is a no-op if the tensor is already being watched either from an
142 // earlier call to `GradientTape::Watch` or being an output of an op with
143 // watched inputs.
144 void Watch(const AbstractTensorHandle*);
145 // Records an operation with given inputs and outputs
146 // on the tape and marks all its outputs as watched if at
147 // least one input of the op is watched and has a trainable dtype.
148 // op_name is optional and is used for debugging only.
149 void RecordOperation(absl::Span<AbstractTensorHandle* const> inputs,
150 absl::Span<AbstractTensorHandle* const> outputs,
151 GradientFunction* gradient_function,
152 const string& op_name = "");
153 // Returns whether any tensor in a list of tensors is being watched and has
154 // a trainable dtype.
155 bool ShouldRecord(
156 absl::Span<const AbstractTensorHandle* const> tensors) const;
157 // Unwatches this tensor on the tape. Mainly used for cleanup when deleting
158 // eager tensors.
159 void DeleteTrace(const AbstractTensorHandle*);
160
161 // Consumes the internal state of the tape (so cannot be called more than
162 // once unless the tape is persistent) and produces the gradient of the target
163 // tensors with respect to the source tensors. The output gradients are used
164 // if not empty and not null. The result is populated with one tensor per
165 // target element.
166 Status ComputeGradient(
167 AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> targets,
168 absl::Span<AbstractTensorHandle* const> sources,
169 absl::Span<AbstractTensorHandle* const> output_gradients,
170 absl::Span<AbstractTensorHandle*> result);
171};
172
173} // namespace gradients
174} // namespace tensorflow
175
176#endif // TENSORFLOW_C_EAGER_GRADIENTS_H_
177