1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" |
16 | |
17 | #include "absl/strings/escaping.h" |
18 | #include "tensorflow/core/common_runtime/device.h" |
19 | #include "tensorflow/core/common_runtime/device_mgr.h" |
20 | #include "tensorflow/core/distributed_runtime/cancellable_call.h" |
21 | #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" |
22 | #include "tensorflow/core/distributed_runtime/worker_cache.h" |
23 | #include "tensorflow/core/framework/device_attributes.pb.h" |
24 | #include "tensorflow/core/platform/errors.h" |
25 | #include "tensorflow/core/platform/status.h" |
26 | #include "tensorflow/core/protobuf/config.pb.h" |
27 | #include "tensorflow/core/util/device_name_utils.h" |
28 | |
29 | namespace tensorflow { |
30 | namespace { |
31 | |
32 | class CompleteGroupCall : public CancellableCall { |
33 | public: |
34 | CompleteGroupCall(const CollGroupParams& group, |
35 | const DeviceAttributes& device, |
36 | CancellationManager* cancel_mgr, |
37 | const string& remote_worker, WorkerCacheInterface* wc) |
38 | : CancellableCall(cancel_mgr, remote_worker, wc) { |
39 | req_.set_group_key(group.group_key); |
40 | req_.set_group_size(group.group_size); |
41 | req_.set_device_type(group.device_type.type_string()); |
42 | *req_.mutable_device_attributes() = device; |
43 | } |
44 | ~CompleteGroupCall() override {} |
45 | |
46 | void IssueCall(const StatusCallback& done) override { |
47 | wi_->CompleteGroupAsync(&opts_, &req_, &resp_, done); |
48 | } |
49 | |
50 | CompleteGroupRequest req_; |
51 | CompleteGroupResponse resp_; |
52 | }; |
53 | |
54 | class CompleteInstanceCall : public CancellableCall { |
55 | public: |
56 | CompleteInstanceCall(const CollGroupParams& group, |
57 | const CollInstanceParams& instance, |
58 | const string& node_name, const string& device_name, |
59 | bool is_source, CancellationManager* cancel_mgr, |
60 | const string& remote_worker, WorkerCacheInterface* wc) |
61 | : CancellableCall(cancel_mgr, remote_worker, wc) { |
62 | req_.set_name(node_name); |
63 | req_.set_type(instance.type); |
64 | req_.set_data_type(instance.data_type); |
65 | instance.shape.AsProto(req_.mutable_shape()); |
66 | req_.set_group_key(group.group_key); |
67 | req_.set_group_size(group.group_size); |
68 | req_.set_instance_key(instance.instance_key); |
69 | req_.set_device_type(group.device_type.type_string()); |
70 | for (int32_t offset : instance.impl_details.subdiv_offsets) { |
71 | req_.add_subdiv_offset(offset); |
72 | } |
73 | req_.set_device(device_name); |
74 | req_.set_is_source(is_source); |
75 | } |
76 | |
77 | ~CompleteInstanceCall() override {} |
78 | |
79 | void IssueCall(const StatusCallback& done) override { |
80 | wi_->CompleteInstanceAsync(&opts_, &req_, &resp_, done); |
81 | } |
82 | |
83 | CompleteInstanceRequest req_; |
84 | CompleteInstanceResponse resp_; |
85 | }; |
86 | |
87 | } // namespace |
88 | |
89 | CollectiveParamResolverDistributed::CollectiveParamResolverDistributed( |
90 | const ConfigProto& config, const DeviceMgr* dev_mgr, |
91 | DeviceResolverDistributed* dev_resolver, |
92 | NcclCommunicatorInterface* nccl_communicator, |
93 | WorkerCacheInterface* worker_cache, const string& task_name) |
94 | : CollectiveParamResolverLocal(config, dev_mgr, dev_resolver, |
95 | nccl_communicator, task_name), |
96 | worker_cache_(worker_cache), |
97 | group_leader_(task_name == config.experimental().collective_group_leader() |
98 | ? "" |
99 | : config.experimental().collective_group_leader()) { |
100 | VLOG(1) << "CompleteParamResolverDistributed ctor task={" << task_name |
101 | << "} config.collective_group_leader={" |
102 | << config.experimental().collective_group_leader() << "}" |
103 | << " config.collective_nccl={" |
104 | << config.experimental().collective_nccl() << "}" ; |
105 | } |
106 | |
107 | void CollectiveParamResolverDistributed::CompleteParamsAsync( |
108 | const DeviceAttributes& device, CollectiveParams* cp, |
109 | CancellationManager* cancel_mgr, const StatusCallback& done) { |
110 | VLOG(1) << "CompleteParams distributed " << device.name() << " for " << cp |
111 | << ": " << cp->ToString(); |
112 | if (cp->run_group_initialization) { |
113 | CompleteGroupDistributed( |
114 | device, &cp->group, cancel_mgr, |
115 | [this, device, cp, cancel_mgr, done](Status s) { |
116 | if (s.ok()) { |
117 | std::vector<DeviceAttributes> devices; |
118 | devices.reserve(cp->group.group_size); |
119 | for (const CollGroupMember& m : cp->group.members) { |
120 | devices.push_back(m.device); |
121 | } |
122 | s = dev_resolver_->UpdateDeviceAttributes(devices); |
123 | } |
124 | if (s.ok()) { |
125 | CompleteInstanceDistributed(device.name(), cp, cancel_mgr, done); |
126 | } else { |
127 | done(s); |
128 | } |
129 | }); |
130 | } else { |
131 | // For Collective V3 ops, group is already initialized. Fetch attributes |
132 | // for the already initialized group to pass to Insitance initialization. |
133 | auto s = LookupGroup(cp->group.group_key, &cp->group); |
134 | if (s.ok()) { |
135 | CompleteInstanceDistributed(device.name(), cp, cancel_mgr, done); |
136 | } else { |
137 | done(s); |
138 | } |
139 | } |
140 | } |
141 | |
142 | void CollectiveParamResolverDistributed::CompleteGroupAsync( |
143 | const DeviceAttributes& device, CollGroupParams* group_params, |
144 | CancellationManager* cancel_mgr, const StatusCallback& done) { |
145 | CompleteGroupDistributed(device, group_params, cancel_mgr, done); |
146 | } |
147 | |
148 | void CollectiveParamResolverDistributed::CompleteInstanceAsync( |
149 | const CompleteInstanceRequest* request, CompleteInstanceResponse* response, |
150 | CancellationManager* cancel_mgr, const StatusCallback& done) { |
151 | GroupRec* gr = GetCachedGroup(request->group_key()); |
152 | if (gr == nullptr) { |
153 | done(errors::FailedPrecondition( |
154 | "group " , request->group_key(), |
155 | " not found. This normally means the server has restarted" )); |
156 | return; |
157 | } |
158 | CollectiveParams* cp = new CollectiveParams; |
159 | { |
160 | mutex_lock l(gr->mu); |
161 | if (!gr->status.ok()) { |
162 | done(gr->status); |
163 | return; |
164 | } else if (gr->group.members.size() != gr->group.group_size) { |
165 | done(errors::FailedPrecondition( |
166 | "group " , request->group_key(), |
167 | " failed to resolve. This normally means the server has restarted" )); |
168 | return; |
169 | } |
170 | cp->group = gr->group; |
171 | } |
172 | cp->name = request->name(); |
173 | cp->instance.type = CollectiveType(request->type()); |
174 | cp->instance.instance_key = request->instance_key(); |
175 | cp->instance.data_type = request->data_type(); |
176 | cp->instance.shape = TensorShape(request->shape()); |
177 | cp->is_source = request->is_source(); |
178 | for (int32_t offset : request->subdiv_offset()) { |
179 | cp->instance.impl_details.subdiv_offsets.push_back(offset); |
180 | } |
181 | StatusCallback done_and_cleanup = [cp, done](const Status& s) { |
182 | done(s); |
183 | cp->Unref(); |
184 | }; |
185 | CompleteInstanceDistributed( |
186 | request->device(), cp, cancel_mgr, |
187 | [this, cp, response, done_and_cleanup](Status status) { |
188 | if (status.ok()) { |
189 | // Now source_rank should be known, so retrieve it. |
190 | bool created_irec; |
191 | InstanceRec* ir = GetOrCreateInstanceRec(cp, &created_irec); |
192 | { |
193 | mutex_lock l(ir->mu); |
194 | status = ir->status; |
195 | if (ir->status.ok()) { |
196 | response->set_instance_key(cp->instance.instance_key); |
197 | response->set_source_rank(ir->source_rank); |
198 | } |
199 | } |
200 | } |
201 | done_and_cleanup(status); |
202 | }); |
203 | } |
204 | |
205 | CollectiveParamResolverDistributed::GroupRec* |
206 | CollectiveParamResolverDistributed::GetCachedGroup(int32_t group_key) { |
207 | mutex_lock l(group_mu_); |
208 | auto it = group_table_.find(group_key); |
209 | if (it == group_table_.end()) { |
210 | return nullptr; |
211 | } |
212 | return it->second.get(); |
213 | } |
214 | |
215 | Status CollectiveParamResolverDistributed::UpdateGroupCache( |
216 | const CompleteGroupResponse& resp) { |
217 | // Build a new record from resp. |
218 | std::unique_ptr<GroupRec> gr(new GroupRec); |
219 | { |
220 | mutex_lock grl(gr->mu); |
221 | gr->group.device_type = DeviceType(resp.device_type()); |
222 | gr->group.group_key = resp.group_key(); |
223 | gr->group.group_size = resp.group_size(); |
224 | gr->group.num_tasks = resp.num_tasks(); |
225 | if (resp.device_attributes().empty()) { |
226 | return errors::Internal( |
227 | "CompleteGroupResponse device_attributes is empty. Make sure you're " |
228 | "running the same version of Tensorflow on all workers." ); |
229 | } |
230 | if (resp.device_attributes_size() != gr->group.group_size) { |
231 | return errors::Internal( |
232 | "CompleteGroupResponse group_size doesn't match device_name list" ); |
233 | } |
234 | gr->group.members.reserve(resp.device_attributes().size()); |
235 | for (const DeviceAttributes& device : resp.device_attributes()) { |
236 | CollGroupMember member; |
237 | member.device = device; |
238 | gr->group.members.push_back(std::move(member)); |
239 | gr->incarnations_by_device_name[device.name()] = device.incarnation(); |
240 | } |
241 | gr->group.runtime_details.communicator_key = resp.communicator_key(); |
242 | FinishGroup(gr.get()); |
243 | } |
244 | GroupRec* previous_gr = nullptr; |
245 | { |
246 | // Group membership should never change. Once a record is in group_table_ |
247 | // it never gets removed. |
248 | mutex_lock l(group_mu_); |
249 | auto it = group_table_.find(resp.group_key()); |
250 | if (it == group_table_.end()) { |
251 | VLOG(2) << "UpdateGroupCache: communicator_key=" |
252 | << absl::CEscape(resp.communicator_key()); |
253 | group_table_[gr->group.group_key] = std::move(gr); |
254 | } else { |
255 | previous_gr = it->second.get(); |
256 | } |
257 | } |
258 | if (previous_gr != nullptr) { |
259 | mutex_lock grl(previous_gr->mu); |
260 | if (previous_gr->group.runtime_details.communicator_key != |
261 | resp.communicator_key()) { |
262 | return errors::Internal( |
263 | "UpdateGroupCache: CompleteGroupResponse for group " , |
264 | resp.group_key(), |
265 | " gives communicator_key=" , absl::CEscape(resp.communicator_key()), |
266 | " but cache already holds communicator_key=" , |
267 | absl::CEscape(previous_gr->group.runtime_details.communicator_key)); |
268 | } |
269 | } |
270 | return OkStatus(); |
271 | } |
272 | |
273 | void CollectiveParamResolverDistributed::CompleteGroupDistributed( |
274 | const DeviceAttributes& device, CollGroupParams* group_params, |
275 | CancellationManager* cancel_mgr, const StatusCallback& done) { |
276 | VLOG(1) << "CompleteGroupDistributed group_key=" << group_params->group_key |
277 | << " dev: " << device.name() |
278 | << " is_leader=" << (group_leader_.empty()); |
279 | if (group_leader_.empty()) { |
280 | // This is the group leader, so resolution is local. |
281 | return CompleteGroupLocal(device, group_params, cancel_mgr, done); |
282 | } else if (GetCachedGroup(group_params->group_key) == nullptr) { |
283 | // Need to update Group cache from the leader. |
284 | CompleteGroupCall* call = new CompleteGroupCall( |
285 | *group_params, device, cancel_mgr, group_leader_, worker_cache_); |
286 | CancellationToken abortion_token = |
287 | abortion_cancel_mgr_.get_cancellation_token(); |
288 | bool already_aborted = !abortion_cancel_mgr_.RegisterCallback( |
289 | abortion_token, [call] { call->Cancel(); }); |
290 | if (already_aborted) { |
291 | done(errors::Cancelled("collective ops already aborted" )); |
292 | delete call; |
293 | return; |
294 | } |
295 | call->Start([this, device, group_params, call, cancel_mgr, abortion_token, |
296 | done](const Status& s) { |
297 | abortion_cancel_mgr_.DeregisterCallback(abortion_token); |
298 | if (s.ok()) { |
299 | Status status = UpdateGroupCache(call->resp_); |
300 | if (status.ok()) { |
301 | CompleteGroupLocal(device, group_params, cancel_mgr, done); |
302 | } else { |
303 | done(status); |
304 | } |
305 | } else { |
306 | done(s); |
307 | } |
308 | delete call; |
309 | }); |
310 | return; |
311 | } else { |
312 | return CompleteGroupLocal(device, group_params, cancel_mgr, done); |
313 | } |
314 | } |
315 | |
316 | bool CollectiveParamResolverDistributed::InstanceIsCached( |
317 | int32_t group_key, int32_t instance_key) { |
318 | mutex_lock l(instance_mu_); |
319 | auto group_it = instance_table_.find(group_key); |
320 | if (group_it == instance_table_.end()) { |
321 | return false; |
322 | } |
323 | auto instance_it = group_it->second.find(instance_key); |
324 | return instance_it != group_it->second.end(); |
325 | } |
326 | |
327 | Status CollectiveParamResolverDistributed::UpdateInstanceCache( |
328 | CollectiveParams* cp, const CompleteInstanceResponse& resp) { |
329 | int32_t source_rank = resp.source_rank(); |
330 | bool created_irec; |
331 | InstanceRec* ir = GetOrCreateInstanceRec(cp, &created_irec); |
332 | mutex_lock l(ir->mu); |
333 | if (!ir->status.ok()) { |
334 | return ir->status; |
335 | } |
336 | if (ir->source_rank != source_rank) { |
337 | if (ir->source_rank >= 0) { |
338 | ir->status = errors::Internal( |
339 | "UpdateInstanceCache: CompleteInstanceResponse for instance " , |
340 | cp->instance.instance_key, " gives source_rank=" , source_rank, |
341 | " but cache already holds value=" , ir->source_rank); |
342 | return ir->status; |
343 | } |
344 | ir->source_rank = source_rank; |
345 | } |
346 | if (ir->known_count < cp->group.group_size) { |
347 | ir->known_count = cp->group.group_size; |
348 | const int ir_known_size = ir->known.size(); |
349 | if (ir_known_size != cp->group.group_size) { |
350 | ir->status = errors::Internal( |
351 | "UpdateInstanceCache:: CompleteInstanceResponse for instance " , |
352 | cp->instance.instance_key, " has known.size()=" , ir->known.size(), |
353 | " < group_size=" , cp->group.group_size); |
354 | return ir->status; |
355 | } |
356 | for (int i = 0; i < ir_known_size; ++i) { |
357 | ir->known[i] = true; |
358 | } |
359 | } |
360 | return ir->status; |
361 | } |
362 | |
363 | void CollectiveParamResolverDistributed::CompleteInstanceDistributed( |
364 | const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr, |
365 | const StatusCallback& done) { |
366 | if (group_leader_.empty()) { |
367 | // This is the group leader so resolution is local. |
368 | return CompleteInstanceLocal(device, cp, done); |
369 | } else if (InstanceIsCached(cp->group.group_key, cp->instance.instance_key)) { |
370 | return CompleteInstanceLocal(device, cp, done); |
371 | } else { |
372 | CompleteInstanceCall* call = new CompleteInstanceCall( |
373 | cp->group, cp->instance, cp->name, device, cp->is_source, cancel_mgr, |
374 | group_leader_, worker_cache_); |
375 | CancellationToken abortion_token = |
376 | abortion_cancel_mgr_.get_cancellation_token(); |
377 | bool already_aborted = !abortion_cancel_mgr_.RegisterCallback( |
378 | abortion_token, [call] { call->Cancel(); }); |
379 | if (already_aborted) { |
380 | done(errors::Cancelled("collective ops already aborted" )); |
381 | delete call; |
382 | return; |
383 | } |
384 | call->Start([this, device, cp, call, abortion_token, done](Status s) { |
385 | abortion_cancel_mgr_.DeregisterCallback(abortion_token); |
386 | if (s.ok()) { |
387 | s = UpdateInstanceCache(cp, call->resp_); |
388 | } |
389 | if (s.ok()) { |
390 | CompleteInstanceLocal(device, cp, done); |
391 | } else { |
392 | done(s); |
393 | } |
394 | delete call; |
395 | }); |
396 | return; |
397 | } |
398 | } |
399 | |
400 | void CollectiveParamResolverDistributed::StartAbort(const Status& s) { |
401 | { |
402 | mutex_lock l(status_mu_); |
403 | if (!status_.ok()) { |
404 | VLOG(2) << "CollectiveParamResolverDistributed already aborted. Ignoring " |
405 | "subsequent abortion with status: " |
406 | << s; |
407 | return; |
408 | } |
409 | status_ = s; |
410 | } |
411 | StartAbortLocal(s); |
412 | abortion_cancel_mgr_.StartCancel(); |
413 | } |
414 | |
415 | } // namespace tensorflow |
416 | |