1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/c/eager/gradients.h"
16
17#include "absl/strings/str_cat.h"
18#include "tensorflow/c/eager/abstract_tensor_handle.h"
19#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
20#include "tensorflow/c/eager/gradients_internal.h"
21#include "tensorflow/core/common_runtime/eager/attr_builder.h"
22#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
23#include "tensorflow/core/platform/errors.h"
24
25namespace tensorflow {
26namespace gradients {
27namespace {
28
29// TODO(b/172558015): Using the pointer address as the identifier for the tensor
30// may lead to collisions. Introduce another way to get a unique id for this
31// tensor.
32int64_t ToId(const AbstractTensorHandle* t) {
33 return static_cast<int64_t>(reinterpret_cast<uintptr_t>(t));
34}
35
36Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* t,
37 AbstractTensorHandle** result) {
38 AbstractOperationPtr op(ctx->CreateOperation());
39 TF_RETURN_IF_ERROR(op->Reset("ZerosLike", /*raw_device_name=*/nullptr));
40 if (isa<tracing::TracingOperation>(op.get())) {
41 TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
42 absl::StrCat("ZerosLike", ToId(t)).c_str()));
43 }
44 TF_RETURN_IF_ERROR(op->AddInput(t));
45 int num_outputs = 1;
46 std::vector<AbstractTensorHandle*> outputs(num_outputs);
47 TF_RETURN_IF_ERROR(
48 op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
49 *result = outputs[0];
50 return OkStatus();
51}
52} // namespace
53
54Status GradientRegistry::Register(
55 const string& op_name, GradientFunctionFactory gradient_function_factory) {
56 auto iter = registry_.find(op_name);
57 if (iter != registry_.end()) {
58 const string error_msg = "Gradient already exists for op: " + op_name + ".";
59 return errors::AlreadyExists(error_msg);
60 }
61 registry_.insert({op_name, gradient_function_factory});
62 return OkStatus();
63}
64Status GradientRegistry::Lookup(
65 const ForwardOperation& op,
66 std::unique_ptr<GradientFunction>* gradient_function) const {
67 auto iter = registry_.find(op.op_name);
68 if (iter == registry_.end()) {
69 const string error_msg = "No gradient defined for op: " + op.op_name + ".";
70 return errors::NotFound(error_msg);
71 }
72 gradient_function->reset(iter->second(op));
73 return OkStatus();
74}
75
76TapeTensor::TapeTensor(AbstractTensorHandle* handle) : handle_(handle) {
77 handle_->Ref();
78}
79TapeTensor::TapeTensor(const TapeTensor& other) {
80 handle_ = other.handle_;
81 handle_->Ref();
82}
83TapeTensor::~TapeTensor() { handle_->Unref(); }
84
85int64_t TapeTensor::GetID() const { return ToId(handle_); }
86
87tensorflow::DataType TapeTensor::GetDType() const {
88 return handle_->DataType();
89}
90AbstractTensorHandle* TapeTensor::GetHandle() const { return handle_; }
91
92AbstractTensorHandle* TapeTensor::ZerosLike() const { return nullptr; }
93
94class TapeVSpace
95 : public eager::VSpace<AbstractTensorHandle, GradientFunction, TapeTensor> {
96 public:
97 explicit TapeVSpace(AbstractContext* ctx) : ctx_(ctx) {}
98 ~TapeVSpace() override {}
99
100 // Returns the number of elements in the gradient tensor.
101 int64_t NumElements(AbstractTensorHandle* tensor) const override;
102
103 // Consumes references to the tensors in the gradient_tensors list and returns
104 // a tensor with the result.
105 AbstractTensorHandle* AggregateGradients(
106 gtl::ArraySlice<AbstractTensorHandle*> gradient_tensors) const override;
107
108 // Calls the passed-in backward function.
109 // op_type is the op's name provided in RecordOperation.
110 Status CallBackwardFunction(
111 const string& op_type, GradientFunction* gradient_function,
112 const std::vector<int64_t>& unneeded_gradients,
113 gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
114 absl::Span<AbstractTensorHandle*> result) const override;
115
116 // Builds a tensor filled with ones with the same shape and dtype as `t`.
117 Status BuildOnesLike(const TapeTensor& t,
118 AbstractTensorHandle** result) const override;
119
120 // Looks up the ID of a Gradient.
121 int64_t TensorId(AbstractTensorHandle* tensor) const override;
122
123 // Converts a Gradient to a TapeTensor.
124 TapeTensor TapeTensorFromGradient(AbstractTensorHandle* g) const override;
125
126 void MarkAsResult(AbstractTensorHandle* gradient) const override;
127
128 void DeleteGradient(AbstractTensorHandle* gradient) const override;
129
130 private:
131 // The context where the aggregation op `Add` is to be created.
132 AbstractContext* ctx_;
133};
134
135// Returns the number of elements in the gradient tensor.
136int64_t TapeVSpace::NumElements(AbstractTensorHandle* tensor) const {
137 // TODO(srbs): It seems like this is used only for performance optimization
138 // and not for correctness. The only downside of keeping this 1 seems to be
139 // that the gradient accumulation is unbounded and we will never
140 // aggressively aggregate accumulated gradients to recover memory.
141 // Revisit and fix.
142 return 1;
143}
144
145// Consumes references to the tensors in the gradient_tensors list and returns
146// a tensor with the result.
147AbstractTensorHandle* TapeVSpace::AggregateGradients(
148 gtl::ArraySlice<AbstractTensorHandle*> gradient_tensors) const {
149 if (gradient_tensors.size() == 1) {
150 return gradient_tensors[0];
151 }
152
153 AbstractOperationPtr op(ctx_->CreateOperation());
154 Status s = op->Reset("AddN", /*raw_device_name=*/nullptr);
155 if (!s.ok()) {
156 return nullptr;
157 }
158 s = op->AddInputList(gradient_tensors);
159 if (!s.ok()) {
160 return nullptr;
161 }
162
163 int num_outputs = 1;
164 std::vector<AbstractTensorHandle*> outputs(num_outputs);
165 s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
166 if (!s.ok()) {
167 return nullptr;
168 }
169 return outputs[0];
170}
171
172// Calls the passed-in backward function.
173// op_type is the op's name provided in RecordOperation.
174Status TapeVSpace::CallBackwardFunction(
175 const string& op_type, GradientFunction* gradient_function,
176 const std::vector<int64_t>& unneeded_gradients,
177 gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
178 absl::Span<AbstractTensorHandle*> result) const {
179 if (gradient_function == nullptr) {
180 return errors::InvalidArgument(
181 "Provided null gradient_function for '", op_type, "'.\n",
182 "If the intent is to treat this op as non-differentiable consider "
183 "using RegisterNotDifferentiable or "
184 "NotDifferentiableGradientFunction.");
185 }
186 return gradient_function->Compute(ctx_, output_gradients, result);
187}
188
189Status TapeVSpace::BuildOnesLike(const TapeTensor& t,
190 AbstractTensorHandle** result) const {
191 AbstractOperationPtr op(ctx_->CreateOperation());
192 TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr));
193 if (isa<tracing::TracingOperation>(op.get())) {
194 TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
195 absl::StrCat("OnesLike", ToId(t.GetHandle())).c_str()));
196 }
197 TF_RETURN_IF_ERROR(op->AddInput(t.GetHandle()));
198 int num_outputs = 1;
199 std::vector<AbstractTensorHandle*> outputs(num_outputs);
200 TF_RETURN_IF_ERROR(
201 op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
202 *result = outputs[0];
203 return OkStatus();
204}
205
206// Looks up the ID of a Gradient.
207int64_t TapeVSpace::TensorId(AbstractTensorHandle* tensor) const {
208 return ToId(tensor);
209}
210
211// Converts a Gradient to a TapeTensor.
212TapeTensor TapeVSpace::TapeTensorFromGradient(AbstractTensorHandle* g) const {
213 return TapeTensor(g);
214}
215
216void TapeVSpace::MarkAsResult(AbstractTensorHandle* gradient) const {}
217
218void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const {
219 gradient->Unref();
220}
221
222void Tape::Watch(const AbstractTensorHandle* t) {
223 GradientTape::Watch(ToId(t));
224}
225void Tape::RecordOperation(absl::Span<AbstractTensorHandle* const> inputs,
226 absl::Span<AbstractTensorHandle* const> outputs,
227 GradientFunction* gradient_function,
228 const string& op_name) {
229 std::vector<int64_t> input_ids(inputs.size());
230 std::vector<tensorflow::DataType> input_dtypes(inputs.size());
231 for (int i = 0; i < inputs.size(); i++) {
232 input_ids[i] = ToId(inputs[i]);
233 input_dtypes[i] = inputs[i]->DataType();
234 }
235 std::vector<TapeTensor> tape_tensors;
236 tape_tensors.reserve(outputs.size());
237 for (auto t : outputs) {
238 tape_tensors.push_back(TapeTensor(t));
239 }
240 GradientTape::RecordOperation(
241 op_name, tape_tensors, input_ids, input_dtypes,
242 [gradient_function]() -> GradientFunction* { return gradient_function; },
243 [](GradientFunction* ptr) {
244 if (ptr) {
245 delete ptr;
246 }
247 });
248}
249bool Tape::ShouldRecord(
250 absl::Span<const AbstractTensorHandle* const> tensors) const {
251 std::vector<int64_t> tensor_ids(tensors.size());
252 std::vector<tensorflow::DataType> tensor_dtypes(tensors.size());
253 for (int i = 0; i < tensors.size(); i++) {
254 tensor_ids[i] = ToId(tensors[i]);
255 tensor_dtypes[i] = tensors[i]->DataType();
256 }
257 return GradientTape::ShouldRecord(tensor_ids, tensor_dtypes);
258}
259void Tape::DeleteTrace(const AbstractTensorHandle* t) {
260 GradientTape::DeleteTrace(ToId(t));
261}
262
263std::vector<int64_t> MakeTensorIDList(
264 absl::Span<AbstractTensorHandle* const> tensors) {
265 std::vector<int64_t> ids(tensors.size());
266 for (int i = 0; i < tensors.size(); i++) {
267 ids[i] = ToId(tensors[i]);
268 }
269 return ids;
270}
271
272Status Tape::ComputeGradient(
273 AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> targets,
274 absl::Span<AbstractTensorHandle* const> sources,
275 absl::Span<AbstractTensorHandle* const> output_gradients,
276 absl::Span<AbstractTensorHandle*> result) {
277 TapeVSpace vspace(ctx);
278 std::vector<int64_t> target_tensor_ids = MakeTensorIDList(targets);
279 std::vector<int64_t> source_tensor_ids = MakeTensorIDList(sources);
280 tensorflow::gtl::FlatSet<int64_t> sources_set(source_tensor_ids.begin(),
281 source_tensor_ids.end());
282 std::unordered_map<int64_t, TapeTensor> sources_that_are_targets;
283 for (int i = 0; i < target_tensor_ids.size(); ++i) {
284 int64_t target_id = target_tensor_ids[i];
285 if (sources_set.find(target_id) != sources_set.end()) {
286 auto tensor = targets[i];
287 sources_that_are_targets.insert(
288 std::make_pair(target_id, TapeTensor(tensor)));
289 }
290 }
291
292 TF_RETURN_IF_ERROR(GradientTape::ComputeGradient(
293 vspace, target_tensor_ids, source_tensor_ids, sources_that_are_targets,
294 output_gradients, result, /*build_default_zeros_grads*/ false));
295 return OkStatus();
296}
297
298// Helper functions which delegate to `AbstractOperation`, update
299// the state of the ForwardOperation and call the tape as appropriate.
300// These APIs are mainly to facilitate testing and are subject to change.
301namespace internal {
302Status Reset(AbstractOperation* op_, const char* op,
303 const char* raw_device_name, ForwardOperation* forward_op_) {
304 forward_op_->op_name = op;
305 forward_op_->attrs.Reset(op);
306 return op_->Reset(op, raw_device_name);
307}
308Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input,
309 ForwardOperation* forward_op_) {
310 TF_RETURN_IF_ERROR(op_->AddInput(input));
311 forward_op_->inputs.push_back(input);
312 return OkStatus();
313}
314Status AddInputList(AbstractOperation* op_,
315 absl::Span<AbstractTensorHandle* const> inputs,
316 ForwardOperation* forward_op_) {
317 TF_RETURN_IF_ERROR(op_->AddInputList(inputs));
318 for (auto input : inputs) {
319 forward_op_->inputs.push_back(input);
320 }
321 return OkStatus();
322}
323
324Status SetAttrString(AbstractOperation* op_, const char* attr_name,
325 const char* data, size_t length,
326 ForwardOperation* forward_op_) {
327 forward_op_->attrs.Set(attr_name, StringPiece(data, length));
328 return op_->SetAttrString(attr_name, data, length);
329}
330Status SetAttrInt(AbstractOperation* op_, const char* attr_name, int64_t value,
331 ForwardOperation* forward_op_) {
332 forward_op_->attrs.Set(attr_name, static_cast<int64_t>(value));
333 return op_->SetAttrInt(attr_name, value);
334}
335Status SetAttrFloat(AbstractOperation* op_, const char* attr_name, float value,
336 ForwardOperation* forward_op_) {
337 forward_op_->attrs.Set(attr_name, value);
338 return op_->SetAttrFloat(attr_name, value);
339}
340Status SetAttrBool(AbstractOperation* op_, const char* attr_name, bool value,
341 ForwardOperation* forward_op_) {
342 forward_op_->attrs.Set(attr_name, value);
343 return op_->SetAttrBool(attr_name, value);
344}
345Status SetAttrType(AbstractOperation* op_, const char* attr_name,
346 DataType value, ForwardOperation* forward_op_) {
347 forward_op_->attrs.Set(attr_name, value);
348 return op_->SetAttrType(attr_name, value);
349}
350Status SetAttrShape(AbstractOperation* op_, const char* attr_name,
351 const int64_t* dims, const int num_dims,
352 ForwardOperation* forward_op_) {
353 if (num_dims > TensorShape::MaxDimensions()) {
354 return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
355 num_dims,
356 " dimensions which is over the limit of ",
357 TensorShape::MaxDimensions(), ".");
358 }
359 TensorShapeProto proto;
360 if (num_dims < 0) {
361 proto.set_unknown_rank(true);
362 } else {
363 for (int d = 0; d < num_dims; ++d) {
364 proto.add_dim()->set_size(dims[d]);
365 }
366 }
367
368 forward_op_->attrs.Set(attr_name, proto);
369 return op_->SetAttrShape(attr_name, dims, num_dims);
370}
371Status SetAttrFunction(AbstractOperation* op_, const char* attr_name,
372 const AbstractOperation* value,
373 ForwardOperation* forward_op_) {
374 return tensorflow::errors::Unimplemented(
375 "SetAttrFunction has not been implemented yet.");
376}
377Status SetAttrFunctionName(AbstractOperation* op_, const char* attr_name,
378 const char* value, size_t length,
379 ForwardOperation* forward_op_) {
380 return tensorflow::errors::Unimplemented(
381 "SetAttrFunctionName has not been implemented "
382 "yet.");
383}
384Status SetAttrTensor(AbstractOperation* op_, const char* attr_name,
385 AbstractTensorInterface* tensor,
386 ForwardOperation* forward_op_) {
387 return tensorflow::errors::Unimplemented(
388 "SetAttrTensor has not been implemented yet.");
389}
390Status SetAttrStringList(AbstractOperation* op_, const char* attr_name,
391 const void* const* values, const size_t* lengths,
392 int num_values, ForwardOperation* forward_op_) {
393 std::vector<StringPiece> v(num_values);
394 for (int i = 0; i < num_values; ++i) {
395 v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
396 }
397 forward_op_->attrs.Set(attr_name, v);
398 return op_->SetAttrStringList(attr_name, values, lengths, num_values);
399}
400Status SetAttrFloatList(AbstractOperation* op_, const char* attr_name,
401 const float* values, int num_values,
402 ForwardOperation* forward_op_) {
403 forward_op_->attrs.Set(attr_name,
404 gtl::ArraySlice<const float>(values, num_values));
405 return op_->SetAttrFloatList(attr_name, values, num_values);
406}
407Status SetAttrIntList(AbstractOperation* op_, const char* attr_name,
408 const int64_t* values, int num_values,
409 ForwardOperation* forward_op_) {
410 forward_op_->attrs.Set(
411 attr_name, gtl::ArraySlice<const int64_t>(
412 reinterpret_cast<const int64_t*>(values), num_values));
413 return op_->SetAttrIntList(attr_name, values, num_values);
414}
415Status SetAttrTypeList(AbstractOperation* op_, const char* attr_name,
416 const DataType* values, int num_values,
417 ForwardOperation* forward_op_) {
418 forward_op_->attrs.Set(attr_name,
419 gtl::ArraySlice<const DataType>(values, num_values));
420 return op_->SetAttrTypeList(attr_name, values, num_values);
421}
422Status SetAttrBoolList(AbstractOperation* op_, const char* attr_name,
423 const unsigned char* values, int num_values,
424 ForwardOperation* forward_op_) {
425 std::unique_ptr<bool[]> b(new bool[num_values]);
426 for (int i = 0; i < num_values; ++i) {
427 b[i] = values[i];
428 }
429 forward_op_->attrs.Set(attr_name,
430 gtl::ArraySlice<const bool>(b.get(), num_values));
431 return op_->SetAttrBoolList(attr_name, values, num_values);
432}
433Status SetAttrShapeList(AbstractOperation* op_, const char* attr_name,
434 const int64_t** dims, const int* num_dims,
435 int num_values, ForwardOperation* forward_op_) {
436 std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
437 for (int i = 0; i < num_values; ++i) {
438 const auto num_dims_i = num_dims[i];
439
440 if (num_dims_i > TensorShape::MaxDimensions()) {
441 return errors::InvalidArgument(
442 strings::StrCat("Value specified for `", attr_name, "` has ",
443 num_dims_i, " dimensions which is over the limit of ",
444 TensorShape::MaxDimensions(), "."));
445 }
446 if (num_dims_i < 0) {
447 proto[i].set_unknown_rank(true);
448 } else {
449 const int64_t* dims_i = dims[i];
450 auto proto_i = &proto[i];
451 for (int d = 0; d < num_dims_i; ++d) {
452 proto_i->add_dim()->set_size(dims_i[d]);
453 }
454 }
455 }
456 forward_op_->attrs.Set(
457 attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
458 return op_->SetAttrShapeList(attr_name, dims, num_dims, num_values);
459}
460Status SetAttrFunctionList(AbstractOperation* op_, const char* attr_name,
461 absl::Span<const AbstractOperation*> values,
462 ForwardOperation* forward_op_) {
463 return tensorflow::errors::Unimplemented(
464 "SetAttrFunctionList has not been "
465 "implemented yet.");
466}
467Status Execute(AbstractOperation* op_, AbstractContext* ctx,
468 absl::Span<AbstractTensorHandle*> retvals, int* num_retvals,
469 ForwardOperation* forward_op_, Tape* tape,
470 const GradientRegistry& registry) {
471 TF_RETURN_IF_ERROR(op_->Execute(retvals, num_retvals));
472 for (int i = 0; i < *num_retvals; i++) {
473 // TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
474 forward_op_->outputs.push_back(retvals[i]);
475 }
476 // TODO(b/166669239): This is needed to support AttrBuilder::Get for string
477 // attributes. Number type attrs and DataType attrs work fine without this.
478 // Consider getting rid of this and making the behavior between number types
479 // and string consistent.
480 forward_op_->attrs.BuildNodeDef();
481 std::unique_ptr<GradientFunction> gradient_fn;
482 TF_RETURN_IF_ERROR(registry.Lookup(*forward_op_, &gradient_fn));
483 tape->RecordOperation(forward_op_->inputs, retvals, gradient_fn.release(),
484 op_->Name());
485 return OkStatus();
486}
487} // namespace internal
488
489} // namespace gradients
490} // namespace tensorflow
491