1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_C_EAGER_GRADIENTS_INTERNAL_H_
17#define TENSORFLOW_C_EAGER_GRADIENTS_INTERNAL_H_
18
19#include "tensorflow/c/eager/gradients.h"
20
21namespace tensorflow {
22namespace gradients {
23namespace internal {
24
25// Helper functions which delegate to `AbstractOperation`, update
26// the state of the ForwardOperation and call the tape as appropriate.
27// These APIs are mainly to facilitate testing and are subject to change.
28
29// Records the op name in the `ForwardOperation`.
30Status Reset(AbstractOperation*, const char* op, const char* raw_device_name,
31 ForwardOperation*);
32
33// Records the inputs in the `ForwardOperation`.
34Status AddInput(AbstractOperation*, AbstractTensorHandle*, ForwardOperation*);
35Status AddInputList(AbstractOperation*,
36 absl::Span<AbstractTensorHandle* const> inputs,
37 ForwardOperation*);
38
39// Sets the attrs in the `ForwardOperation`.
40Status SetAttrString(AbstractOperation*, const char* attr_name,
41 const char* data, size_t length, ForwardOperation*);
42Status SetAttrInt(AbstractOperation*, const char* attr_name, int64_t value,
43 ForwardOperation*);
44Status SetAttrFloat(AbstractOperation*, const char* attr_name, float value,
45 ForwardOperation*);
46Status SetAttrBool(AbstractOperation*, const char* attr_name, bool value,
47 ForwardOperation*);
48Status SetAttrType(AbstractOperation*, const char* attr_name, DataType value,
49 ForwardOperation*);
50Status SetAttrShape(AbstractOperation*, const char* attr_name,
51 const int64_t* dims, const int num_dims, ForwardOperation*);
52Status SetAttrFunction(AbstractOperation*, const char* attr_name,
53 const AbstractOperation* value, ForwardOperation*);
54Status SetAttrFunctionName(AbstractOperation*, const char* attr_name,
55 const char* value, size_t length, ForwardOperation*);
56Status SetAttrTensor(AbstractOperation*, const char* attr_name,
57 AbstractTensorInterface* tensor, ForwardOperation*);
58Status SetAttrStringList(AbstractOperation*, const char* attr_name,
59 const void* const* values, const size_t* lengths,
60 int num_values, ForwardOperation*);
61Status SetAttrFloatList(AbstractOperation*, const char* attr_name,
62 const float* values, int num_values, ForwardOperation*);
63Status SetAttrIntList(AbstractOperation*, const char* attr_name,
64 const int64_t* values, int num_values, ForwardOperation*);
65Status SetAttrTypeList(AbstractOperation*, const char* attr_name,
66 const DataType* values, int num_values,
67 ForwardOperation*);
68Status SetAttrBoolList(AbstractOperation*, const char* attr_name,
69 const unsigned char* values, int num_values,
70 ForwardOperation*);
71Status SetAttrShapeList(AbstractOperation*, const char* attr_name,
72 const int64_t** dims, const int* num_dims,
73 int num_values, ForwardOperation*);
74Status SetAttrFunctionList(AbstractOperation*, const char* attr_name,
75 absl::Span<const AbstractOperation*> values,
76 ForwardOperation*);
77
78// Make the call to `Tape::RecordOperation`.
79Status Execute(AbstractOperation*, AbstractContext*,
80 absl::Span<AbstractTensorHandle*> retvals, int* num_retvals,
81 ForwardOperation*, Tape*, const GradientRegistry&);
82
83} // namespace internal
84} // namespace gradients
85} // namespace tensorflow
86
87#endif // TENSORFLOW_C_EAGER_GRADIENTS_INTERNAL_H_
88