1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
16#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
17
18#include <algorithm>
19#include <cstddef>
20#include <map>
21#include <memory>
22#include <queue>
23#include <string>
24#include <unordered_set>
25#include <vector>
26
27#include "absl/container/flat_hash_map.h"
28#include "absl/types/optional.h"
29#include "tensorflow/c/eager/immediate_execution_context.h"
30#include "tensorflow/core/common_runtime/composite_device.h"
31#include "tensorflow/core/common_runtime/device_factory.h"
32#include "tensorflow/core/common_runtime/device_mgr.h"
33#include "tensorflow/core/common_runtime/eager/custom_device.h"
34#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
35#include "tensorflow/core/common_runtime/eager/eager_executor.h"
36#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
37#include "tensorflow/core/common_runtime/function.h"
38#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
39#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
40#include "tensorflow/core/example/example.pb.h"
41#include "tensorflow/core/framework/collective.h"
42#include "tensorflow/core/framework/function.h"
43#include "tensorflow/core/framework/log_memory.h"
44#include "tensorflow/core/framework/rendezvous.h"
45#include "tensorflow/core/framework/tensor.h"
46#include "tensorflow/core/lib/core/status.h"
47#include "tensorflow/core/lib/core/stringpiece.h"
48#include "tensorflow/core/lib/core/threadpool.h"
49#include "tensorflow/core/lib/gtl/flatmap.h"
50#include "tensorflow/core/lib/gtl/flatset.h"
51#include "tensorflow/core/lib/gtl/inlined_vector.h"
52#include "tensorflow/core/lib/gtl/map_util.h"
53#include "tensorflow/core/platform/casts.h"
54#include "tensorflow/core/platform/env.h"
55#include "tensorflow/core/platform/fingerprint.h"
56#include "tensorflow/core/platform/mutex.h"
57#include "tensorflow/core/platform/platform.h"
58#include "tensorflow/core/platform/status.h"
59#include "tensorflow/core/platform/thread_annotations.h"
60#include "tensorflow/core/platform/threadpool.h"
61#include "tensorflow/core/public/session_options.h"
62#include "tensorflow/core/public/version.h"
63#include "tensorflow/core/util/device_name_utils.h"
64
65// "tensorflow/core/platform/platform.h" must be included first before using
66// IS_MOBILE_PLATFORM.
67#if !defined(IS_MOBILE_PLATFORM)
68#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
69#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
70#include "tensorflow/core/distributed_runtime/server_lib.h"
71#include "tensorflow/core/distributed_runtime/worker_cache.h"
72#include "tensorflow/core/distributed_runtime/worker_env.h"
73#endif // !IS_MOBILE_PLATFORM
74
75namespace tensorflow {
76
77namespace eager {
78// We need this forward declaration because we have circular dependency:
79// Context -> RemoteMgr -> TensorHandle -> Context.
80// TODO(fishx): Remove this once we remove Context dependency in TensorHandle.
81class RemoteMgr;
82} // namespace eager
83
84class TensorHandle;
85class EagerOperation;
86
87class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
88 public:
89 static constexpr uint64 kInvalidContextId = 0;
90
91 static uint64 NewContextId() {
92 uint64 context_id = random::New64();
93 while (context_id == kInvalidContextId) {
94 context_id = random::New64();
95 }
96 return context_id;
97 }
98
99 EagerContext(
100 const SessionOptions& opts,
101 ContextDevicePlacementPolicy default_device_placement_policy, bool async,
102 /*const*/ DeviceMgr* device_mgr, bool device_mgr_owned,
103 /*const*/ Rendezvous* rendezvous,
104 DistributedFunctionLibraryRuntime* cluster_flr = nullptr,
105 CollectiveExecutorMgrInterface* collective_executor_mgr = nullptr,
106 bool run_eager_op_as_function = false, bool jit_compile_rewrite = false);
107
108 void Release() override { Unref(); }
109
110 AbstractTensorInterface* CreateInt64Scalar(int64_t value) override;
111 AbstractTensorInterface* CreateUint64Scalar(uint64 value) override;
112 AbstractTensorInterface* CreateInt32Scalar(int32_t value) override;
113 AbstractTensorInterface* CreateFloatScalar(float value) override;
114 AbstractTensorInterface* CreateDoubleScalar(double value) override;
115 AbstractTensorInterface* CreateHalfScalar(Eigen::half value) override;
116 AbstractTensorInterface* CreateStringScalar(
117 tensorflow::tstring value) override;
118 AbstractTensorInterface* CreateComplex128Scalar(
119 tensorflow::complex128 value) override;
120 AbstractTensorInterface* CreateBoolScalar(bool value) override;
121
122 AbstractTensorInterface* CreateTensor(
123 DataType dtype, absl::Span<const int64_t> dim_sizes) override;
124 AbstractTensorInterface* CreateTensor(DataType dtype, const int64_t* dims,
125 int num_dims, void* data, size_t len,
126 MemoryReleaser memory_releaser,
127 void* memory_releaser_arg) override;
128
129 ImmediateExecutionTensorHandle* CreateLocalHandle(
130 AbstractTensorInterface* t) override;
131 // Create an abstract tensor handle from tensorflow::Tensor.
132 ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor(
133 tensorflow::Tensor& t, const char* d_name) override;
134 ImmediateExecutionTensorHandle* CopyTensorHandleToDevice(
135 ImmediateExecutionTensorHandle* handle, const char* device_name,
136 Status* status) override;
137 ImmediateExecutionOperation* CreateOperation() override;
138
139 // This is a virtual helper function to convert TFRT TensorHandle to
140 // tensorflow::TensorHandle. In current runtime EagerContext, just forward
141 // the input since the input tensor handle is already a
142 // tensorflow::TensorHandle.
143 ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
144 ImmediateExecutionTensorHandle* handle) override;
145
146 Status RegisterFunction(AbstractFunction* f) override;
147
148 bool UsesTFRT() override;
149
150 bool RunEagerOpAsFunction() const;
151
152 void SetRunEagerOpAsFunction(bool enable) override;
153
154 bool JitCompileRewrite() const;
155
156 void SetJitCompileRewrite(bool enable) override;
157
158 void ListDevices(std::vector<DeviceAttributes>* devices) override;
159
160 Status AddDevices(std::vector<std::unique_ptr<Device>> devices) override;
161
162 thread::ThreadPool* GetThreadPool() { return thread_pool_.get(); }
163
164 // Returns the function library runtime for the given device.
165 FunctionLibraryRuntime* func_lib(const Device* d) const {
166 return pflr_->GetFLR(d->name());
167 }
168
169 ProcessFunctionLibraryRuntime* pflr() const { return pflr_.get(); }
170
171 std::function<void(std::function<void()>)>* runner() { return &runner_; }
172
173 // Specify a executor for this thread.
174 void SetExecutorForThread(EagerExecutor* executor) override;
175
176 const std::shared_ptr<std::vector<DeviceType>> prioritized_device_type_list()
177 const {
178 mutex_lock l(device_type_list_mu_);
179 return prioritized_device_type_list_;
180 }
181
182 // Clear pending nodes in thread executors and kernel caches.
183 void ClearCachesAndThreadExecutors() override;
184 // Clear pending nodes in default executor and kernel caches.
185 void ClearCachesAndDefaultExecutor();
186
187 // Sets the device placement policy for the current thread.
188 void SetThreadLocalDevicePlacementPolicy(
189 ContextDevicePlacementPolicy policy) override;
190
191 // Returns the device placement policy for the current thread.
192 ContextDevicePlacementPolicy GetDevicePlacementPolicy() const override;
193
194 // Select an appropriate device for an operation.
195 //
196 // Given the preferred device for the operation, and the node_def, finds the
197 // best suitable device for the operation in this context.
198 //
199 // The preferred device is specified as a `ParsedName` containing the elements
200 // (details) that the resulting device should match. If there are no such
201 // devices, and the context currently allows soft device placement, a suitable
202 // device not matching `preferred` will be chosen.
203 //
204 // The chosen device is stored in the `device` argument. The argument is not
205 // modified unless this method returns `OkStatus()`.
206 Status SelectDevice(DeviceNameUtils::ParsedName preferred,
207 const NodeDef& ndef, Device** out) const;
208
209 // TODO(mdan): Rename to ContainsFunction.
210 bool FindFunctionByName(const string& name) const;
211
212 Status FindFunctionOpData(const string& name,
213 const tensorflow::OpRegistrationData** op_data);
214
215 const FunctionDef* FindFunctionDef(const string& name) const override;
216
217 Device* HostCPU() const { return host_cpu_device_; }
218 Device* CanonicalDevice(Device* d) const {
219 return HostCPU() == d ? nullptr : d;
220 }
221 const DeviceNameUtils::ParsedName& HostCPUParsedName() const override {
222 return HostCPU()->parsed_name();
223 }
224
225 const string& HostCPUName() const override { return HostCPU()->name(); }
226
227 GraphCollector* GetGraphCollector() { return &graph_collector_; }
228
229 EagerExecutor& Executor() override;
230
231 // Add the given `fdef` to the local FunctionLibraryDefinition. And add an
232 // entry to the KernelAndDevice cache for it if it's not exist.
233 Status AddFunctionDef(const FunctionDef& fdef) override;
234
235 Status AddFunctionDefWithStackTraces(
236 const FunctionDef& fdef, const StackTracesMap& stack_traces) override;
237
238 // `library` contains all FunctionDefs and GradientDefs to expand `fdef`. Add
239 // it to the local FunctionLibraryDefinition as well, but no need to add it
240 // to the KernelAndDevice cache since they won't be executed as
241 // KernelAndDevices.
242 Status AddFunctionDef(const FunctionDef& fdef,
243 const FunctionDefLibrary& library,
244 bool add_to_local_only = false,
245 const StackTracesMap& stack_traces = {});
246
247 const FunctionDef* GetFunctionDef(const string& function_name);
248
249 std::vector<string> ListFunctionNames() override;
250
251 Status RemoveFunction(const string& func) override;
252
253 // Wait for pending nodes to be finished in local executors (including context
254 // default executor and thread executors) and executors on remote workers.
255 // Return combined status of remote executors. If there are multiple errors,
256 // the Status code will be the same as the first remote executor that has
257 // errors, and the error message will be combined from all executors.
258 Status SyncExecutors();
259
260 Status AsyncWait() override { return SyncExecutors(); }
261
262 core::RefCountPtr<KernelAndDevice> GetCachedKernel(Fprint128 cache_key);
263 Device* GetCachedDevice(Fprint128 device_cache_key);
264
265 void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
266 void AddDeviceToCache(Fprint128 device_cache_key, Device* device);
267
268 bool LogDevicePlacement() const { return log_device_placement_; }
269 void SetLogDevicePlacement(bool enable) override {
270 log_device_placement_ = enable;
271 }
272
273 // When tensor transfer across functions/eager executions using send/recv ops
274 // are required, `reuse_rendezvous_for_functions_` can be set to true so that
275 // function executions and eager executions use the same rendezvous instance,
276 // instead of creating new instance per function calls.
277 void SetReuseRendezvousForFunctions(
278 bool reuse_rendezvous_for_functions) override {
279 reuse_rendezvous_for_functions_ = reuse_rendezvous_for_functions;
280 }
281 bool GetReuseRendezvousForFunctions() const {
282 return reuse_rendezvous_for_functions_;
283 }
284 mutex* reuse_rendezvous_for_functions_mu() {
285 return &reuse_rendezvous_for_functions_mu_;
286 }
287
288 bool AllowSoftPlacement() const { return allow_soft_placement_; }
289 void SetAllowSoftPlacement(bool enable) override {
290 allow_soft_placement_ = enable;
291 }
292 bool LogMemory() const { return log_memory_; }
293
294 Rendezvous* GetRendezvous() const { return rendezvous_; }
295
296 void ResetGlobalRendezvousForFunction() override {
297 mutex_lock l(global_rendezvous_mu_);
298 // Remove the global rendezvous instance from the local rendezvous table
299 // if it uses local rendezvous type, which forces EagerContext to create a
300 // new local rendezvous instance in the table.
301 local_rendezvous_table_->Remove(-1);
302 global_rendezvous_for_functions_ =
303 core::RefCountPtr<Rendezvous>(CreateRendezvous(-1));
304 }
305
306 // Returns the global_rendezvous_for_functions' underlying LocalRendezvous'
307 // status. If the underlying Rendezvous is not in the local_rendezvous_table_
308 // returns OK.
309 Status GetGlobalRendezvousForFunctionLocalRendezvousStatus();
310
311 // Returns a function which maps from step_id to rendezvous. This closure
312 // respects the value of `SetReuseRendezvousForFunctions` at the time the
313 // closure was created, which allows the setting to be toggled around async op
314 // launches.
315 //
316 // The caller of the returned function owns a reference to the resulting
317 // Rendezvous.
318 std::function<Rendezvous*(int64_t)> RendezvousCreator() {
319 // There is an implicit assumption that the global_rendezvous_for_functions_
320 // is always an IntraProcessRendezvous to match the behaviour of the
321 // EagerContext's rendezvous.
322 // Ref: tensorflow/c/eager/c_api.cc;l=143;rcl=396387348
323 // If a cross process kernel needs a rendezvous a new InterProcessRendezvous
324 // should be created.
325 if (reuse_rendezvous_for_functions_ && rendezvous_creator_ == nullptr &&
326#if !defined(IS_MOBILE_PLATFORM)
327 worker_env_ == nullptr &&
328#endif
329 remote_device_mgr() == nullptr) {
330 return [this](int64_t step_id) {
331 mutex_lock l(global_rendezvous_mu_);
332 global_rendezvous_for_functions_->Ref();
333 return global_rendezvous_for_functions_.get();
334 };
335 } else {
336 return [this](int64_t step_id) { return CreateRendezvous(step_id); };
337 }
338 }
339
340 CollectiveExecutorMgrInterface* collective_executor_mgr() {
341 return collective_executor_mgr_.Get();
342 }
343 std::unique_ptr<CollectiveExecutor::Handle> GetCollectiveExecutorHandle() {
344 return std::unique_ptr<CollectiveExecutor::Handle>(
345 new CollectiveExecutor::Handle(
346 collective_executor_mgr()->FindOrCreate(0), true /*inherit_ref*/));
347 }
348
349 tensorflow::DeviceMgr* local_device_mgr() const {
350 return local_device_manager_.Get();
351 }
352 const tensorflow::DynamicDeviceMgr* remote_device_mgr() const {
353 return remote_device_manager_.Get();
354 }
355
356 tensorflow::DynamicDeviceMgr* GetOwnedRemoteDeviceMgr() {
357 return remote_device_manager_.GetOwned();
358 }
359
360 std::vector<Device*> ListLocalTfDevices() override {
361 return local_device_mgr()->ListDevices();
362 }
363
364 std::vector<Device*> ListAllTfDevices() override;
365
366 // TODO(apassos) clean up RunMetadata storage.
367 mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; }
368 bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_);
369 void SetShouldStoreGraphs(bool value) override;
370 RunMetadata* RunMetadataProto() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_) {
371 return run_metadata_.get();
372 }
373 std::unique_ptr<RunMetadata> ExportRunMetadata() override
374 TF_LOCKS_EXCLUDED(metadata_mu_);
375
376 void StartStep() override;
377 void EndStep() override;
378 ScopedStepContainer* StepContainer();
379
380 FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; }
381
382#if !defined(IS_MOBILE_PLATFORM)
383 // Assign the EagerClient pointer to `client` based on the given device / task
384 // name, and increment the refcount of the client. The reference ownership is
385 // transferred to the caller, and the unref should automatically happen when
386 // destructing the RefCountPtr object at the caller's side.
387 // `client` must not be initialized or holding a reference of another object
388 // before calling this method.
389 Status GetClient(Device* device,
390 core::RefCountPtr<eager::EagerClient>* client);
391 Status GetClient(const DeviceNameUtils::ParsedName& device_name,
392 core::RefCountPtr<eager::EagerClient>* client);
393 Status GetClient(const string& remote_task,
394 core::RefCountPtr<eager::EagerClient>* client);
395
396 uint64 GetContextId() const;
397 uint64 GetContextViewId() const;
398 void IncrementContextViewId();
399
400 Status EnableCollectiveOps(const ServerDef& server_def) override;
401
402 // TODO(nareshmodi): Encapsulate remote state into a separate
403 // class/struct.
404 //
405 // Enables the eager context to communicate with remote devices. When
406 // initializing with this method, this context will be the primary context,
407 // which will kill all its remote contexts in shutdown.
408 //
409 // - server: A ServerInterface that exports the tensorflow.WorkerService.
410 // Note that this class expects the server to already have been started.
411 // - remote_eager_workers: A cache from which we can get "EagerClient"s to
412 // communicate with remote eager services.
413 // - remote_device_mgr: A DeviceMgr* which contains all remote devices
414 // (should contain no local devices).
415 // - remote_contexts: A vector containing task names.
416 // TODO(b/184375824): clean up parameter order for better readability.
417 Status InitializeRemoteMaster(
418 std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
419 std::shared_ptr<WorkerSession> worker_session,
420 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
421 std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
422 const std::vector<string>& remote_contexts, uint64 context_id,
423 /*const*/ Rendezvous* r, /*const*/ DeviceMgr* local_device_mgr,
424 int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr,
425 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
426 remote_mgr);
427
428 // Update an existing master context with a new set of remote workers (i.e., a
429 // new "view" of cluster membership. Similar to InitializeRemoteMaster but
430 // this will keep the current context_id and increment a context_view_id, will
431 // keep the current resource manager so that resources from the previous view
432 // can still be accessed, and will automatically register existing functions
433 // if there are newly added hosts.
434 Status UpdateRemoteMaster(
435 uint64 context_id,
436 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
437 const std::vector<string>& add_remote_contexts,
438 const std::vector<string>& remove_remote_contexts);
439
440 // Similar with InitializeRemoteMaster but this context will not kill remote
441 // contexts in shutdown.
442 Status InitializeRemoteWorker(
443 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
444 DynamicDeviceMgr* remote_device_mgr,
445 const std::vector<string>& remote_contexts, uint64 context_id,
446 uint64 context_view_id,
447 std::function<Rendezvous*(const int64_t)> rendezvous_creator,
448 DistributedFunctionLibraryRuntime* cluster_flr,
449 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
450 remote_mgr,
451 std::function<void()> resource_deallocator);
452
453 // Similar with InitializeRemoteWorker but will reuse existing context and
454 // increment context_view_id.
455 Status UpdateRemoteWorker(
456 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
457 const std::vector<string>& remote_contexts, uint64 context_id);
458
459 Status StoreCollectiveOpsServer(
460 std::unique_ptr<ServerInterface> new_server, DeviceMgr* device_mgr,
461 CollectiveExecutorMgrInterface* rpc_collective_executor_mgr);
462
463 // For the specified remote worker, preprocess and set its device filters.
464 Status SetRemoteDeviceFilters(const string& remote_worker,
465 const std::vector<string>& device_filters);
466
467 // For the specified remote worker, apply the stored device filters to the
468 // list of device attributes following these rules:
469 // (1) if the remote worker does not have device filters, all devices are
470 // visible to the worker;
471 // (2) if the device is on the remote worker, then it is visible;
472 // (3) if the device matches at least one device filter, then it is visible.
473 // The result is saved as a boolean vector of the same length (i.e.,
474 // filtered_device_mask) indicating whether each of the devices is visible to
475 // the remote worker.
476 void FilterDevicesForRemoteWorkers(
477 const string& remote_worker,
478 const protobuf::RepeatedPtrField<DeviceAttributes>& device_attrs,
479 std::vector<bool>* filtered_device_mask);
480
481 // TODO(fishx): Remove the custom deleter once we remove forward declaration.
482 const std::unique_ptr<eager::RemoteMgr,
483 std::function<void(eager::RemoteMgr*)>>&
484 RemoteMgr() {
485 return remote_mgr_;
486 }
487
488 // If true, then tensors should be shipped across processes via the
489 // EagerService.Enqueue(SendTensorOp). If false, _Send/_Recv ops should be
490 // used instead (which in-turn use WorkerService.RecvTensor RPCs).
491 bool UseSendTensorRPC() { return use_send_tensor_rpc_; }
492
493 tensorflow::ServerInterface* GetServer() { return server_.get(); }
494
495 // For LLVM style RTTI.
496 static bool classof(const AbstractContext* ptr) {
497 return ptr->getKind() == kEager;
498 }
499
500 // Function to support distributed C API.
501 void SetDistributedManager(
502 std::unique_ptr<ImmediateExecutionDistributedManager> distributed)
503 override {
504 distributed_manager_ = std::move(distributed);
505 }
506 ImmediateExecutionDistributedManager* GetDistributedManager() override {
507 return distributed_manager_.get();
508 }
509
510 // May only be used during multi-client setup so that a RemoteRendezvous
511 // can be initialized instead of defaulting to the IntraProcessRendezvous.
512 void SetWorkerEnv(WorkerEnv* worker_env,
513 std::shared_ptr<WorkerSession> worker_session);
514#endif // IS_MOBILE_PLATFORM
515
516 // Closes remote eager contexts, waits for all RPCs to finish, and
517 // destroys the EagerClientCache. No RPCs can be made through this context
518 // after this method has been called.
519 // This method exists to aid a clean shutdown. It causes all RPCs to finish
520 // and remote TensorHandles to release their references to this context.
521 // To avoid deadlocks, this method must not be called on the thread
522 // processing RPCs because it makes RPCs and waits for their completion.
523 //
524 // On mobile, it just cleans the caches.
525 void WaitForAndCloseRemoteContexts();
526
527 bool PinSmallOpsToCPU() const { return pin_small_ops_to_cpu_; }
528
529 tensorflow::Env* TFEnv() const { return env_; }
530
531 Status FindDeviceFromName(const char* device_name, Device** device) const;
532
533 Status FindCompositeDeviceFromName(StringPiece device_name,
534 CompositeDevice** device) const;
535
536 bool IsCustomDevice(const string& device_name) override;
537
538 Status RegisterCustomDevice(const string& name,
539 std::unique_ptr<CustomDevice> device) override;
540
541 CustomDeviceOpHandler& GetCustomDeviceOpHandler() override {
542 return custom_device_op_handler_;
543 };
544
545 // Find or create a composite device with the given `underlying_devices` and
546 // `device_name` (if not empty).
547 Status FindOrCreateCompositeDevice(
548 const std::vector<string>& underlying_devices, const string& device_name,
549 CompositeDevice** composite_device);
550
551 bool OnSameTask(const Device* first, const Device* second) const;
552 // Gets the CPU device on the task of device.
553 Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const;
554
555 const SessionOptions& session_options() const { return opts_; }
556 void InitPrioritizedDeviceTypeList();
557
558 // Re-assign cluster-FLR and re-initialize devices and FLR in process-FLR
559 void UpdateClusterFLRAndInitDevices(
560 DistributedFunctionLibraryRuntime* cluster_flr);
561
562 // A constant representing the step id used for the global rendezvous.
563 // This is used to distibguish whether a user-specified step id should be set.
564 // Step id value of kGlobalRendezvous is reserved and should not be specified
565 // by the user.
566 static const int64_t kGlobalRendezvousId;
567
568 private:
569 // The class for wrapping a map of step_id to local rendezvous instances.
570 class LocalRendezvousTable {
571 public:
572 LocalRendezvousTable() = default;
573 ~LocalRendezvousTable();
574
575 IntraProcessRendezvous* FindOrCreate(int64_t step_id,
576 DeviceMgr* device_mgr);
577 IntraProcessRendezvous* Find(int64_t step_id);
578 void Remove(int64_t step_id);
579 void CleanUpAll();
580
581 private:
582 mutable mutex table_lock_;
583 absl::flat_hash_map<int64_t, IntraProcessRendezvous*> table_
584 TF_GUARDED_BY(table_lock_);
585 };
586
587 Rendezvous* CreateRendezvous(int64_t step_id) const {
588 if (rendezvous_creator_ != nullptr) {
589 VLOG(6) << "Creating rendezvous using the rendezvous_creator_.";
590 return rendezvous_creator_(step_id);
591 }
592
593#if !defined(IS_MOBILE_PLATFORM)
594 if (worker_env_ != nullptr && worker_env_->rendezvous_mgr != nullptr) {
595 VLOG(6) << "Creating rendezvous using the worker_env's rendezvous_mgr.";
596 auto* remote_r = worker_env_->rendezvous_mgr->Find(step_id);
597 remote_r->Initialize(worker_session_.get()).IgnoreError();
598 return remote_r;
599 }
600#endif
601
602 if (remote_device_mgr() == nullptr) {
603 VLOG(6) << "Creating rendezvous using local_device_mgr.";
604 return local_rendezvous_table_->FindOrCreate(step_id, local_device_mgr());
605 }
606
607 return nullptr;
608 }
609
610 ~EagerContext() override;
611
612 Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef);
613 Status RegisterExistingFunctionsOnRemoteWorkers(
614 const std::vector<string>& remote_workers);
615
616 void ResetPFLR(const DeviceMgr* device_mgr, Env* env,
617 const ConfigProto* config, int graph_def_version,
618 const FunctionLibraryDefinition* lib_def,
619 const OptimizerOptions& optimizer_options,
620 thread::ThreadPool* thread_pool = nullptr,
621 DistributedFunctionLibraryRuntime* cluster_flr = nullptr);
622
623 void ResetClusterFLR(DistributedFunctionLibraryRuntime* cluster_flr);
624 void UpdateGlobalRendezvousDeviceManager(tensorflow::DeviceMgr* device_mgr);
625
626 void ClearResourceContainer(const string& name);
627
628 template <typename T>
629 struct OwnedOrUnownedHelper {
630 public:
631 OwnedOrUnownedHelper() {}
632 explicit OwnedOrUnownedHelper(T* object, const bool owned = false) {
633 Reset(object, owned);
634 }
635
636 void Reset(std::unique_ptr<T> object) {
637 owned_object = std::move(object);
638 unowned_object_ptr = nullptr;
639 }
640
641 void Reset(T* object, const bool owned = false) {
642 if (owned) {
643 owned_object.reset(object);
644 unowned_object_ptr = nullptr;
645 } else {
646 owned_object.reset(nullptr);
647 unowned_object_ptr = object;
648 }
649 }
650
651 bool Owned() const { return owned_object != nullptr; }
652
653 T* GetOwned() const { return owned_object.get(); }
654 T* Get() const {
655 return owned_object ? owned_object.get() : unowned_object_ptr;
656 }
657
658 std::unique_ptr<T> owned_object = nullptr;
659 T* unowned_object_ptr = nullptr;
660 };
661
662 SessionOptions opts_;
663 const ContextDevicePlacementPolicy default_device_placement_policy_;
664
665 // Note: we cannot use C++11 thread_local here as there is no concept of a
666 // thread-local-object-local variable in C++11.
667 mutable mutex policy_map_mu_;
668 std::unordered_map<std::thread::id, ContextDevicePlacementPolicy>
669 device_placement_policy_ TF_GUARDED_BY(policy_map_mu_);
670
671 // This device manager maintains only the local devices on this worker.
672 OwnedOrUnownedHelper<DeviceMgr> local_device_manager_;
673 // Maintain copy of all previously created local device managers.
674 std::vector<std::unique_ptr<DeviceMgr>> old_local_device_managers_;
675
676 // Unowned DynamicDeviceMgr is set on remote worker to allow running
677 // multi-device function on remote worker.
678 // This device manager maintains all the devices (including both local and
679 // remote to this worker) in the cluster.
680 OwnedOrUnownedHelper<DynamicDeviceMgr> remote_device_manager_;
681
682 Device* host_cpu_device_; // Owned by device_manager
683 mutable mutex device_type_list_mu_;
684 std::shared_ptr<std::vector<DeviceType>> prioritized_device_type_list_
685 TF_GUARDED_BY(device_type_list_mu_);
686 Rendezvous* rendezvous_;
687 std::function<Rendezvous*(const int64_t)> rendezvous_creator_;
688 CustomDeviceOpHandler custom_device_op_handler_;
689
690 mutable mutex composite_devices_mu_;
691 // Maps from the fingerprint of a set of device names to a virtual
692 // CompositeDevice.
693 // TODO(b/145922293): Consider taking device names as keys.
694 absl::flat_hash_map<uint64, std::unique_ptr<CompositeDevice>>
695 composite_devices_ ABSL_GUARDED_BY(composite_devices_mu_);
696
697 FunctionLibraryDefinition func_lib_def_{OpRegistry::Global(), {}};
698
699 std::unique_ptr<thread::ThreadPool> thread_pool_;
700
701 // EagerContext owns the DistributedFunctionLibraryRuntime(
702 // EagerClusterFunctionLibraryRuntime) if using EagerService for remote
703 // function execution (lazy_copy_function_remote_inputs_=true).
704 OwnedOrUnownedHelper<DistributedFunctionLibraryRuntime> cluster_flr_;
705 // One FunctionLibraryRuntime per device.
706 // func_libs[i] is the FunctionLibraryRuntime corresponding to
707 // session->devices[i].
708 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
709
710 std::function<void(std::function<void()>)> runner_;
711
712 mutex cache_mu_;
713 mutex device_cache_mu_;
714 struct RegisteredFunction : public core::RefCounted {
715 ~RegisteredFunction() override {}
716
717 std::unique_ptr<std::vector<Fprint128>> cached_kernel_keys;
718 };
719 std::unordered_map<Fprint128, core::RefCountPtr<KernelAndDevice>,
720 Fprint128Hasher>
721 kernel_cache_ TF_GUARDED_BY(cache_mu_);
722 std::unordered_map<string, RegisteredFunction*> registered_functions_
723 TF_GUARDED_BY(cache_mu_);
724 absl::flat_hash_map<Fprint128, Device*, Fprint128Hasher> device_cache_
725 TF_GUARDED_BY(device_cache_mu_);
726
727 // Whether we should compute RunMetadata.
728 std::atomic<bool> should_store_graphs_{false};
729 mutex metadata_mu_;
730 std::unique_ptr<RunMetadata> run_metadata_ TF_GUARDED_BY(metadata_mu_);
731 GraphCollector graph_collector_;
732 std::atomic<bool> log_device_placement_;
733 std::atomic<bool> allow_soft_placement_;
734
735 // Information related to step containers.
736 std::atomic<int> num_active_steps_;
737 std::unique_ptr<ScopedStepContainer> step_container_
738 TF_GUARDED_BY(metadata_mu_);
739
740 EagerExecutor default_executor_;
741 mutable mutex executor_map_mu_;
742 // Not owned.
743 std::unordered_map<std::thread::id, EagerExecutor*> thread_local_executor_
744 TF_GUARDED_BY(executor_map_mu_);
745 std::unordered_map<std::thread::id, std::unordered_set<EagerExecutor*>>
746 has_cleanup_ TF_GUARDED_BY(executor_map_mu_);
747
748 const bool log_memory_;
749
750 // The table of local rendezvous instances for intra-process communication.
751 // This make sures only one local rendezvous instance exists per step id.
752 std::unique_ptr<LocalRendezvousTable> local_rendezvous_table_;
753
754 // Whether to use same rendezvous instance across function/eager executions.
755 std::atomic<bool> reuse_rendezvous_for_functions_{false};
756 mutable mutex global_rendezvous_mu_;
757 core::RefCountPtr<Rendezvous> global_rendezvous_for_functions_
758 TF_GUARDED_BY(global_rendezvous_mu_);
759 mutex reuse_rendezvous_for_functions_mu_;
760
761 Env* const env_;
762
763 OwnedOrUnownedHelper<CollectiveExecutorMgrInterface> collective_executor_mgr_;
764
765#if !defined(IS_MOBILE_PLATFORM)
766 std::vector<string> GetRemoteContexts() TF_LOCKS_EXCLUDED(remote_state_mu_);
767 bool IsRemoteContextsEmpty() TF_LOCKS_EXCLUDED(remote_state_mu_);
768 void CloseAndClearAllRemoteContexts();
769 void CloseRemoteContexts(const std::vector<string>& remote_contexts,
770 uint64 context_id, uint64 context_view_id);
771
772 // TODO(b/184375824): clean up parameter order for better readability.
773 Status SetMasterContextState(
774 std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
775 std::shared_ptr<WorkerSession> worker_session,
776 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
777 std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
778 uint64 context_id, uint64 context_view_id, /*const*/ Rendezvous* r,
779 /*const*/ DeviceMgr* local_device_mgr, int keep_alive_secs,
780 DistributedFunctionLibraryRuntime* cluster_flr,
781 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
782 remote_mgr);
783
784 // The server_ is not const since we release it when the context is destroyed.
785 // Therefore the server_ object is not marked as const (even though it should
786 // be).
787 std::unique_ptr<ServerInterface> server_;
788 WorkerEnv* worker_env_ = nullptr;
789 std::shared_ptr<WorkerSession> worker_session_;
790
791 mutable mutex remote_state_mu_;
792
793 uint64 context_id_ TF_GUARDED_BY(remote_state_mu_);
794 // The view id of an eager context should be set to 0 when context is created,
795 // and continuously incremented when context with the same context_id gets
796 // updated. The view id should be consistent between master and workers.
797 uint64 context_view_id_ TF_GUARDED_BY(remote_state_mu_);
798 std::vector<string> remote_contexts_ TF_GUARDED_BY(remote_state_mu_);
799 std::unique_ptr<eager::EagerClientCache> remote_eager_workers_
800 TF_GUARDED_BY(remote_state_mu_);
801
802 int keep_alive_secs_ TF_GUARDED_BY(remote_state_mu_);
803 std::atomic<int> sleep_for_secs_;
804
805 std::unique_ptr<Thread> keep_alive_thread_;
806 mutex keep_alive_thread_shutdown_mu_;
807 condition_variable keep_alive_thread_cv_;
808 bool shutting_down_ TF_GUARDED_BY(keep_alive_thread_shutdown_mu_) = false;
809
810 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
811 remote_mgr_;
812 bool is_master_ TF_GUARDED_BY(remote_state_mu_);
813
814 // Maps from a remote worker to a list of parsed device filters.
815 std::unordered_map<string, std::vector<DeviceNameUtils::ParsedName>>
816 cluster_device_filters_ TF_GUARDED_BY(remote_state_mu_);
817
818 // A distributed manager that helps setup, update, and check liveness of
819 // member tasks in the cluster.
820 std::unique_ptr<ImmediateExecutionDistributedManager> distributed_manager_;
821
822#endif // IS_MOBILE_PLATFORM
823
824 // For a multi device function, the target device of each input is unknown
825 // until the function is instantiated on the default function device.
826 // If false, eagerly copy all remote inputs to the default function device;
827 // if true, lazily copy remote inputs to their target devices to avoid
828 // redundant copies.
829 bool lazy_copy_function_remote_inputs_ = false;
830 bool use_send_tensor_rpc_;
831 const bool pin_small_ops_to_cpu_;
832
833 // Function that will be invoked in destructor to deallocate resources related
834 // to this context.
835 std::function<void()> resource_deallocator_ = nullptr;
836 bool run_eager_op_as_function_;
837 bool jit_compile_rewrite_;
838};
839
840inline EagerContext* ContextFromInterface(ImmediateExecutionContext* context) {
841 return down_cast<EagerContext*>(context);
842}
843
844namespace internal {
845struct EagerContextDeleter {
846 void operator()(EagerContext* p) const {
847 if (p != nullptr) {
848 p->Release();
849 }
850 }
851};
852} // namespace internal
853
854using EagerContextPtr =
855 std::unique_ptr<EagerContext, internal::EagerContextDeleter>;
856
857// Sets the EagerContext owned by the current Python eager Context (see
858// TFE_Py_SetEagerContext in python/eager/pywrap_tfe.h). This is always called
859// in tandem with TFE_Py_SetEagerContext (but not called by it, because its
860// py_context argument is opaque).
861//
862// Do not use this function in production. It is only intended for testing.
863// (see _reset_context in context.py).
864//
865// Not thread-safe.
866void SetCEagerContext(EagerContext* ctx);
867
868// Returns the EagerContext owned by the current Python eager Context (see
869// TFE_Py_SetEagerContext in pywrap_tfe.h).
870//
871// Not thread-safe.
872EagerContext* GetCEagerContext();
873
874} // namespace tensorflow
875
876#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
877