1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_C_EAGER_GRADIENTS_H_ |
17 | #define TENSORFLOW_C_EAGER_GRADIENTS_H_ |
18 | |
19 | #include "absl/container/flat_hash_map.h" |
20 | #include "tensorflow/c/eager/abstract_context.h" |
21 | #include "tensorflow/c/eager/abstract_tensor_handle.h" |
22 | #include "tensorflow/c/eager/tape.h" |
23 | #include "tensorflow/core/common_runtime/eager/attr_builder.h" |
24 | |
25 | namespace tensorflow { |
26 | namespace gradients { |
27 | |
28 | // =============== Experimental C++ API for computing gradients =============== |
29 | |
30 | // Sample gradient function: |
31 | // |
32 | // class AddGradientFunction : public GradientFunction { |
33 | // public: |
34 | // Status Compute(Context* ctx, |
35 | // absl::Span<AbstractTensorHandle* const> grad_inputs, |
36 | // absl::Span<AbstractTensorHandle*> grad_outputs) override { |
37 | // grad_outputs[0] = grad_inputs[0]; |
38 | // grad_outputs[1] = grad_inputs[0]; |
39 | // grad_outputs[0]->Ref(); |
40 | // grad_outputs[1]->Ref(); |
41 | // return OkStatus(); |
42 | // } |
43 | // ~AddGradientFunction() override {} |
44 | // }; |
45 | // |
46 | // GradientFunction* AddRegisterer(const ForwardOperation& op) { |
47 | // // More complex gradient functions can use inputs/attrs etc. from the |
48 | // // forward `op`. |
49 | // return new AddGradientFunction; |
50 | // } |
51 | // |
52 | // Status RegisterGradients(GradientRegistry* registry) { |
53 | // return registry->Register("Add", AddRegisterer); |
54 | // } |
55 | class GradientFunction { |
56 | public: |
57 | virtual Status Compute(AbstractContext* ctx, |
58 | absl::Span<AbstractTensorHandle* const> grad_outputs, |
59 | absl::Span<AbstractTensorHandle*> grad_inputs) = 0; |
60 | virtual ~GradientFunction() {} |
61 | }; |
62 | |
63 | // Metadata from the forward operation that is made available to the |
64 | // gradient registerer to instantiate a GradientFunction. |
65 | struct ForwardOperation { |
66 | public: |
67 | string op_name; |
68 | std::vector<AbstractTensorHandle*> inputs; |
69 | std::vector<AbstractTensorHandle*> outputs; |
70 | std::vector<int64_t> skip_input_indices; |
71 | AttrBuilder attrs; |
72 | }; |
73 | |
74 | using GradientFunctionFactory = |
75 | std::function<GradientFunction*(const ForwardOperation& op)>; |
76 | |
77 | // Map from op name to a `GradientFunctionFactory`. |
78 | class GradientRegistry { |
79 | public: |
80 | Status Register(const string& op, |
81 | GradientFunctionFactory gradient_function_factory); |
82 | Status Lookup(const ForwardOperation& op, |
83 | std::unique_ptr<GradientFunction>* gradient_function) const; |
84 | |
85 | private: |
86 | absl::flat_hash_map<string, GradientFunctionFactory> registry_; |
87 | }; |
88 | |
89 | // TODO(srbs): Figure out if we can avoid declaring this in the public header. |
90 | // Wrapper for a tensor output of an operation executing under a tape. |
91 | // |
92 | // `GetID` returns a unique id for the wrapped tensor which is used to maintain |
93 | // a map (`tensorflow::eager::TensorTape`) from the wrapped tensor to the id of |
94 | // the op that produced it (or -1 if this tensor was watched using |
95 | // `GradientTape::Watch`.) The op_id is simply a unique index assigned to each |
96 | // op executed under the tape. A separate map (`tensorflow::eager::OpTape`) |
97 | // maintains the map from `op_id` to a `OpTapeEntry` which stores the `op_type`, |
98 | // inputs and outputs and the gradient function These data structures combined |
99 | // allow us to trace the data dependencies between operations and hence compute |
100 | // gradients. |
101 | // |
102 | // `ZerosLike` is not expected to be called and returns a nullptr. The creation |
103 | // of default zeros grads is handled by the `DefaultGradientFunction` registered |
104 | // for each op. |
105 | // TODO(srbs): We need to define `ZerosLike` here to keep the compiler happy. |
106 | // Figure out a way to avoid this. |
107 | // TODO(srbs): Should ZerosLike check-fail instead of returning nullptr? |
108 | class TapeTensor { |
109 | public: |
110 | explicit TapeTensor(AbstractTensorHandle* handle); |
111 | TapeTensor(const TapeTensor& other); |
112 | ~TapeTensor(); |
113 | |
114 | int64_t GetID() const; |
115 | tensorflow::DataType GetDType() const; |
116 | |
117 | AbstractTensorHandle* ZerosLike() const; |
118 | |
119 | AbstractTensorHandle* GetHandle() const; |
120 | |
121 | private: |
122 | AbstractTensorHandle* handle_; |
123 | }; |
124 | |
125 | // A tracing/immediate-execution agnostic tape. |
126 | // |
127 | // Gradient functions defined for this tape must support handling null incoming |
128 | // gradients. |
129 | class Tape : protected eager::GradientTape<AbstractTensorHandle, |
130 | GradientFunction, TapeTensor> { |
131 | public: |
132 | using GradientTape<AbstractTensorHandle, GradientFunction, |
133 | TapeTensor>::GradientTape; |
134 | // Returns whether the tape is persistent, i.e., whether the tape will hold |
135 | // onto its internal state after a call to `ComputeGradient`. |
136 | using GradientTape<AbstractTensorHandle, GradientFunction, |
137 | TapeTensor>::IsPersistent; |
138 | |
139 | // Adds this tensor to the list of watched tensors. |
140 | // |
141 | // This is a no-op if the tensor is already being watched either from an |
142 | // earlier call to `GradientTape::Watch` or being an output of an op with |
143 | // watched inputs. |
144 | void Watch(const AbstractTensorHandle*); |
145 | // Records an operation with given inputs and outputs |
146 | // on the tape and marks all its outputs as watched if at |
147 | // least one input of the op is watched and has a trainable dtype. |
148 | // op_name is optional and is used for debugging only. |
149 | void RecordOperation(absl::Span<AbstractTensorHandle* const> inputs, |
150 | absl::Span<AbstractTensorHandle* const> outputs, |
151 | GradientFunction* gradient_function, |
152 | const string& op_name = "" ); |
153 | // Returns whether any tensor in a list of tensors is being watched and has |
154 | // a trainable dtype. |
155 | bool ShouldRecord( |
156 | absl::Span<const AbstractTensorHandle* const> tensors) const; |
157 | // Unwatches this tensor on the tape. Mainly used for cleanup when deleting |
158 | // eager tensors. |
159 | void DeleteTrace(const AbstractTensorHandle*); |
160 | |
161 | // Consumes the internal state of the tape (so cannot be called more than |
162 | // once unless the tape is persistent) and produces the gradient of the target |
163 | // tensors with respect to the source tensors. The output gradients are used |
164 | // if not empty and not null. The result is populated with one tensor per |
165 | // target element. |
166 | Status ComputeGradient( |
167 | AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> targets, |
168 | absl::Span<AbstractTensorHandle* const> sources, |
169 | absl::Span<AbstractTensorHandle* const> output_gradients, |
170 | absl::Span<AbstractTensorHandle*> result); |
171 | }; |
172 | |
173 | } // namespace gradients |
174 | } // namespace tensorflow |
175 | |
176 | #endif // TENSORFLOW_C_EAGER_GRADIENTS_H_ |
177 | |