1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h" |
16 | |
17 | #include "tensorflow/core/common_runtime/base_collective_executor.h" |
18 | #include "tensorflow/core/common_runtime/collective_executor_mgr.h" |
19 | #include "tensorflow/core/common_runtime/collective_rma_local.h" |
20 | #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" |
21 | #include "tensorflow/core/distributed_runtime/collective_rma_distributed.h" |
22 | #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" |
23 | #include "tensorflow/core/distributed_runtime/worker_cache.h" |
24 | #include "tensorflow/core/lib/random/random.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | RpcCollectiveExecutorMgr::RpcCollectiveExecutorMgr( |
29 | const ConfigProto& config, const DeviceMgr* dev_mgr, |
30 | std::unique_ptr<DeviceResolverDistributed> dev_resolver, |
31 | std::unique_ptr<CollectiveParamResolverDistributed> param_resolver, |
32 | std::unique_ptr<NcclCommunicatorInterface> nccl_communicator, |
33 | WorkerCacheInterface* worker_cache, const string& task_name) |
34 | : CollectiveExecutorMgr(config, dev_mgr, std::move(dev_resolver), |
35 | std::move(param_resolver), |
36 | std::move(nccl_communicator)), |
37 | worker_cache_(worker_cache), |
38 | task_name_(task_name) { |
39 | group_leader_ = (task_name == config.experimental().collective_group_leader()) |
40 | ? "" |
41 | : config.experimental().collective_group_leader(); |
42 | } |
43 | |
44 | RpcCollectiveExecutorMgr::~RpcCollectiveExecutorMgr() { |
45 | for (auto it : sequence_table_) { |
46 | delete it.second; |
47 | } |
48 | } |
49 | |
50 | CollectiveExecutor* RpcCollectiveExecutorMgr::Create(int64_t step_id) { |
51 | CollectiveRemoteAccessDistributed* rma = |
52 | new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(), |
53 | work_queue_, worker_cache_, step_id, |
54 | task_name_); |
55 | return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_, work_queue_); |
56 | } |
57 | |
58 | namespace { |
59 | // StepId must leave the most-significant 7 bits empty for future use. |
60 | static const int64_t kStepIdMask = (((1uLL << 56) - 1) | (1uLL << 56)); |
61 | |
62 | int64_t NewRandomStepId() { |
63 | int64_t step_id = random::New64(); |
64 | // Leave MS 8 bits clear for future use. |
65 | step_id &= kStepIdMask; |
66 | return step_id; |
67 | } |
68 | } // namespace |
69 | |
70 | void RpcCollectiveExecutorMgr::RefreshStepIdSequenceAsync( |
71 | int64_t graph_key, const StatusCallback& done) { |
72 | if (group_leader_.empty()) { |
73 | mutex_lock l(sequence_mu_); |
74 | GraphKeySequence* gks = nullptr; |
75 | auto it = sequence_table_.find(graph_key); |
76 | if (it == sequence_table_.end()) { |
77 | gks = new GraphKeySequence(graph_key); |
78 | sequence_table_[graph_key] = gks; |
79 | } else { |
80 | gks = it->second; |
81 | } |
82 | gks->next_step_id_ = NewRandomStepId(); |
83 | done(OkStatus()); |
84 | } else { |
85 | WorkerInterface* wi = worker_cache_->GetOrCreateWorker(group_leader_); |
86 | GetStepSequenceRequest* req = new GetStepSequenceRequest; |
87 | GetStepSequenceResponse* resp = new GetStepSequenceResponse; |
88 | req->add_graph_key(graph_key); |
89 | wi->GetStepSequenceAsync( |
90 | req, resp, [this, req, resp, done](const Status& s) { |
91 | if (!s.ok()) { |
92 | LOG(ERROR) << "Bad response [" << s |
93 | << "] from GetStepSequenceAsync call to " |
94 | << group_leader_; |
95 | done(s); |
96 | } else { |
97 | done(UpdateStepSequences(*resp)); |
98 | } |
99 | delete req; |
100 | delete resp; |
101 | }); |
102 | } |
103 | } |
104 | |
105 | void RpcCollectiveExecutorMgr::GetStepSequenceAsync( |
106 | const GetStepSequenceRequest* request, GetStepSequenceResponse* response, |
107 | const StatusCallback& done) { |
108 | if (!group_leader_.empty()) { |
109 | LOG(ERROR) << "GetStepSequence called at non-group-leader" ; |
110 | done(errors::Internal("GetStepSequenceAsync called at non-group-leader" )); |
111 | } else { |
112 | mutex_lock l(sequence_mu_); |
113 | for (int64_t graph_key : request->graph_key()) { |
114 | auto it = sequence_table_.find(graph_key); |
115 | GraphKeySequence* gks = nullptr; |
116 | if (it == sequence_table_.end()) { |
117 | gks = new GraphKeySequence(graph_key); |
118 | gks->next_step_id_ = NewRandomStepId(); |
119 | sequence_table_[graph_key] = gks; |
120 | } else { |
121 | gks = it->second; |
122 | } |
123 | StepSequence* ss = response->add_step_sequence(); |
124 | ss->set_graph_key(graph_key); |
125 | ss->set_next_step_id(gks->next_step_id_); |
126 | } |
127 | done(OkStatus()); |
128 | } |
129 | } |
130 | |
131 | Status RpcCollectiveExecutorMgr::UpdateStepSequences( |
132 | const GetStepSequenceResponse& resp) { |
133 | mutex_lock l(sequence_mu_); |
134 | for (const StepSequence& ss : resp.step_sequence()) { |
135 | GraphKeySequence* gks = nullptr; |
136 | auto it = sequence_table_.find(ss.graph_key()); |
137 | if (it == sequence_table_.end()) { |
138 | gks = new GraphKeySequence(ss.graph_key()); |
139 | sequence_table_[ss.graph_key()] = gks; |
140 | } else { |
141 | gks = it->second; |
142 | } |
143 | gks->next_step_id_ = ss.next_step_id(); |
144 | } |
145 | return OkStatus(); |
146 | } |
147 | |
148 | int64_t RpcCollectiveExecutorMgr::NextStepId(int64_t graph_key) { |
149 | mutex_lock l(sequence_mu_); |
150 | auto it = sequence_table_.find(graph_key); |
151 | if (it != sequence_table_.end()) { |
152 | return it->second->next_step_id_; |
153 | } |
154 | return CollectiveExecutor::kInvalidId; |
155 | } |
156 | |
157 | void RpcCollectiveExecutorMgr::RetireStepId(int64_t graph_key, |
158 | int64_t step_id) { |
159 | mutex_lock l(sequence_mu_); |
160 | auto it = sequence_table_.find(graph_key); |
161 | if (it != sequence_table_.end()) { |
162 | if (step_id == it->second->next_step_id_) { |
163 | it->second->next_step_id_ = (it->second->next_step_id_ + 1) & kStepIdMask; |
164 | } else { |
165 | it->second->next_step_id_ = CollectiveExecutor::kInvalidId; |
166 | } |
167 | } else { |
168 | LOG(ERROR) << "Failed to find graph_key " << graph_key << " to retire." ; |
169 | } |
170 | } |
171 | |
172 | std::unique_ptr<RpcCollectiveExecutorMgr> CreateProdRpcCollectiveExecutorMgr( |
173 | const ConfigProto& config, const DeviceMgr* device_mgr, |
174 | std::unique_ptr<NcclCommunicatorInterface> nccl_communicator, |
175 | WorkerCacheInterface* worker_cache, const string& default_worker_name) { |
176 | auto dev_resolver = std::make_unique<DeviceResolverDistributed>(device_mgr); |
177 | auto param_resolver = std::make_unique<CollectiveParamResolverDistributed>( |
178 | config, device_mgr, dev_resolver.get(), nccl_communicator.get(), |
179 | worker_cache, default_worker_name); |
180 | return std::make_unique<RpcCollectiveExecutorMgr>( |
181 | config, device_mgr, std::move(dev_resolver), std::move(param_resolver), |
182 | std::move(nccl_communicator), worker_cache, default_worker_name); |
183 | } |
184 | |
185 | } // namespace tensorflow |
186 | |