1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_ |
16 | #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_ |
17 | |
18 | #include <algorithm> |
19 | #include <cstddef> |
20 | #include <map> |
21 | #include <memory> |
22 | #include <queue> |
23 | #include <string> |
24 | #include <unordered_set> |
25 | #include <vector> |
26 | |
27 | #include "absl/container/flat_hash_map.h" |
28 | #include "absl/types/optional.h" |
29 | #include "tensorflow/c/eager/immediate_execution_context.h" |
30 | #include "tensorflow/core/common_runtime/composite_device.h" |
31 | #include "tensorflow/core/common_runtime/device_factory.h" |
32 | #include "tensorflow/core/common_runtime/device_mgr.h" |
33 | #include "tensorflow/core/common_runtime/eager/custom_device.h" |
34 | #include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h" |
35 | #include "tensorflow/core/common_runtime/eager/eager_executor.h" |
36 | #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" |
37 | #include "tensorflow/core/common_runtime/function.h" |
38 | #include "tensorflow/core/common_runtime/process_function_library_runtime.h" |
39 | #include "tensorflow/core/common_runtime/rendezvous_mgr.h" |
40 | #include "tensorflow/core/example/example.pb.h" |
41 | #include "tensorflow/core/framework/collective.h" |
42 | #include "tensorflow/core/framework/function.h" |
43 | #include "tensorflow/core/framework/log_memory.h" |
44 | #include "tensorflow/core/framework/rendezvous.h" |
45 | #include "tensorflow/core/framework/tensor.h" |
46 | #include "tensorflow/core/lib/core/status.h" |
47 | #include "tensorflow/core/lib/core/stringpiece.h" |
48 | #include "tensorflow/core/lib/core/threadpool.h" |
49 | #include "tensorflow/core/lib/gtl/flatmap.h" |
50 | #include "tensorflow/core/lib/gtl/flatset.h" |
51 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
52 | #include "tensorflow/core/lib/gtl/map_util.h" |
53 | #include "tensorflow/core/platform/casts.h" |
54 | #include "tensorflow/core/platform/env.h" |
55 | #include "tensorflow/core/platform/fingerprint.h" |
56 | #include "tensorflow/core/platform/mutex.h" |
57 | #include "tensorflow/core/platform/platform.h" |
58 | #include "tensorflow/core/platform/status.h" |
59 | #include "tensorflow/core/platform/thread_annotations.h" |
60 | #include "tensorflow/core/platform/threadpool.h" |
61 | #include "tensorflow/core/public/session_options.h" |
62 | #include "tensorflow/core/public/version.h" |
63 | #include "tensorflow/core/util/device_name_utils.h" |
64 | |
65 | // "tensorflow/core/platform/platform.h" must be included first before using |
66 | // IS_MOBILE_PLATFORM. |
67 | #if !defined(IS_MOBILE_PLATFORM) |
68 | #include "tensorflow/core/distributed_runtime/eager/eager_client.h" |
69 | #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" |
70 | #include "tensorflow/core/distributed_runtime/server_lib.h" |
71 | #include "tensorflow/core/distributed_runtime/worker_cache.h" |
72 | #include "tensorflow/core/distributed_runtime/worker_env.h" |
73 | #endif // !IS_MOBILE_PLATFORM |
74 | |
75 | namespace tensorflow { |
76 | |
77 | namespace eager { |
78 | // We need this forward declaration because we have circular dependency: |
79 | // Context -> RemoteMgr -> TensorHandle -> Context. |
80 | // TODO(fishx): Remove this once we remove Context dependency in TensorHandle. |
81 | class RemoteMgr; |
82 | } // namespace eager |
83 | |
84 | class TensorHandle; |
85 | class EagerOperation; |
86 | |
87 | class EagerContext : public ImmediateExecutionContext, public core::RefCounted { |
88 | public: |
89 | static constexpr uint64 kInvalidContextId = 0; |
90 | |
91 | static uint64 NewContextId() { |
92 | uint64 context_id = random::New64(); |
93 | while (context_id == kInvalidContextId) { |
94 | context_id = random::New64(); |
95 | } |
96 | return context_id; |
97 | } |
98 | |
99 | EagerContext( |
100 | const SessionOptions& opts, |
101 | ContextDevicePlacementPolicy default_device_placement_policy, bool async, |
102 | /*const*/ DeviceMgr* device_mgr, bool device_mgr_owned, |
103 | /*const*/ Rendezvous* rendezvous, |
104 | DistributedFunctionLibraryRuntime* cluster_flr = nullptr, |
105 | CollectiveExecutorMgrInterface* collective_executor_mgr = nullptr, |
106 | bool run_eager_op_as_function = false, bool jit_compile_rewrite = false); |
107 | |
108 | void Release() override { Unref(); } |
109 | |
110 | AbstractTensorInterface* CreateInt64Scalar(int64_t value) override; |
111 | AbstractTensorInterface* CreateUint64Scalar(uint64 value) override; |
112 | AbstractTensorInterface* CreateInt32Scalar(int32_t value) override; |
113 | AbstractTensorInterface* CreateFloatScalar(float value) override; |
114 | AbstractTensorInterface* CreateDoubleScalar(double value) override; |
115 | AbstractTensorInterface* CreateHalfScalar(Eigen::half value) override; |
116 | AbstractTensorInterface* CreateStringScalar( |
117 | tensorflow::tstring value) override; |
118 | AbstractTensorInterface* CreateComplex128Scalar( |
119 | tensorflow::complex128 value) override; |
120 | AbstractTensorInterface* CreateBoolScalar(bool value) override; |
121 | |
122 | AbstractTensorInterface* CreateTensor( |
123 | DataType dtype, absl::Span<const int64_t> dim_sizes) override; |
124 | AbstractTensorInterface* CreateTensor(DataType dtype, const int64_t* dims, |
125 | int num_dims, void* data, size_t len, |
126 | MemoryReleaser memory_releaser, |
127 | void* memory_releaser_arg) override; |
128 | |
129 | ImmediateExecutionTensorHandle* CreateLocalHandle( |
130 | AbstractTensorInterface* t) override; |
131 | // Create an abstract tensor handle from tensorflow::Tensor. |
132 | ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor( |
133 | tensorflow::Tensor& t, const char* d_name) override; |
134 | ImmediateExecutionTensorHandle* CopyTensorHandleToDevice( |
135 | ImmediateExecutionTensorHandle* handle, const char* device_name, |
136 | Status* status) override; |
137 | ImmediateExecutionOperation* CreateOperation() override; |
138 | |
139 | // This is a virtual helper function to convert TFRT TensorHandle to |
140 | // tensorflow::TensorHandle. In current runtime EagerContext, just forward |
141 | // the input since the input tensor handle is already a |
142 | // tensorflow::TensorHandle. |
143 | ImmediateExecutionTensorHandle* TFTensorHandleFromInterface( |
144 | ImmediateExecutionTensorHandle* handle) override; |
145 | |
146 | Status RegisterFunction(AbstractFunction* f) override; |
147 | |
148 | bool UsesTFRT() override; |
149 | |
150 | bool RunEagerOpAsFunction() const; |
151 | |
152 | void SetRunEagerOpAsFunction(bool enable) override; |
153 | |
154 | bool JitCompileRewrite() const; |
155 | |
156 | void SetJitCompileRewrite(bool enable) override; |
157 | |
158 | void ListDevices(std::vector<DeviceAttributes>* devices) override; |
159 | |
160 | Status AddDevices(std::vector<std::unique_ptr<Device>> devices) override; |
161 | |
162 | thread::ThreadPool* GetThreadPool() { return thread_pool_.get(); } |
163 | |
164 | // Returns the function library runtime for the given device. |
165 | FunctionLibraryRuntime* func_lib(const Device* d) const { |
166 | return pflr_->GetFLR(d->name()); |
167 | } |
168 | |
169 | ProcessFunctionLibraryRuntime* pflr() const { return pflr_.get(); } |
170 | |
171 | std::function<void(std::function<void()>)>* runner() { return &runner_; } |
172 | |
173 | // Specify a executor for this thread. |
174 | void SetExecutorForThread(EagerExecutor* executor) override; |
175 | |
176 | const std::shared_ptr<std::vector<DeviceType>> prioritized_device_type_list() |
177 | const { |
178 | mutex_lock l(device_type_list_mu_); |
179 | return prioritized_device_type_list_; |
180 | } |
181 | |
182 | // Clear pending nodes in thread executors and kernel caches. |
183 | void ClearCachesAndThreadExecutors() override; |
184 | // Clear pending nodes in default executor and kernel caches. |
185 | void ClearCachesAndDefaultExecutor(); |
186 | |
187 | // Sets the device placement policy for the current thread. |
188 | void SetThreadLocalDevicePlacementPolicy( |
189 | ContextDevicePlacementPolicy policy) override; |
190 | |
191 | // Returns the device placement policy for the current thread. |
192 | ContextDevicePlacementPolicy GetDevicePlacementPolicy() const override; |
193 | |
194 | // Select an appropriate device for an operation. |
195 | // |
196 | // Given the preferred device for the operation, and the node_def, finds the |
197 | // best suitable device for the operation in this context. |
198 | // |
199 | // The preferred device is specified as a `ParsedName` containing the elements |
200 | // (details) that the resulting device should match. If there are no such |
201 | // devices, and the context currently allows soft device placement, a suitable |
202 | // device not matching `preferred` will be chosen. |
203 | // |
204 | // The chosen device is stored in the `device` argument. The argument is not |
205 | // modified unless this method returns `OkStatus()`. |
206 | Status SelectDevice(DeviceNameUtils::ParsedName preferred, |
207 | const NodeDef& ndef, Device** out) const; |
208 | |
209 | // TODO(mdan): Rename to ContainsFunction. |
210 | bool FindFunctionByName(const string& name) const; |
211 | |
212 | Status FindFunctionOpData(const string& name, |
213 | const tensorflow::OpRegistrationData** op_data); |
214 | |
215 | const FunctionDef* FindFunctionDef(const string& name) const override; |
216 | |
217 | Device* HostCPU() const { return host_cpu_device_; } |
218 | Device* CanonicalDevice(Device* d) const { |
219 | return HostCPU() == d ? nullptr : d; |
220 | } |
221 | const DeviceNameUtils::ParsedName& HostCPUParsedName() const override { |
222 | return HostCPU()->parsed_name(); |
223 | } |
224 | |
225 | const string& HostCPUName() const override { return HostCPU()->name(); } |
226 | |
227 | GraphCollector* GetGraphCollector() { return &graph_collector_; } |
228 | |
229 | EagerExecutor& Executor() override; |
230 | |
231 | // Add the given `fdef` to the local FunctionLibraryDefinition. And add an |
232 | // entry to the KernelAndDevice cache for it if it's not exist. |
233 | Status AddFunctionDef(const FunctionDef& fdef) override; |
234 | |
235 | Status AddFunctionDefWithStackTraces( |
236 | const FunctionDef& fdef, const StackTracesMap& stack_traces) override; |
237 | |
238 | // `library` contains all FunctionDefs and GradientDefs to expand `fdef`. Add |
239 | // it to the local FunctionLibraryDefinition as well, but no need to add it |
240 | // to the KernelAndDevice cache since they won't be executed as |
241 | // KernelAndDevices. |
242 | Status AddFunctionDef(const FunctionDef& fdef, |
243 | const FunctionDefLibrary& library, |
244 | bool add_to_local_only = false, |
245 | const StackTracesMap& stack_traces = {}); |
246 | |
247 | const FunctionDef* GetFunctionDef(const string& function_name); |
248 | |
249 | std::vector<string> ListFunctionNames() override; |
250 | |
251 | Status RemoveFunction(const string& func) override; |
252 | |
253 | // Wait for pending nodes to be finished in local executors (including context |
254 | // default executor and thread executors) and executors on remote workers. |
255 | // Return combined status of remote executors. If there are multiple errors, |
256 | // the Status code will be the same as the first remote executor that has |
257 | // errors, and the error message will be combined from all executors. |
258 | Status SyncExecutors(); |
259 | |
260 | Status AsyncWait() override { return SyncExecutors(); } |
261 | |
262 | core::RefCountPtr<KernelAndDevice> GetCachedKernel(Fprint128 cache_key); |
263 | Device* GetCachedDevice(Fprint128 device_cache_key); |
264 | |
265 | void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel); |
266 | void AddDeviceToCache(Fprint128 device_cache_key, Device* device); |
267 | |
268 | bool LogDevicePlacement() const { return log_device_placement_; } |
269 | void SetLogDevicePlacement(bool enable) override { |
270 | log_device_placement_ = enable; |
271 | } |
272 | |
273 | // When tensor transfer across functions/eager executions using send/recv ops |
274 | // are required, `reuse_rendezvous_for_functions_` can be set to true so that |
275 | // function executions and eager executions use the same rendezvous instance, |
276 | // instead of creating new instance per function calls. |
277 | void SetReuseRendezvousForFunctions( |
278 | bool reuse_rendezvous_for_functions) override { |
279 | reuse_rendezvous_for_functions_ = reuse_rendezvous_for_functions; |
280 | } |
281 | bool GetReuseRendezvousForFunctions() const { |
282 | return reuse_rendezvous_for_functions_; |
283 | } |
284 | mutex* reuse_rendezvous_for_functions_mu() { |
285 | return &reuse_rendezvous_for_functions_mu_; |
286 | } |
287 | |
288 | bool AllowSoftPlacement() const { return allow_soft_placement_; } |
289 | void SetAllowSoftPlacement(bool enable) override { |
290 | allow_soft_placement_ = enable; |
291 | } |
292 | bool LogMemory() const { return log_memory_; } |
293 | |
294 | Rendezvous* GetRendezvous() const { return rendezvous_; } |
295 | |
296 | void ResetGlobalRendezvousForFunction() override { |
297 | mutex_lock l(global_rendezvous_mu_); |
298 | // Remove the global rendezvous instance from the local rendezvous table |
299 | // if it uses local rendezvous type, which forces EagerContext to create a |
300 | // new local rendezvous instance in the table. |
301 | local_rendezvous_table_->Remove(-1); |
302 | global_rendezvous_for_functions_ = |
303 | core::RefCountPtr<Rendezvous>(CreateRendezvous(-1)); |
304 | } |
305 | |
306 | // Returns the global_rendezvous_for_functions' underlying LocalRendezvous' |
307 | // status. If the underlying Rendezvous is not in the local_rendezvous_table_ |
308 | // returns OK. |
309 | Status GetGlobalRendezvousForFunctionLocalRendezvousStatus(); |
310 | |
311 | // Returns a function which maps from step_id to rendezvous. This closure |
312 | // respects the value of `SetReuseRendezvousForFunctions` at the time the |
313 | // closure was created, which allows the setting to be toggled around async op |
314 | // launches. |
315 | // |
316 | // The caller of the returned function owns a reference to the resulting |
317 | // Rendezvous. |
318 | std::function<Rendezvous*(int64_t)> RendezvousCreator() { |
319 | // There is an implicit assumption that the global_rendezvous_for_functions_ |
320 | // is always an IntraProcessRendezvous to match the behaviour of the |
321 | // EagerContext's rendezvous. |
322 | // Ref: tensorflow/c/eager/c_api.cc;l=143;rcl=396387348 |
323 | // If a cross process kernel needs a rendezvous a new InterProcessRendezvous |
324 | // should be created. |
325 | if (reuse_rendezvous_for_functions_ && rendezvous_creator_ == nullptr && |
326 | #if !defined(IS_MOBILE_PLATFORM) |
327 | worker_env_ == nullptr && |
328 | #endif |
329 | remote_device_mgr() == nullptr) { |
330 | return [this](int64_t step_id) { |
331 | mutex_lock l(global_rendezvous_mu_); |
332 | global_rendezvous_for_functions_->Ref(); |
333 | return global_rendezvous_for_functions_.get(); |
334 | }; |
335 | } else { |
336 | return [this](int64_t step_id) { return CreateRendezvous(step_id); }; |
337 | } |
338 | } |
339 | |
340 | CollectiveExecutorMgrInterface* collective_executor_mgr() { |
341 | return collective_executor_mgr_.Get(); |
342 | } |
343 | std::unique_ptr<CollectiveExecutor::Handle> GetCollectiveExecutorHandle() { |
344 | return std::unique_ptr<CollectiveExecutor::Handle>( |
345 | new CollectiveExecutor::Handle( |
346 | collective_executor_mgr()->FindOrCreate(0), true /*inherit_ref*/)); |
347 | } |
348 | |
349 | tensorflow::DeviceMgr* local_device_mgr() const { |
350 | return local_device_manager_.Get(); |
351 | } |
352 | const tensorflow::DynamicDeviceMgr* remote_device_mgr() const { |
353 | return remote_device_manager_.Get(); |
354 | } |
355 | |
356 | tensorflow::DynamicDeviceMgr* GetOwnedRemoteDeviceMgr() { |
357 | return remote_device_manager_.GetOwned(); |
358 | } |
359 | |
360 | std::vector<Device*> ListLocalTfDevices() override { |
361 | return local_device_mgr()->ListDevices(); |
362 | } |
363 | |
364 | std::vector<Device*> ListAllTfDevices() override; |
365 | |
366 | // TODO(apassos) clean up RunMetadata storage. |
367 | mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; } |
368 | bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_); |
369 | void SetShouldStoreGraphs(bool value) override; |
370 | RunMetadata* RunMetadataProto() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_) { |
371 | return run_metadata_.get(); |
372 | } |
373 | std::unique_ptr<RunMetadata> ExportRunMetadata() override |
374 | TF_LOCKS_EXCLUDED(metadata_mu_); |
375 | |
376 | void StartStep() override; |
377 | void EndStep() override; |
378 | ScopedStepContainer* StepContainer(); |
379 | |
380 | FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; } |
381 | |
382 | #if !defined(IS_MOBILE_PLATFORM) |
383 | // Assign the EagerClient pointer to `client` based on the given device / task |
384 | // name, and increment the refcount of the client. The reference ownership is |
385 | // transferred to the caller, and the unref should automatically happen when |
386 | // destructing the RefCountPtr object at the caller's side. |
387 | // `client` must not be initialized or holding a reference of another object |
388 | // before calling this method. |
389 | Status GetClient(Device* device, |
390 | core::RefCountPtr<eager::EagerClient>* client); |
391 | Status GetClient(const DeviceNameUtils::ParsedName& device_name, |
392 | core::RefCountPtr<eager::EagerClient>* client); |
393 | Status GetClient(const string& remote_task, |
394 | core::RefCountPtr<eager::EagerClient>* client); |
395 | |
396 | uint64 GetContextId() const; |
397 | uint64 GetContextViewId() const; |
398 | void IncrementContextViewId(); |
399 | |
400 | Status EnableCollectiveOps(const ServerDef& server_def) override; |
401 | |
402 | // TODO(nareshmodi): Encapsulate remote state into a separate |
403 | // class/struct. |
404 | // |
405 | // Enables the eager context to communicate with remote devices. When |
406 | // initializing with this method, this context will be the primary context, |
407 | // which will kill all its remote contexts in shutdown. |
408 | // |
409 | // - server: A ServerInterface that exports the tensorflow.WorkerService. |
410 | // Note that this class expects the server to already have been started. |
411 | // - remote_eager_workers: A cache from which we can get "EagerClient"s to |
412 | // communicate with remote eager services. |
413 | // - remote_device_mgr: A DeviceMgr* which contains all remote devices |
414 | // (should contain no local devices). |
415 | // - remote_contexts: A vector containing task names. |
416 | // TODO(b/184375824): clean up parameter order for better readability. |
417 | Status InitializeRemoteMaster( |
418 | std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env, |
419 | std::shared_ptr<WorkerSession> worker_session, |
420 | std::unique_ptr<eager::EagerClientCache> remote_eager_workers, |
421 | std::unique_ptr<DynamicDeviceMgr> remote_device_manager, |
422 | const std::vector<string>& remote_contexts, uint64 context_id, |
423 | /*const*/ Rendezvous* r, /*const*/ DeviceMgr* local_device_mgr, |
424 | int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr, |
425 | std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>> |
426 | remote_mgr); |
427 | |
428 | // Update an existing master context with a new set of remote workers (i.e., a |
429 | // new "view" of cluster membership. Similar to InitializeRemoteMaster but |
430 | // this will keep the current context_id and increment a context_view_id, will |
431 | // keep the current resource manager so that resources from the previous view |
432 | // can still be accessed, and will automatically register existing functions |
433 | // if there are newly added hosts. |
434 | Status UpdateRemoteMaster( |
435 | uint64 context_id, |
436 | std::unique_ptr<eager::EagerClientCache> remote_eager_workers, |
437 | const std::vector<string>& add_remote_contexts, |
438 | const std::vector<string>& remove_remote_contexts); |
439 | |
440 | // Similar with InitializeRemoteMaster but this context will not kill remote |
441 | // contexts in shutdown. |
442 | Status InitializeRemoteWorker( |
443 | std::unique_ptr<eager::EagerClientCache> remote_eager_workers, |
444 | DynamicDeviceMgr* remote_device_mgr, |
445 | const std::vector<string>& remote_contexts, uint64 context_id, |
446 | uint64 context_view_id, |
447 | std::function<Rendezvous*(const int64_t)> rendezvous_creator, |
448 | DistributedFunctionLibraryRuntime* cluster_flr, |
449 | std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>> |
450 | remote_mgr, |
451 | std::function<void()> resource_deallocator); |
452 | |
453 | // Similar with InitializeRemoteWorker but will reuse existing context and |
454 | // increment context_view_id. |
455 | Status UpdateRemoteWorker( |
456 | std::unique_ptr<eager::EagerClientCache> remote_eager_workers, |
457 | const std::vector<string>& remote_contexts, uint64 context_id); |
458 | |
459 | Status StoreCollectiveOpsServer( |
460 | std::unique_ptr<ServerInterface> new_server, DeviceMgr* device_mgr, |
461 | CollectiveExecutorMgrInterface* rpc_collective_executor_mgr); |
462 | |
463 | // For the specified remote worker, preprocess and set its device filters. |
464 | Status SetRemoteDeviceFilters(const string& remote_worker, |
465 | const std::vector<string>& device_filters); |
466 | |
467 | // For the specified remote worker, apply the stored device filters to the |
468 | // list of device attributes following these rules: |
469 | // (1) if the remote worker does not have device filters, all devices are |
470 | // visible to the worker; |
471 | // (2) if the device is on the remote worker, then it is visible; |
472 | // (3) if the device matches at least one device filter, then it is visible. |
473 | // The result is saved as a boolean vector of the same length (i.e., |
474 | // filtered_device_mask) indicating whether each of the devices is visible to |
475 | // the remote worker. |
476 | void FilterDevicesForRemoteWorkers( |
477 | const string& remote_worker, |
478 | const protobuf::RepeatedPtrField<DeviceAttributes>& device_attrs, |
479 | std::vector<bool>* filtered_device_mask); |
480 | |
481 | // TODO(fishx): Remove the custom deleter once we remove forward declaration. |
482 | const std::unique_ptr<eager::RemoteMgr, |
483 | std::function<void(eager::RemoteMgr*)>>& |
484 | RemoteMgr() { |
485 | return remote_mgr_; |
486 | } |
487 | |
488 | // If true, then tensors should be shipped across processes via the |
489 | // EagerService.Enqueue(SendTensorOp). If false, _Send/_Recv ops should be |
490 | // used instead (which in-turn use WorkerService.RecvTensor RPCs). |
491 | bool UseSendTensorRPC() { return use_send_tensor_rpc_; } |
492 | |
493 | tensorflow::ServerInterface* GetServer() { return server_.get(); } |
494 | |
495 | // For LLVM style RTTI. |
496 | static bool classof(const AbstractContext* ptr) { |
497 | return ptr->getKind() == kEager; |
498 | } |
499 | |
500 | // Function to support distributed C API. |
501 | void SetDistributedManager( |
502 | std::unique_ptr<ImmediateExecutionDistributedManager> distributed) |
503 | override { |
504 | distributed_manager_ = std::move(distributed); |
505 | } |
506 | ImmediateExecutionDistributedManager* GetDistributedManager() override { |
507 | return distributed_manager_.get(); |
508 | } |
509 | |
510 | // May only be used during multi-client setup so that a RemoteRendezvous |
511 | // can be initialized instead of defaulting to the IntraProcessRendezvous. |
512 | void SetWorkerEnv(WorkerEnv* worker_env, |
513 | std::shared_ptr<WorkerSession> worker_session); |
514 | #endif // IS_MOBILE_PLATFORM |
515 | |
516 | // Closes remote eager contexts, waits for all RPCs to finish, and |
517 | // destroys the EagerClientCache. No RPCs can be made through this context |
518 | // after this method has been called. |
519 | // This method exists to aid a clean shutdown. It causes all RPCs to finish |
520 | // and remote TensorHandles to release their references to this context. |
521 | // To avoid deadlocks, this method must not be called on the thread |
522 | // processing RPCs because it makes RPCs and waits for their completion. |
523 | // |
524 | // On mobile, it just cleans the caches. |
525 | void WaitForAndCloseRemoteContexts(); |
526 | |
527 | bool PinSmallOpsToCPU() const { return pin_small_ops_to_cpu_; } |
528 | |
529 | tensorflow::Env* TFEnv() const { return env_; } |
530 | |
531 | Status FindDeviceFromName(const char* device_name, Device** device) const; |
532 | |
533 | Status FindCompositeDeviceFromName(StringPiece device_name, |
534 | CompositeDevice** device) const; |
535 | |
536 | bool IsCustomDevice(const string& device_name) override; |
537 | |
538 | Status RegisterCustomDevice(const string& name, |
539 | std::unique_ptr<CustomDevice> device) override; |
540 | |
541 | CustomDeviceOpHandler& GetCustomDeviceOpHandler() override { |
542 | return custom_device_op_handler_; |
543 | }; |
544 | |
545 | // Find or create a composite device with the given `underlying_devices` and |
546 | // `device_name` (if not empty). |
547 | Status FindOrCreateCompositeDevice( |
548 | const std::vector<string>& underlying_devices, const string& device_name, |
549 | CompositeDevice** composite_device); |
550 | |
551 | bool OnSameTask(const Device* first, const Device* second) const; |
552 | // Gets the CPU device on the task of device. |
553 | Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const; |
554 | |
555 | const SessionOptions& session_options() const { return opts_; } |
556 | void InitPrioritizedDeviceTypeList(); |
557 | |
558 | // Re-assign cluster-FLR and re-initialize devices and FLR in process-FLR |
559 | void UpdateClusterFLRAndInitDevices( |
560 | DistributedFunctionLibraryRuntime* cluster_flr); |
561 | |
562 | // A constant representing the step id used for the global rendezvous. |
563 | // This is used to distibguish whether a user-specified step id should be set. |
564 | // Step id value of kGlobalRendezvous is reserved and should not be specified |
565 | // by the user. |
566 | static const int64_t kGlobalRendezvousId; |
567 | |
568 | private: |
569 | // The class for wrapping a map of step_id to local rendezvous instances. |
570 | class LocalRendezvousTable { |
571 | public: |
572 | LocalRendezvousTable() = default; |
573 | ~LocalRendezvousTable(); |
574 | |
575 | IntraProcessRendezvous* FindOrCreate(int64_t step_id, |
576 | DeviceMgr* device_mgr); |
577 | IntraProcessRendezvous* Find(int64_t step_id); |
578 | void Remove(int64_t step_id); |
579 | void CleanUpAll(); |
580 | |
581 | private: |
582 | mutable mutex table_lock_; |
583 | absl::flat_hash_map<int64_t, IntraProcessRendezvous*> table_ |
584 | TF_GUARDED_BY(table_lock_); |
585 | }; |
586 | |
587 | Rendezvous* CreateRendezvous(int64_t step_id) const { |
588 | if (rendezvous_creator_ != nullptr) { |
589 | VLOG(6) << "Creating rendezvous using the rendezvous_creator_." ; |
590 | return rendezvous_creator_(step_id); |
591 | } |
592 | |
593 | #if !defined(IS_MOBILE_PLATFORM) |
594 | if (worker_env_ != nullptr && worker_env_->rendezvous_mgr != nullptr) { |
595 | VLOG(6) << "Creating rendezvous using the worker_env's rendezvous_mgr." ; |
596 | auto* remote_r = worker_env_->rendezvous_mgr->Find(step_id); |
597 | remote_r->Initialize(worker_session_.get()).IgnoreError(); |
598 | return remote_r; |
599 | } |
600 | #endif |
601 | |
602 | if (remote_device_mgr() == nullptr) { |
603 | VLOG(6) << "Creating rendezvous using local_device_mgr." ; |
604 | return local_rendezvous_table_->FindOrCreate(step_id, local_device_mgr()); |
605 | } |
606 | |
607 | return nullptr; |
608 | } |
609 | |
610 | ~EagerContext() override; |
611 | |
612 | Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef); |
613 | Status RegisterExistingFunctionsOnRemoteWorkers( |
614 | const std::vector<string>& remote_workers); |
615 | |
616 | void ResetPFLR(const DeviceMgr* device_mgr, Env* env, |
617 | const ConfigProto* config, int graph_def_version, |
618 | const FunctionLibraryDefinition* lib_def, |
619 | const OptimizerOptions& optimizer_options, |
620 | thread::ThreadPool* thread_pool = nullptr, |
621 | DistributedFunctionLibraryRuntime* cluster_flr = nullptr); |
622 | |
623 | void ResetClusterFLR(DistributedFunctionLibraryRuntime* cluster_flr); |
624 | void UpdateGlobalRendezvousDeviceManager(tensorflow::DeviceMgr* device_mgr); |
625 | |
626 | void ClearResourceContainer(const string& name); |
627 | |
628 | template <typename T> |
629 | struct OwnedOrUnownedHelper { |
630 | public: |
631 | OwnedOrUnownedHelper() {} |
632 | explicit OwnedOrUnownedHelper(T* object, const bool owned = false) { |
633 | Reset(object, owned); |
634 | } |
635 | |
636 | void Reset(std::unique_ptr<T> object) { |
637 | owned_object = std::move(object); |
638 | unowned_object_ptr = nullptr; |
639 | } |
640 | |
641 | void Reset(T* object, const bool owned = false) { |
642 | if (owned) { |
643 | owned_object.reset(object); |
644 | unowned_object_ptr = nullptr; |
645 | } else { |
646 | owned_object.reset(nullptr); |
647 | unowned_object_ptr = object; |
648 | } |
649 | } |
650 | |
651 | bool Owned() const { return owned_object != nullptr; } |
652 | |
653 | T* GetOwned() const { return owned_object.get(); } |
654 | T* Get() const { |
655 | return owned_object ? owned_object.get() : unowned_object_ptr; |
656 | } |
657 | |
658 | std::unique_ptr<T> owned_object = nullptr; |
659 | T* unowned_object_ptr = nullptr; |
660 | }; |
661 | |
662 | SessionOptions opts_; |
663 | const ContextDevicePlacementPolicy default_device_placement_policy_; |
664 | |
665 | // Note: we cannot use C++11 thread_local here as there is no concept of a |
666 | // thread-local-object-local variable in C++11. |
667 | mutable mutex policy_map_mu_; |
668 | std::unordered_map<std::thread::id, ContextDevicePlacementPolicy> |
669 | device_placement_policy_ TF_GUARDED_BY(policy_map_mu_); |
670 | |
671 | // This device manager maintains only the local devices on this worker. |
672 | OwnedOrUnownedHelper<DeviceMgr> local_device_manager_; |
673 | // Maintain copy of all previously created local device managers. |
674 | std::vector<std::unique_ptr<DeviceMgr>> old_local_device_managers_; |
675 | |
676 | // Unowned DynamicDeviceMgr is set on remote worker to allow running |
677 | // multi-device function on remote worker. |
678 | // This device manager maintains all the devices (including both local and |
679 | // remote to this worker) in the cluster. |
680 | OwnedOrUnownedHelper<DynamicDeviceMgr> remote_device_manager_; |
681 | |
682 | Device* host_cpu_device_; // Owned by device_manager |
683 | mutable mutex device_type_list_mu_; |
684 | std::shared_ptr<std::vector<DeviceType>> prioritized_device_type_list_ |
685 | TF_GUARDED_BY(device_type_list_mu_); |
686 | Rendezvous* rendezvous_; |
687 | std::function<Rendezvous*(const int64_t)> rendezvous_creator_; |
688 | CustomDeviceOpHandler custom_device_op_handler_; |
689 | |
690 | mutable mutex composite_devices_mu_; |
691 | // Maps from the fingerprint of a set of device names to a virtual |
692 | // CompositeDevice. |
693 | // TODO(b/145922293): Consider taking device names as keys. |
694 | absl::flat_hash_map<uint64, std::unique_ptr<CompositeDevice>> |
695 | composite_devices_ ABSL_GUARDED_BY(composite_devices_mu_); |
696 | |
697 | FunctionLibraryDefinition func_lib_def_{OpRegistry::Global(), {}}; |
698 | |
699 | std::unique_ptr<thread::ThreadPool> thread_pool_; |
700 | |
701 | // EagerContext owns the DistributedFunctionLibraryRuntime( |
702 | // EagerClusterFunctionLibraryRuntime) if using EagerService for remote |
703 | // function execution (lazy_copy_function_remote_inputs_=true). |
704 | OwnedOrUnownedHelper<DistributedFunctionLibraryRuntime> cluster_flr_; |
705 | // One FunctionLibraryRuntime per device. |
706 | // func_libs[i] is the FunctionLibraryRuntime corresponding to |
707 | // session->devices[i]. |
708 | std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; |
709 | |
710 | std::function<void(std::function<void()>)> runner_; |
711 | |
712 | mutex cache_mu_; |
713 | mutex device_cache_mu_; |
714 | struct RegisteredFunction : public core::RefCounted { |
715 | ~RegisteredFunction() override {} |
716 | |
717 | std::unique_ptr<std::vector<Fprint128>> cached_kernel_keys; |
718 | }; |
719 | std::unordered_map<Fprint128, core::RefCountPtr<KernelAndDevice>, |
720 | Fprint128Hasher> |
721 | kernel_cache_ TF_GUARDED_BY(cache_mu_); |
722 | std::unordered_map<string, RegisteredFunction*> registered_functions_ |
723 | TF_GUARDED_BY(cache_mu_); |
724 | absl::flat_hash_map<Fprint128, Device*, Fprint128Hasher> device_cache_ |
725 | TF_GUARDED_BY(device_cache_mu_); |
726 | |
727 | // Whether we should compute RunMetadata. |
728 | std::atomic<bool> should_store_graphs_{false}; |
729 | mutex metadata_mu_; |
730 | std::unique_ptr<RunMetadata> run_metadata_ TF_GUARDED_BY(metadata_mu_); |
731 | GraphCollector graph_collector_; |
732 | std::atomic<bool> log_device_placement_; |
733 | std::atomic<bool> allow_soft_placement_; |
734 | |
735 | // Information related to step containers. |
736 | std::atomic<int> num_active_steps_; |
737 | std::unique_ptr<ScopedStepContainer> step_container_ |
738 | TF_GUARDED_BY(metadata_mu_); |
739 | |
740 | EagerExecutor default_executor_; |
741 | mutable mutex executor_map_mu_; |
742 | // Not owned. |
743 | std::unordered_map<std::thread::id, EagerExecutor*> thread_local_executor_ |
744 | TF_GUARDED_BY(executor_map_mu_); |
745 | std::unordered_map<std::thread::id, std::unordered_set<EagerExecutor*>> |
746 | has_cleanup_ TF_GUARDED_BY(executor_map_mu_); |
747 | |
748 | const bool log_memory_; |
749 | |
750 | // The table of local rendezvous instances for intra-process communication. |
751 | // This make sures only one local rendezvous instance exists per step id. |
752 | std::unique_ptr<LocalRendezvousTable> local_rendezvous_table_; |
753 | |
754 | // Whether to use same rendezvous instance across function/eager executions. |
755 | std::atomic<bool> reuse_rendezvous_for_functions_{false}; |
756 | mutable mutex global_rendezvous_mu_; |
757 | core::RefCountPtr<Rendezvous> global_rendezvous_for_functions_ |
758 | TF_GUARDED_BY(global_rendezvous_mu_); |
759 | mutex reuse_rendezvous_for_functions_mu_; |
760 | |
761 | Env* const env_; |
762 | |
763 | OwnedOrUnownedHelper<CollectiveExecutorMgrInterface> collective_executor_mgr_; |
764 | |
765 | #if !defined(IS_MOBILE_PLATFORM) |
766 | std::vector<string> GetRemoteContexts() TF_LOCKS_EXCLUDED(remote_state_mu_); |
767 | bool IsRemoteContextsEmpty() TF_LOCKS_EXCLUDED(remote_state_mu_); |
768 | void CloseAndClearAllRemoteContexts(); |
769 | void CloseRemoteContexts(const std::vector<string>& remote_contexts, |
770 | uint64 context_id, uint64 context_view_id); |
771 | |
772 | // TODO(b/184375824): clean up parameter order for better readability. |
773 | Status SetMasterContextState( |
774 | std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env, |
775 | std::shared_ptr<WorkerSession> worker_session, |
776 | std::unique_ptr<eager::EagerClientCache> remote_eager_workers, |
777 | std::unique_ptr<DynamicDeviceMgr> remote_device_manager, |
778 | uint64 context_id, uint64 context_view_id, /*const*/ Rendezvous* r, |
779 | /*const*/ DeviceMgr* local_device_mgr, int keep_alive_secs, |
780 | DistributedFunctionLibraryRuntime* cluster_flr, |
781 | std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>> |
782 | remote_mgr); |
783 | |
784 | // The server_ is not const since we release it when the context is destroyed. |
785 | // Therefore the server_ object is not marked as const (even though it should |
786 | // be). |
787 | std::unique_ptr<ServerInterface> server_; |
788 | WorkerEnv* worker_env_ = nullptr; |
789 | std::shared_ptr<WorkerSession> worker_session_; |
790 | |
791 | mutable mutex remote_state_mu_; |
792 | |
793 | uint64 context_id_ TF_GUARDED_BY(remote_state_mu_); |
794 | // The view id of an eager context should be set to 0 when context is created, |
795 | // and continuously incremented when context with the same context_id gets |
796 | // updated. The view id should be consistent between master and workers. |
797 | uint64 context_view_id_ TF_GUARDED_BY(remote_state_mu_); |
798 | std::vector<string> remote_contexts_ TF_GUARDED_BY(remote_state_mu_); |
799 | std::unique_ptr<eager::EagerClientCache> remote_eager_workers_ |
800 | TF_GUARDED_BY(remote_state_mu_); |
801 | |
802 | int keep_alive_secs_ TF_GUARDED_BY(remote_state_mu_); |
803 | std::atomic<int> sleep_for_secs_; |
804 | |
805 | std::unique_ptr<Thread> keep_alive_thread_; |
806 | mutex keep_alive_thread_shutdown_mu_; |
807 | condition_variable keep_alive_thread_cv_; |
808 | bool shutting_down_ TF_GUARDED_BY(keep_alive_thread_shutdown_mu_) = false; |
809 | |
810 | std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>> |
811 | remote_mgr_; |
812 | bool is_master_ TF_GUARDED_BY(remote_state_mu_); |
813 | |
814 | // Maps from a remote worker to a list of parsed device filters. |
815 | std::unordered_map<string, std::vector<DeviceNameUtils::ParsedName>> |
816 | cluster_device_filters_ TF_GUARDED_BY(remote_state_mu_); |
817 | |
818 | // A distributed manager that helps setup, update, and check liveness of |
819 | // member tasks in the cluster. |
820 | std::unique_ptr<ImmediateExecutionDistributedManager> distributed_manager_; |
821 | |
822 | #endif // IS_MOBILE_PLATFORM |
823 | |
824 | // For a multi device function, the target device of each input is unknown |
825 | // until the function is instantiated on the default function device. |
826 | // If false, eagerly copy all remote inputs to the default function device; |
827 | // if true, lazily copy remote inputs to their target devices to avoid |
828 | // redundant copies. |
829 | bool lazy_copy_function_remote_inputs_ = false; |
830 | bool use_send_tensor_rpc_; |
831 | const bool pin_small_ops_to_cpu_; |
832 | |
833 | // Function that will be invoked in destructor to deallocate resources related |
834 | // to this context. |
835 | std::function<void()> resource_deallocator_ = nullptr; |
836 | bool run_eager_op_as_function_; |
837 | bool jit_compile_rewrite_; |
838 | }; |
839 | |
840 | inline EagerContext* ContextFromInterface(ImmediateExecutionContext* context) { |
841 | return down_cast<EagerContext*>(context); |
842 | } |
843 | |
844 | namespace internal { |
845 | struct EagerContextDeleter { |
846 | void operator()(EagerContext* p) const { |
847 | if (p != nullptr) { |
848 | p->Release(); |
849 | } |
850 | } |
851 | }; |
852 | } // namespace internal |
853 | |
854 | using EagerContextPtr = |
855 | std::unique_ptr<EagerContext, internal::EagerContextDeleter>; |
856 | |
857 | // Sets the EagerContext owned by the current Python eager Context (see |
858 | // TFE_Py_SetEagerContext in python/eager/pywrap_tfe.h). This is always called |
859 | // in tandem with TFE_Py_SetEagerContext (but not called by it, because its |
860 | // py_context argument is opaque). |
861 | // |
862 | // Do not use this function in production. It is only intended for testing. |
863 | // (see _reset_context in context.py). |
864 | // |
865 | // Not thread-safe. |
866 | void SetCEagerContext(EagerContext* ctx); |
867 | |
868 | // Returns the EagerContext owned by the current Python eager Context (see |
869 | // TFE_Py_SetEagerContext in pywrap_tfe.h). |
870 | // |
871 | // Not thread-safe. |
872 | EagerContext* GetCEagerContext(); |
873 | |
874 | } // namespace tensorflow |
875 | |
876 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_ |
877 | |