1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_ |
16 | #define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_ |
17 | |
18 | #include <memory> |
19 | #include <vector> |
20 | |
21 | #include "absl/types/optional.h" |
22 | #include "absl/types/span.h" |
23 | #include "tensorflow/c/eager/abstract_context.h" |
24 | #include "tensorflow/c/eager/immediate_execution_distributed_manager.h" |
25 | #include "tensorflow/c/eager/immediate_execution_operation.h" |
26 | #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" |
27 | #include "tensorflow/c/tensor_interface.h" |
28 | #include "tensorflow/core/framework/function.h" |
29 | #include "tensorflow/core/framework/function.pb.h" |
30 | #include "tensorflow/core/framework/numeric_types.h" |
31 | #include "tensorflow/core/framework/tensor.h" |
32 | #include "tensorflow/core/framework/types.pb.h" |
33 | #include "tensorflow/core/platform/platform.h" |
34 | #include "tensorflow/core/platform/status.h" |
35 | #include "tensorflow/core/platform/tstring.h" |
36 | #include "tensorflow/core/protobuf/config.pb.h" |
37 | #include "tensorflow/core/util/device_name_utils.h" |
38 | |
39 | namespace tensorflow { |
40 | class EagerExecutor; |
41 | class EagerContext; |
42 | class CustomDevice; |
43 | class CustomDeviceOpHandler; |
44 | class Device; |
45 | |
46 | // LINT.IfChange |
47 | // Note: Keep in sync with exported copy of enum in eager/c_api.h. |
48 | enum ContextDevicePlacementPolicy { |
49 | // Running operations with input tensors on the wrong device will fail. |
50 | DEVICE_PLACEMENT_EXPLICIT = 0, |
51 | // Copy the tensor to the right device but log a warning. |
52 | DEVICE_PLACEMENT_WARN = 1, |
53 | // Silently copy the tensor, which has a performance cost since the operation |
54 | // will be blocked till the copy completes. This is the default policy. |
55 | DEVICE_PLACEMENT_SILENT = 2, |
56 | // Placement policy which silently copies int32 tensors but not other dtypes. |
57 | DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, |
58 | }; |
59 | // LINT.ThenChange(//tensorflow/c/eager/c_api.h) |
60 | |
61 | // Abstract interface to a context. |
62 | // |
63 | // A context is responsible for creating key objects such as Tensors, |
64 | // TensorHandles & Operations. |
65 | class ImmediateExecutionContext : public AbstractContext { |
66 | public: |
67 | // Optimized scalar creation functions |
68 | virtual AbstractTensorInterface* CreateInt64Scalar(int64_t value) = 0; |
69 | virtual AbstractTensorInterface* CreateUint64Scalar(uint64 value) = 0; |
70 | virtual AbstractTensorInterface* CreateInt32Scalar(int32_t value) = 0; |
71 | virtual AbstractTensorInterface* CreateFloatScalar(float value) = 0; |
72 | virtual AbstractTensorInterface* CreateDoubleScalar(double value) = 0; |
73 | virtual AbstractTensorInterface* CreateHalfScalar(Eigen::half value) = 0; |
74 | virtual AbstractTensorInterface* CreateStringScalar(tstring value) = 0; |
75 | virtual AbstractTensorInterface* CreateComplex128Scalar(complex128 value) = 0; |
76 | virtual AbstractTensorInterface* CreateBoolScalar(bool value) = 0; |
77 | |
78 | // Tensor creation functions |
79 | virtual AbstractTensorInterface* CreateTensor( |
80 | DataType dtype, absl::Span<const int64_t> dim_sizes) = 0; |
81 | |
82 | typedef void (*MemoryReleaser)(void* data, size_t len, void* arg); |
83 | |
84 | // Create a tensor instance from the given data buffer and description. |
85 | // `memory_releaser` will be called on destruction, and it's responsible for |
86 | // cleaning up the underlying buffer. |
87 | virtual AbstractTensorInterface* CreateTensor( |
88 | DataType dtype, const int64_t* dims, int num_dims, void* data, size_t len, |
89 | MemoryReleaser memory_releaser, void* memory_releaser_arg) = 0; |
90 | |
91 | // Create a handle to wrap and manage a Tensor |
92 | virtual ImmediateExecutionTensorHandle* CreateLocalHandle( |
93 | AbstractTensorInterface* t) = 0; |
94 | // Copy the handle to another device. |
95 | virtual ImmediateExecutionTensorHandle* CopyTensorHandleToDevice( |
96 | ImmediateExecutionTensorHandle* handle, const char* device_name, |
97 | Status* status) = 0; |
98 | |
99 | // Create an operation to perform op execution |
100 | ImmediateExecutionOperation* CreateOperation() override = 0; |
101 | |
102 | // Returns whether the runtime is backed by TFRT or the legacy TF Eager |
103 | // Runtime. This is necessary to decouple runtime-dependent |
104 | // code that is layered on top of the runtime. |
105 | virtual bool UsesTFRT() = 0; |
106 | |
107 | // List attributes of available devices |
108 | virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0; |
109 | |
110 | // Add `devices` into context's device manager. Context's device manager |
111 | // will take ownership and maintain devices' lifetime. |
112 | virtual Status AddDevices(std::vector<std::unique_ptr<Device>> devices) = 0; |
113 | |
114 | // Block until all pending nodes are finished. |
115 | virtual Status AsyncWait() = 0; |
116 | |
117 | // Add a function (serialized FunctionDef protocol buffer) so that it can |
118 | // be executed as an op. Return error if the function with the same name |
119 | // already exists. |
120 | virtual Status AddFunctionDef(const FunctionDef& fdef) = 0; |
121 | |
122 | // Same as `AddFunctionDef`, but additionally saves the `stack_traces` under |
123 | // the key of the function definition name (to be retrieved during function |
124 | // instantiation). |
125 | virtual Status AddFunctionDefWithStackTraces( |
126 | const FunctionDef& fdef, const StackTracesMap& stack_traces) = 0; |
127 | |
128 | // Find and return a added function by its name. |
129 | virtual const FunctionDef* FindFunctionDef(const string& name) const = 0; |
130 | |
131 | // Return the ParsedName of Host CPU device. |
132 | virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0; |
133 | virtual const string& HostCPUName() const = 0; |
134 | |
135 | // Configure soft device placement policy. |
136 | virtual void SetAllowSoftPlacement(bool enable) = 0; |
137 | |
138 | // Configure device placement policy logging. |
139 | virtual void SetLogDevicePlacement(bool enable) = 0; |
140 | |
141 | // Enables running eager ops as functions. |
142 | virtual void SetRunEagerOpAsFunction(bool enable) = 0; |
143 | |
144 | // Enables rewriting jit_compile functions. |
145 | virtual void SetJitCompileRewrite(bool enable) = 0; |
146 | |
147 | // Sets the device placement policy for the current thread. |
148 | virtual void SetThreadLocalDevicePlacementPolicy( |
149 | ContextDevicePlacementPolicy policy) = 0; |
150 | // Returns the device placement policy for the current thread. |
151 | virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0; |
152 | |
153 | // Configure graph collection in RunMetadata. |
154 | virtual void SetShouldStoreGraphs(bool value) = 0; |
155 | |
156 | // Return the collected RunMetadata. This method will transfer the ownership |
157 | // to the caller. |
158 | virtual std::unique_ptr<RunMetadata> ExportRunMetadata() = 0; |
159 | |
160 | // For LLVM style RTTI. |
161 | static bool classof(const AbstractContext* ptr) { |
162 | return ptr->getKind() == kEager || ptr->getKind() == kTfrt; |
163 | } |
164 | |
165 | //===--------------------------------------------------------------------===// |
166 | // Experimental Custom Device. |
167 | //===--------------------------------------------------------------------===// |
168 | virtual CustomDeviceOpHandler& GetCustomDeviceOpHandler() = 0; |
169 | |
170 | // Returns whether `device_name` is registered as a custom device. |
171 | virtual bool IsCustomDevice(const string& device_name) = 0; |
172 | |
173 | // Register a custom device. It will return error is the device name is |
174 | // already registered. |
175 | // TODO(tfrt-devs): Remove this method. Let caller register it directly into |
176 | // CustomDeviceOpHandler. |
177 | virtual Status RegisterCustomDevice(const string& name, |
178 | std::unique_ptr<CustomDevice> device) = 0; |
179 | |
180 | // Return FunctionLibraryDefinition. Transformations need to use it to use it |
181 | // to invoke MLIR compiler passes. |
182 | virtual FunctionLibraryDefinition* FuncLibDef() = 0; |
183 | |
184 | // When tensor transfer across functions/eager executions using send/recv ops |
185 | // are required, `reuse_rendezvous_for_functions_` can be set to true so that |
186 | // function executions and eager executions use the same rendezvous instance, |
187 | // instead of creating new instance per function calls. |
188 | virtual void SetReuseRendezvousForFunctions( |
189 | bool reuse_rendezvous_for_functions) = 0; |
190 | |
191 | // Resets the global rendezvous used for functions. |
192 | virtual void ResetGlobalRendezvousForFunction() = 0; |
193 | |
194 | //===--------------------------------------------------------------------===// |
195 | // Following are features in current TF Eager Runtime. |
196 | // TODO(tfrt-devs): Figure out a way to deprecate following features after |
197 | // migrated to TFRT. |
198 | //===--------------------------------------------------------------------===// |
199 | // Clear pending nodes in thread executors and kernel caches. |
200 | virtual void ClearCachesAndThreadExecutors() = 0; |
201 | |
202 | // Initialize the step resource container for a training step. This is used |
203 | // in current TF runtime. For tfrt, it is used by fallback op handler. |
204 | virtual void StartStep() = 0; |
205 | // Destroy the step resource container for a training step. |
206 | virtual void EndStep() = 0; |
207 | |
208 | // Return the Eager Executor for current thread. Please note that Eager |
209 | // Executor is only used in current TF but not in TFRT. |
210 | virtual EagerExecutor& Executor() = 0; |
211 | // Update the Eager Executor for current thread. |
212 | virtual void SetExecutorForThread(EagerExecutor* executor) = 0; |
213 | |
214 | // Return a list of local tensorflow::Device*. |
215 | // TODO(tfrt-devs): We shouldn't expose legacy device in this API. |
216 | virtual std::vector<tensorflow::Device*> ListLocalTfDevices() = 0; |
217 | |
218 | // Return a list of all tensorflow::Device*. |
219 | virtual std::vector<tensorflow::Device*> ListAllTfDevices() = 0; |
220 | |
221 | //===--------------------------------------------------------------------===// |
222 | // Following are helper functions to assist integrating TFRT with current |
223 | // TF eager runtime. |
224 | // TODO(b/172877902): These helper functions are currently used to support |
225 | // PyFuncOp on TFRT, and might be useful for ops that directly use low |
226 | // level TF APIs. Remove/replace the following functions when TFRT native |
227 | // ops are implemented. |
228 | //===--------------------------------------------------------------------===// |
229 | // Create an abstract tensor handle from tensorflow::Tensor. |
230 | virtual ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor( |
231 | tensorflow::Tensor& t, const char* d_name) = 0; |
232 | |
233 | // Convert a TFRT TensorHandle to tensorflow::TensorHandle. |
234 | virtual ImmediateExecutionTensorHandle* TFTensorHandleFromInterface( |
235 | ImmediateExecutionTensorHandle* handle) = 0; |
236 | |
237 | virtual std::vector<std::string> GetLoggedOpsTestonly() { return {}; } |
238 | |
239 | // Get a list of the names of functions that have been registered. |
240 | virtual std::vector<string> ListFunctionNames() = 0; |
241 | |
242 | //===--------------------------------------------------------------------===// |
243 | // Distributed runtime related functions. |
244 | //===--------------------------------------------------------------------===// |
245 | #if !defined(IS_MOBILE_PLATFORM) |
246 | // Set up a multi-client distributed execution environment. Must be called on |
247 | // all tasks in the cluster. |
248 | // This call internally coordinates with other tasks to initialize the eager |
249 | // context and TF server for multi-client execution. |
250 | virtual Status EnableCollectiveOps(const ServerDef& server_def) = 0; |
251 | |
252 | // Set a distributed manager that helps set up, update, and check liveness |
253 | // of member tasks in the cluster. |
254 | virtual void SetDistributedManager( |
255 | std::unique_ptr<ImmediateExecutionDistributedManager> distributed) = 0; |
256 | |
257 | virtual ImmediateExecutionDistributedManager* GetDistributedManager() = 0; |
258 | #endif // !IS_MOBILE_PLATFORM |
259 | |
260 | protected: |
261 | explicit ImmediateExecutionContext(AbstractContextKind kind) |
262 | : AbstractContext(kind) {} |
263 | ~ImmediateExecutionContext() override {} |
264 | }; |
265 | |
266 | namespace internal { |
267 | struct ImmediateExecutionContextDeleter { |
268 | void operator()(ImmediateExecutionContext* p) const { |
269 | if (p != nullptr) { |
270 | p->Release(); |
271 | } |
272 | } |
273 | }; |
274 | } // namespace internal |
275 | |
276 | using ImmediateContextPtr = |
277 | std::unique_ptr<ImmediateExecutionContext, |
278 | internal::ImmediateExecutionContextDeleter>; |
279 | |
280 | } // namespace tensorflow |
281 | |
282 | #endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_ |
283 | |