1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
16#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
17
18#include <memory>
19#include <vector>
20
21#include "absl/types/optional.h"
22#include "absl/types/span.h"
23#include "tensorflow/c/eager/abstract_context.h"
24#include "tensorflow/c/eager/immediate_execution_distributed_manager.h"
25#include "tensorflow/c/eager/immediate_execution_operation.h"
26#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
27#include "tensorflow/c/tensor_interface.h"
28#include "tensorflow/core/framework/function.h"
29#include "tensorflow/core/framework/function.pb.h"
30#include "tensorflow/core/framework/numeric_types.h"
31#include "tensorflow/core/framework/tensor.h"
32#include "tensorflow/core/framework/types.pb.h"
33#include "tensorflow/core/platform/platform.h"
34#include "tensorflow/core/platform/status.h"
35#include "tensorflow/core/platform/tstring.h"
36#include "tensorflow/core/protobuf/config.pb.h"
37#include "tensorflow/core/util/device_name_utils.h"
38
39namespace tensorflow {
40class EagerExecutor;
41class EagerContext;
42class CustomDevice;
43class CustomDeviceOpHandler;
44class Device;
45
46// LINT.IfChange
47// Note: Keep in sync with exported copy of enum in eager/c_api.h.
48enum ContextDevicePlacementPolicy {
49 // Running operations with input tensors on the wrong device will fail.
50 DEVICE_PLACEMENT_EXPLICIT = 0,
51 // Copy the tensor to the right device but log a warning.
52 DEVICE_PLACEMENT_WARN = 1,
53 // Silently copy the tensor, which has a performance cost since the operation
54 // will be blocked till the copy completes. This is the default policy.
55 DEVICE_PLACEMENT_SILENT = 2,
56 // Placement policy which silently copies int32 tensors but not other dtypes.
57 DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
58};
59// LINT.ThenChange(//tensorflow/c/eager/c_api.h)
60
61// Abstract interface to a context.
62//
63// A context is responsible for creating key objects such as Tensors,
64// TensorHandles & Operations.
65class ImmediateExecutionContext : public AbstractContext {
66 public:
67 // Optimized scalar creation functions
68 virtual AbstractTensorInterface* CreateInt64Scalar(int64_t value) = 0;
69 virtual AbstractTensorInterface* CreateUint64Scalar(uint64 value) = 0;
70 virtual AbstractTensorInterface* CreateInt32Scalar(int32_t value) = 0;
71 virtual AbstractTensorInterface* CreateFloatScalar(float value) = 0;
72 virtual AbstractTensorInterface* CreateDoubleScalar(double value) = 0;
73 virtual AbstractTensorInterface* CreateHalfScalar(Eigen::half value) = 0;
74 virtual AbstractTensorInterface* CreateStringScalar(tstring value) = 0;
75 virtual AbstractTensorInterface* CreateComplex128Scalar(complex128 value) = 0;
76 virtual AbstractTensorInterface* CreateBoolScalar(bool value) = 0;
77
78 // Tensor creation functions
79 virtual AbstractTensorInterface* CreateTensor(
80 DataType dtype, absl::Span<const int64_t> dim_sizes) = 0;
81
82 typedef void (*MemoryReleaser)(void* data, size_t len, void* arg);
83
84 // Create a tensor instance from the given data buffer and description.
85 // `memory_releaser` will be called on destruction, and it's responsible for
86 // cleaning up the underlying buffer.
87 virtual AbstractTensorInterface* CreateTensor(
88 DataType dtype, const int64_t* dims, int num_dims, void* data, size_t len,
89 MemoryReleaser memory_releaser, void* memory_releaser_arg) = 0;
90
91 // Create a handle to wrap and manage a Tensor
92 virtual ImmediateExecutionTensorHandle* CreateLocalHandle(
93 AbstractTensorInterface* t) = 0;
94 // Copy the handle to another device.
95 virtual ImmediateExecutionTensorHandle* CopyTensorHandleToDevice(
96 ImmediateExecutionTensorHandle* handle, const char* device_name,
97 Status* status) = 0;
98
99 // Create an operation to perform op execution
100 ImmediateExecutionOperation* CreateOperation() override = 0;
101
102 // Returns whether the runtime is backed by TFRT or the legacy TF Eager
103 // Runtime. This is necessary to decouple runtime-dependent
104 // code that is layered on top of the runtime.
105 virtual bool UsesTFRT() = 0;
106
107 // List attributes of available devices
108 virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
109
110 // Add `devices` into context's device manager. Context's device manager
111 // will take ownership and maintain devices' lifetime.
112 virtual Status AddDevices(std::vector<std::unique_ptr<Device>> devices) = 0;
113
114 // Block until all pending nodes are finished.
115 virtual Status AsyncWait() = 0;
116
117 // Add a function (serialized FunctionDef protocol buffer) so that it can
118 // be executed as an op. Return error if the function with the same name
119 // already exists.
120 virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
121
122 // Same as `AddFunctionDef`, but additionally saves the `stack_traces` under
123 // the key of the function definition name (to be retrieved during function
124 // instantiation).
125 virtual Status AddFunctionDefWithStackTraces(
126 const FunctionDef& fdef, const StackTracesMap& stack_traces) = 0;
127
128 // Find and return a added function by its name.
129 virtual const FunctionDef* FindFunctionDef(const string& name) const = 0;
130
131 // Return the ParsedName of Host CPU device.
132 virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0;
133 virtual const string& HostCPUName() const = 0;
134
135 // Configure soft device placement policy.
136 virtual void SetAllowSoftPlacement(bool enable) = 0;
137
138 // Configure device placement policy logging.
139 virtual void SetLogDevicePlacement(bool enable) = 0;
140
141 // Enables running eager ops as functions.
142 virtual void SetRunEagerOpAsFunction(bool enable) = 0;
143
144 // Enables rewriting jit_compile functions.
145 virtual void SetJitCompileRewrite(bool enable) = 0;
146
147 // Sets the device placement policy for the current thread.
148 virtual void SetThreadLocalDevicePlacementPolicy(
149 ContextDevicePlacementPolicy policy) = 0;
150 // Returns the device placement policy for the current thread.
151 virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0;
152
153 // Configure graph collection in RunMetadata.
154 virtual void SetShouldStoreGraphs(bool value) = 0;
155
156 // Return the collected RunMetadata. This method will transfer the ownership
157 // to the caller.
158 virtual std::unique_ptr<RunMetadata> ExportRunMetadata() = 0;
159
160 // For LLVM style RTTI.
161 static bool classof(const AbstractContext* ptr) {
162 return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
163 }
164
165 //===--------------------------------------------------------------------===//
166 // Experimental Custom Device.
167 //===--------------------------------------------------------------------===//
168 virtual CustomDeviceOpHandler& GetCustomDeviceOpHandler() = 0;
169
170 // Returns whether `device_name` is registered as a custom device.
171 virtual bool IsCustomDevice(const string& device_name) = 0;
172
173 // Register a custom device. It will return error is the device name is
174 // already registered.
175 // TODO(tfrt-devs): Remove this method. Let caller register it directly into
176 // CustomDeviceOpHandler.
177 virtual Status RegisterCustomDevice(const string& name,
178 std::unique_ptr<CustomDevice> device) = 0;
179
180 // Return FunctionLibraryDefinition. Transformations need to use it to use it
181 // to invoke MLIR compiler passes.
182 virtual FunctionLibraryDefinition* FuncLibDef() = 0;
183
184 // When tensor transfer across functions/eager executions using send/recv ops
185 // are required, `reuse_rendezvous_for_functions_` can be set to true so that
186 // function executions and eager executions use the same rendezvous instance,
187 // instead of creating new instance per function calls.
188 virtual void SetReuseRendezvousForFunctions(
189 bool reuse_rendezvous_for_functions) = 0;
190
191 // Resets the global rendezvous used for functions.
192 virtual void ResetGlobalRendezvousForFunction() = 0;
193
194 //===--------------------------------------------------------------------===//
195 // Following are features in current TF Eager Runtime.
196 // TODO(tfrt-devs): Figure out a way to deprecate following features after
197 // migrated to TFRT.
198 //===--------------------------------------------------------------------===//
199 // Clear pending nodes in thread executors and kernel caches.
200 virtual void ClearCachesAndThreadExecutors() = 0;
201
202 // Initialize the step resource container for a training step. This is used
203 // in current TF runtime. For tfrt, it is used by fallback op handler.
204 virtual void StartStep() = 0;
205 // Destroy the step resource container for a training step.
206 virtual void EndStep() = 0;
207
208 // Return the Eager Executor for current thread. Please note that Eager
209 // Executor is only used in current TF but not in TFRT.
210 virtual EagerExecutor& Executor() = 0;
211 // Update the Eager Executor for current thread.
212 virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
213
214 // Return a list of local tensorflow::Device*.
215 // TODO(tfrt-devs): We shouldn't expose legacy device in this API.
216 virtual std::vector<tensorflow::Device*> ListLocalTfDevices() = 0;
217
218 // Return a list of all tensorflow::Device*.
219 virtual std::vector<tensorflow::Device*> ListAllTfDevices() = 0;
220
221 //===--------------------------------------------------------------------===//
222 // Following are helper functions to assist integrating TFRT with current
223 // TF eager runtime.
224 // TODO(b/172877902): These helper functions are currently used to support
225 // PyFuncOp on TFRT, and might be useful for ops that directly use low
226 // level TF APIs. Remove/replace the following functions when TFRT native
227 // ops are implemented.
228 //===--------------------------------------------------------------------===//
229 // Create an abstract tensor handle from tensorflow::Tensor.
230 virtual ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor(
231 tensorflow::Tensor& t, const char* d_name) = 0;
232
233 // Convert a TFRT TensorHandle to tensorflow::TensorHandle.
234 virtual ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
235 ImmediateExecutionTensorHandle* handle) = 0;
236
237 virtual std::vector<std::string> GetLoggedOpsTestonly() { return {}; }
238
239 // Get a list of the names of functions that have been registered.
240 virtual std::vector<string> ListFunctionNames() = 0;
241
242 //===--------------------------------------------------------------------===//
243 // Distributed runtime related functions.
244 //===--------------------------------------------------------------------===//
245#if !defined(IS_MOBILE_PLATFORM)
246 // Set up a multi-client distributed execution environment. Must be called on
247 // all tasks in the cluster.
248 // This call internally coordinates with other tasks to initialize the eager
249 // context and TF server for multi-client execution.
250 virtual Status EnableCollectiveOps(const ServerDef& server_def) = 0;
251
252 // Set a distributed manager that helps set up, update, and check liveness
253 // of member tasks in the cluster.
254 virtual void SetDistributedManager(
255 std::unique_ptr<ImmediateExecutionDistributedManager> distributed) = 0;
256
257 virtual ImmediateExecutionDistributedManager* GetDistributedManager() = 0;
258#endif // !IS_MOBILE_PLATFORM
259
260 protected:
261 explicit ImmediateExecutionContext(AbstractContextKind kind)
262 : AbstractContext(kind) {}
263 ~ImmediateExecutionContext() override {}
264};
265
266namespace internal {
267struct ImmediateExecutionContextDeleter {
268 void operator()(ImmediateExecutionContext* p) const {
269 if (p != nullptr) {
270 p->Release();
271 }
272 }
273};
274} // namespace internal
275
276using ImmediateContextPtr =
277 std::unique_ptr<ImmediateExecutionContext,
278 internal::ImmediateExecutionContextDeleter>;
279
280} // namespace tensorflow
281
282#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
283