1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_ |
16 | #define TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_ |
17 | |
18 | #include <memory> |
19 | |
20 | #include "absl/types/span.h" |
21 | #include "tensorflow/c/eager/abstract_tensor_handle.h" |
22 | #include "tensorflow/c/tensor_interface.h" |
23 | #include "tensorflow/core/framework/tensor_shape.h" |
24 | #include "tensorflow/core/framework/types.pb.h" |
25 | #include "tensorflow/core/platform/status.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | // Abstract interface to an operation. |
30 | // This interface allows building and executing an operation in either |
31 | // tracing or immediate execution mode. |
32 | class AbstractOperation { |
33 | protected: |
34 | enum AbstractOperationKind { |
35 | kGraph, |
36 | kMlir, |
37 | kEager, |
38 | kTfrt, |
39 | kTape, |
40 | kOpHandler |
41 | }; |
42 | explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {} |
43 | virtual ~AbstractOperation() {} |
44 | |
45 | public: |
46 | AbstractOperationKind getKind() const { return kind_; } |
47 | |
48 | // Release any underlying resources, including the interface object. |
49 | // |
50 | // WARNING: The destructor of this class is marked as protected to disallow |
51 | // clients from directly destroying this object since it may manage it's own |
52 | // lifetime through ref counting. Thus this must be allocated on the heap and |
53 | // clients MUST call Release() in order to destroy an instance of this class. |
54 | virtual void Release() = 0; |
55 | |
56 | virtual Status Reset(const char* op, const char* raw_device_name) = 0; |
57 | |
58 | virtual const string& Name() const = 0; |
59 | |
60 | // Returns the operation's device name. |
61 | // |
62 | // The value returned may be different from the one set by SetDeviceName, but |
63 | // it will be compatible with it: the name will be updated by device placement |
64 | // logic to refer to the specific device chosen. |
65 | // |
66 | // Example: If one calls `op->SetDeviceName("/device:GPU")`, the value |
67 | // returned by DeviceName should be "/device:GPU:*" until a particular GPU is |
68 | // chosen for the operation by the device placement logic in the |
69 | // executor. After that, the value returned by DeviceName will be a full |
70 | // device name such as "/job:localhost/replica:0/task:0/device:GPU:1". |
71 | virtual const string& DeviceName() const = 0; |
72 | |
73 | // Sets the operation device name. |
74 | // |
75 | // The given `name` must be parseable by DeviceNameUtils::ParseFullName, and |
76 | // the result will be used as a constraint for device placement. See the |
77 | // documentation for DeviceName for more details. |
78 | // |
79 | // The value will override the previous value - that is, no "merging" of |
80 | // existing and given constraints will be performed. |
81 | virtual Status SetDeviceName(const char* name) = 0; |
82 | |
83 | virtual Status AddInput(AbstractTensorHandle* input) = 0; |
84 | virtual Status AddInputList( |
85 | absl::Span<AbstractTensorHandle* const> inputs) = 0; |
86 | virtual Status Execute(absl::Span<AbstractTensorHandle*> retvals, |
87 | int* num_retvals) = 0; |
88 | |
89 | virtual Status SetAttrString(const char* attr_name, const char* data, |
90 | size_t length) = 0; |
91 | virtual Status SetAttrInt(const char* attr_name, int64_t value) = 0; |
92 | virtual Status SetAttrFloat(const char* attr_name, float value) = 0; |
93 | virtual Status SetAttrBool(const char* attr_name, bool value) = 0; |
94 | virtual Status SetAttrType(const char* attr_name, DataType value) = 0; |
95 | virtual Status SetAttrShape(const char* attr_name, const int64_t* dims, |
96 | const int num_dims) = 0; |
97 | virtual Status SetAttrShape(const char* attr_name, |
98 | const PartialTensorShape shape); |
99 | virtual Status SetAttrFunction(const char* attr_name, |
100 | const AbstractOperation* value) = 0; |
101 | virtual Status SetAttrFunctionName(const char* attr_name, const char* value, |
102 | size_t length) = 0; |
103 | virtual Status SetAttrTensor(const char* attr_name, |
104 | AbstractTensorInterface* tensor) = 0; |
105 | virtual Status SetAttrStringList(const char* attr_name, |
106 | const void* const* values, |
107 | const size_t* lengths, int num_values) = 0; |
108 | virtual Status SetAttrStringList(const char* attr_name, |
109 | absl::Span<string const> values); |
110 | virtual Status SetAttrFloatList(const char* attr_name, const float* values, |
111 | int num_values) = 0; |
112 | virtual Status SetAttrIntList(const char* attr_name, const int64_t* values, |
113 | int num_values) = 0; |
114 | virtual Status SetAttrTypeList(const char* attr_name, const DataType* values, |
115 | int num_values) = 0; |
116 | virtual Status SetAttrBoolList(const char* attr_name, |
117 | const unsigned char* values, |
118 | int num_values) = 0; |
119 | virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims, |
120 | const int* num_dims, int num_values) = 0; |
121 | virtual Status SetAttrFunctionList( |
122 | const char* attr_name, absl::Span<const AbstractOperation*> values) = 0; |
123 | |
124 | private: |
125 | const AbstractOperationKind kind_; |
126 | }; |
127 | |
128 | // TODO(b/193656009): Defining these in a cc file causes linker errors with |
129 | // fastbuild. |
130 | inline Status AbstractOperation::SetAttrShape(const char* attr_name, |
131 | const PartialTensorShape shape) { |
132 | return SetAttrShape(attr_name, shape.dim_sizes().data(), shape.dims()); |
133 | } |
134 | |
135 | inline Status AbstractOperation::SetAttrStringList( |
136 | const char* attr_name, absl::Span<string const> values) { |
137 | std::vector<const char*> raw_strs; |
138 | std::vector<size_t> lengths; |
139 | raw_strs.reserve(values.size()); |
140 | lengths.reserve(values.size()); |
141 | for (const auto& s : values) { |
142 | raw_strs.emplace_back(s.data()); |
143 | lengths.emplace_back(s.size()); |
144 | } |
145 | return SetAttrStringList(attr_name, |
146 | reinterpret_cast<const void**>(raw_strs.data()), |
147 | lengths.data(), values.size()); |
148 | } |
149 | |
150 | namespace internal { |
151 | struct AbstractOperationDeleter { |
152 | void operator()(AbstractOperation* p) const { |
153 | if (p != nullptr) { |
154 | p->Release(); |
155 | } |
156 | } |
157 | }; |
158 | } // namespace internal |
159 | |
160 | using AbstractOperationPtr = |
161 | std::unique_ptr<AbstractOperation, internal::AbstractOperationDeleter>; |
162 | |
163 | } // namespace tensorflow |
164 | |
165 | #endif // TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_ |
166 | |