1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_ |
16 | #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_ |
17 | |
18 | #include "tensorflow/core/common_runtime/collective_executor_mgr.h" |
19 | #include "tensorflow/core/framework/collective.h" |
20 | |
21 | namespace tensorflow { |
22 | class CollectiveParamResolverDistributed; |
23 | class ConfigProto; |
24 | class DeviceMgr; |
25 | class DeviceResolverDistributed; |
26 | class WorkerCacheInterface; |
27 | class StepSequenceRequest; |
28 | class StepSequenceResponse; |
29 | |
30 | // An implementation of CollectiveExecutorMgr for a distributed environment |
31 | // that uses WorkerInterface::RecvBufAsync to route data transfers over RPCs. |
32 | // |
33 | // In some execution environments it may be possible to implement a |
34 | // higher-performance solution and use it in place of this class. |
35 | class RpcCollectiveExecutorMgr : public CollectiveExecutorMgr { |
36 | public: |
37 | RpcCollectiveExecutorMgr( |
38 | const ConfigProto& config, const DeviceMgr* dev_mgr, |
39 | std::unique_ptr<DeviceResolverDistributed> dev_resolver, |
40 | std::unique_ptr<CollectiveParamResolverDistributed> param_resolver, |
41 | std::unique_ptr<NcclCommunicatorInterface> nccl_communicator, |
42 | WorkerCacheInterface* worker_cache, const string& task_name); |
43 | |
44 | virtual ~RpcCollectiveExecutorMgr(); |
45 | |
46 | // This function should only be called at the group_leader, by an RPC. |
47 | // Other needs for StepIds should be satisfied by NextStepId. |
48 | void GetStepSequenceAsync(const GetStepSequenceRequest* request, |
49 | GetStepSequenceResponse* response, |
50 | const StatusCallback& done) override; |
51 | |
52 | void RefreshStepIdSequenceAsync(int64_t graph_key, |
53 | const StatusCallback& done) override; |
54 | |
55 | int64_t NextStepId(int64_t graph_key) override; |
56 | |
57 | void RetireStepId(int64_t graph_key, int64_t step_id) override; |
58 | |
59 | protected: |
60 | virtual CollectiveExecutor* Create(int64_t step_id) override; |
61 | |
62 | WorkerCacheInterface* const worker_cache_; // Not owned. |
63 | const string task_name_; |
64 | string group_leader_; |
65 | friend class RpcCollectiveExecutorMgrTest; |
66 | |
67 | private: |
68 | Status UpdateStepSequences(const GetStepSequenceResponse& resp); |
69 | |
70 | // This class maintains the step_id sequencing for a single |
71 | // collective_graph_key. |
72 | struct GraphKeySequence { |
73 | explicit GraphKeySequence(int64_t k) |
74 | : graph_key_(k), next_step_id_(CollectiveExecutor::kInvalidId) {} |
75 | |
76 | const int64_t graph_key_; |
77 | int64_t next_step_id_; |
78 | }; |
79 | |
80 | mutex sequence_mu_; |
81 | gtl::FlatMap<int64_t, GraphKeySequence*> sequence_table_ |
82 | TF_GUARDED_BY(sequence_mu_); |
83 | }; |
84 | |
85 | // Creates a distributed CollectiveExecutorMgr with production implementations |
86 | // of each components. Cases that need to inject other implementations of these |
87 | // components should call CollectiveExecutorMgr constructor directly. |
88 | std::unique_ptr<RpcCollectiveExecutorMgr> CreateProdRpcCollectiveExecutorMgr( |
89 | const ConfigProto& config, const DeviceMgr* device_mgr, |
90 | std::unique_ptr<NcclCommunicatorInterface> nccl_communicator, |
91 | WorkerCacheInterface* worker_cache, const string& default_worker_name); |
92 | |
93 | } // namespace tensorflow |
94 | #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_ |
95 | |