1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/common_runtime/ring_alg.h" |
16 | |
17 | #include <stdlib.h> |
18 | |
19 | #include <atomic> |
20 | #include <functional> |
21 | #include <utility> |
22 | |
23 | #include "tensorflow/core/common_runtime/collective_rma_local.h" |
24 | #include "tensorflow/core/common_runtime/collective_util.h" |
25 | #include "tensorflow/core/common_runtime/copy_tensor.h" |
26 | #include "tensorflow/core/common_runtime/device.h" |
27 | #include "tensorflow/core/common_runtime/device_mgr.h" |
28 | #include "tensorflow/core/common_runtime/dma_helper.h" |
29 | #include "tensorflow/core/common_runtime/process_util.h" |
30 | #include "tensorflow/core/framework/allocator.h" |
31 | #include "tensorflow/core/framework/device_base.h" |
32 | #include "tensorflow/core/framework/op_kernel.h" |
33 | #include "tensorflow/core/framework/tensor.h" |
34 | #include "tensorflow/core/framework/types.h" |
35 | #include "tensorflow/core/lib/core/errors.h" |
36 | #include "tensorflow/core/lib/core/notification.h" |
37 | #include "tensorflow/core/lib/core/status.h" |
38 | #include "tensorflow/core/lib/strings/str_util.h" |
39 | #include "tensorflow/core/lib/strings/strcat.h" |
40 | #include "tensorflow/core/platform/env.h" |
41 | #include "tensorflow/core/platform/types.h" |
42 | |
43 | // Set true for greater intelligibility of debug mode log messages. |
44 | #define READABLE_KEYS false |
45 | // A ring algorithm exchanges chunks of tensor between devices. The chunk size |
46 | // depends on the number of subdivisions specified in the algorithm. If the |
47 | // user does not specify the number of subdivisions we may infer the number |
48 | // dynamically so that the resulting chunk size does not exceed |
49 | // kMaxChunkSizeBytes, empirically set at 4 MiB. |
50 | constexpr size_t kMaxChunkSizeBytes = (4 * 1024 * 1024); |
51 | // kMaxSubdivsPerDeviceDefault is used to give an upper bound on the number of |
52 | // subdivisions dynamically generated when user does not provide the parameter |
53 | // through the collectives API. A reasonable value would be a small |
54 | // multiple of the number of NICs adjacent to each device. |
55 | constexpr int kMaxSubdivsPerDeviceDefault = 2; |
56 | |
57 | namespace tensorflow { |
58 | namespace { |
59 | // Each CollectiveOp implementation is free to define its own |
60 | // BufRendezvous key format. This function produces the key used by |
61 | // RingAlg instances. Note that the exec_key will differentiate between |
62 | // different instances consequently we don't need to further differentiate |
63 | // between subclasses of RingAlg. |
64 | string RingAlgBufKey(const string& name, const string& exec_key, int pass, |
65 | int section, int source_rank) { |
66 | if (READABLE_KEYS) { |
67 | return strings::StrCat(name, "(" , exec_key, "):pass(" , pass, "):section(" , |
68 | section, "):srcrank(" , source_rank, ")" ); |
69 | } else { |
70 | // TODO(b/78352018): Try out some kind of denser encoding, e.g. 128 bit |
71 | // hash. |
72 | return strings::StrCat(exec_key, ":" , pass, ":" , section, ":" , source_rank); |
73 | } |
74 | } |
75 | |
76 | } // namespace |
77 | |
78 | void RingAlg::PCQueue::Enqueue(RingField* rf) { |
79 | mutex_lock l(pcq_mu_); |
80 | deque_.push_back(rf); |
81 | if (waiter_count_ > 0) { |
82 | cv_.notify_one(); |
83 | } |
84 | } |
85 | |
86 | RingAlg::RingField* RingAlg::PCQueue::Dequeue() { |
87 | mutex_lock l(pcq_mu_); |
88 | if (deque_.empty()) { |
89 | ++waiter_count_; |
90 | while (deque_.empty()) { |
91 | cv_.wait(l); |
92 | } |
93 | --waiter_count_; |
94 | } |
95 | RingField* rf = deque_.front(); |
96 | deque_.pop_front(); |
97 | return rf; |
98 | } |
99 | |
100 | RingAlg::RingAlg(CollectiveType type, const string& name) |
101 | : type_(type), |
102 | name_(name), |
103 | col_ctx_(nullptr), |
104 | col_params_(nullptr), |
105 | done_(nullptr), |
106 | group_size_(-1), |
107 | num_subdivs_(-1) {} |
108 | |
109 | namespace { |
110 | Status GenerateSubdivsInCollectiveParams(CollectiveParams* col_params) { |
111 | // This function generates subdivision_offsets. Expect it to be empty when |
112 | // called. |
113 | DCHECK(col_params->instance.impl_details.subdiv_offsets.empty()); |
114 | |
115 | if (col_params->instance.impl_details.max_subdivs_per_device == -1) { |
116 | col_params->instance.impl_details.subdiv_offsets = {0}; |
117 | VLOG(2) << "Limiting to 1 subdivision as max_subdivs_per_device == -1" ; |
118 | return OkStatus(); |
119 | } |
120 | |
121 | if (col_params->instance.shape.num_elements() == 0) { |
122 | return errors::Internal("shape in CollectiveParams should be non-empty" ); |
123 | } |
124 | const int kAvgDevPerTask = |
125 | col_params->group.group_size / col_params->group.num_tasks; |
126 | const int max_subdivs_per_device = |
127 | (col_params->instance.impl_details.max_subdivs_per_device > 0) |
128 | ? col_params->instance.impl_details.max_subdivs_per_device |
129 | : kMaxSubdivsPerDeviceDefault; |
130 | const int kMaxNumSubdivs = max_subdivs_per_device * kAvgDevPerTask; |
131 | if (kMaxNumSubdivs <= 0) { |
132 | return errors::Internal("Unexpected kMaxNumSubdivs " , kMaxNumSubdivs, |
133 | " in " , |
134 | col_params->instance.impl_details.collective_name); |
135 | } |
136 | // NOTE(ayushd): If no subdiv_offsets have been specified, dynamically add |
137 | // as many offsets as needed so that the size of tensor chunks <= |
138 | // kMaxChunkSizeBytes. Empirically, chunks that are too small or too large |
139 | // lead to worse performance. |
140 | int num_subdivs = 0; |
141 | const size_t tensor_size = col_params->instance.shape.num_elements() * |
142 | DataTypeSize(col_params->instance.data_type); |
143 | size_t chunk_size; |
144 | do { |
145 | ++num_subdivs; |
146 | int num_chunks = col_params->group.group_size * num_subdivs; |
147 | chunk_size = tensor_size / num_chunks; |
148 | VLOG(2) << "num_subdivs " << num_subdivs << " num_chunks " << num_chunks |
149 | << " chunk_size " << chunk_size; |
150 | } while (chunk_size > kMaxChunkSizeBytes && num_subdivs < kMaxNumSubdivs); |
151 | if (num_subdivs <= 0) { |
152 | return errors::Internal("Unexpected num_subdivs " , num_subdivs, " in " , |
153 | col_params->instance.impl_details.collective_name); |
154 | } |
155 | |
156 | int subdiv_stride = kAvgDevPerTask / num_subdivs; |
157 | if (subdiv_stride == 0) subdiv_stride = 1; |
158 | col_params->instance.impl_details.subdiv_offsets.reserve(num_subdivs); |
159 | for (int sdi = 0; sdi < num_subdivs; ++sdi) { |
160 | int subdiv_offset = subdiv_stride * sdi; |
161 | if (sdi % 2 == 1) subdiv_offset *= -1; |
162 | col_params->instance.impl_details.subdiv_offsets.push_back(subdiv_offset); |
163 | } |
164 | |
165 | if (VLOG_IS_ON(2)) { |
166 | string subdiv_buf; |
167 | for (const int subdiv_offset : |
168 | col_params->instance.impl_details.subdiv_offsets) { |
169 | strings::StrAppend(&subdiv_buf, " " , subdiv_offset); |
170 | } |
171 | VLOG(2) << "Dynamically generated " << num_subdivs |
172 | << " subdiv_offsets:" << subdiv_buf << " tensor_size " |
173 | << tensor_size << " chunk_size " << chunk_size; |
174 | } |
175 | |
176 | return OkStatus(); |
177 | } |
178 | } // namespace |
179 | |
180 | Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) { |
181 | const string& device_name = |
182 | col_params->group.members[col_params->default_rank].device.name(); |
183 | // Each subdiv permutation is a ring formed by rotating each |
184 | // single-task subsequence of devices by an offset. This makes most |
185 | // sense when each task has the same number of devices but we can't |
186 | // depend on that being the case so we'll compute something that |
187 | // works in any case. |
188 | |
189 | // Start by counting the devices in each task. |
190 | // Precondition: device_names must be sorted so that all devices in |
191 | // the same task are adjacent. |
192 | std::vector<int> dev_per_task; |
193 | const string* prior_task_name = &col_params->group.members[0].task; |
194 | int dev_count = 1; |
195 | for (int di = 1; di < col_params->group.group_size; ++di) { |
196 | if (col_params->group.members[di].task != *prior_task_name) { |
197 | dev_per_task.push_back(dev_count); |
198 | dev_count = 1; |
199 | prior_task_name = &col_params->group.members[di].task; |
200 | } else { |
201 | ++dev_count; |
202 | } |
203 | } |
204 | dev_per_task.push_back(dev_count); |
205 | DCHECK_EQ(col_params->group.num_tasks, dev_per_task.size()); |
206 | |
207 | if (col_params->instance.impl_details.subdiv_offsets.empty()) { |
208 | TF_RETURN_IF_ERROR(GenerateSubdivsInCollectiveParams(col_params)); |
209 | } |
210 | |
211 | // Generate a ring permutation for requested offset. |
212 | VLOG(2) << "Setting up perms for col_params " << col_params |
213 | << " subdiv_permutations " |
214 | << &col_params->instance.impl_details.subdiv_permutations; |
215 | col_params->instance.impl_details.subdiv_permutations.resize( |
216 | col_params->instance.impl_details.subdiv_offsets.size()); |
217 | col_params->subdiv_rank.resize( |
218 | col_params->instance.impl_details.subdiv_offsets.size(), -1); |
219 | for (int sdi = 0; |
220 | sdi < col_params->instance.impl_details.subdiv_offsets.size(); ++sdi) { |
221 | std::vector<int>& perm = |
222 | col_params->instance.impl_details.subdiv_permutations[sdi]; |
223 | DCHECK_EQ(perm.size(), 0); |
224 | int offset = col_params->instance.impl_details.subdiv_offsets[sdi]; |
225 | // A negative subdivision offset is interpreted as follows: |
226 | // 1. Reverse the local device ordering. |
227 | // 2. Begin the subdivision at abs(offset) in the reversed ordering. |
228 | bool reverse = false; |
229 | if (offset < 0) { |
230 | offset = abs(offset); |
231 | reverse = true; |
232 | } |
233 | int prior_dev_count = 0; // sum over prior worker device counts |
234 | for (int ti = 0; ti < col_params->group.num_tasks; ++ti) { |
235 | for (int di = 0; di < dev_per_task[ti]; ++di) { |
236 | int di_offset = (di + offset) % dev_per_task[ti]; |
237 | int offset_di = |
238 | reverse ? (dev_per_task[ti] - (di_offset + 1)) : di_offset; |
239 | // Device index in global subdivision permutation. |
240 | int permuted_di = prior_dev_count + offset_di; |
241 | int rank = static_cast<int>(perm.size()); |
242 | perm.push_back(permuted_di); |
243 | if (col_params->group.members[permuted_di].device.name() == |
244 | device_name) { |
245 | DCHECK_EQ(permuted_di, col_params->default_rank); |
246 | col_params->subdiv_rank[sdi] = rank; |
247 | } |
248 | } |
249 | prior_dev_count += dev_per_task[ti]; |
250 | } |
251 | DCHECK_EQ(col_params->group.group_size, perm.size()); |
252 | } |
253 | |
254 | VLOG(2) << collective_util::SubdivPermDebugString(*col_params); |
255 | return OkStatus(); |
256 | } |
257 | |
258 | Status RingAlg::InitializeCollectiveContext( |
259 | std::shared_ptr<CollectiveContext> col_ctx) { |
260 | DCHECK(col_ctx->dev_mgr); |
261 | col_ctx_ = col_ctx; |
262 | col_params_ = col_ctx->col_params.get(); |
263 | return collective_util::InitializeDeviceAndLocality( |
264 | col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device, |
265 | &col_ctx->device_locality); |
266 | } |
267 | |
268 | string RingAlg::TensorDebugString(const Tensor& tensor) { |
269 | const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info = |
270 | col_ctx_->op_ctx->device()->tensorflow_accelerator_device_info(); |
271 | if (accelerator_device_info) { |
272 | Tensor cpu_tensor(tensor.dtype(), tensor.shape()); |
273 | Status st = |
274 | accelerator_device_info->default_context->CopyDeviceTensorToCPUSync( |
275 | &tensor, "" /*tensor_name*/, col_ctx_->device, &cpu_tensor); |
276 | DCHECK(st.ok()); |
277 | return cpu_tensor.SummarizeValue(64); |
278 | } else { |
279 | return tensor.SummarizeValue(64); |
280 | } |
281 | } |
282 | |
283 | void RingAlg::StartAbort(const Status& s) { |
284 | // In abort mode we stop issuing additional ProvideBuf |
285 | // and ConsumeBuf calls, but we need to wait for all of the |
286 | // outstanding callbacks to be invoked before quitting. |
287 | bool abort_started = false; |
288 | { |
289 | mutex_lock l(status_mu_); |
290 | if (status_.ok()) { |
291 | LOG(ERROR) << "Aborting Ring" << name_ << " with " << s; |
292 | abort_started = true; |
293 | status_.Update(s); |
294 | } |
295 | } |
296 | // If this is the initial entry to abort mode and it's not a cancellation, |
297 | // then invoke StartAbort on the CollectiveExecutor that invoked us. That |
298 | // should start cancellation on all of the outstanding CollectiveRemoteAccess |
299 | // actions. If it's cancellation all pending send/recv should be cancelled as |
300 | // well and there's then no need to abort. |
301 | if (abort_started) { |
302 | if (col_ctx_->op_ctx->cancellation_manager() == nullptr || |
303 | (!col_ctx_->op_ctx->cancellation_manager()->IsCancelled() && |
304 | !col_ctx_->op_ctx->cancellation_manager()->IsCancelling())) { |
305 | col_ctx_->col_exec->StartAbort(s); |
306 | } |
307 | } |
308 | } |
309 | |
310 | void RingAlg::Finish(bool ok) { |
311 | if (ok) { |
312 | // Recover the output from the adaptor. |
313 | ca_->ConsumeFinalValue(col_ctx_->output); |
314 | } |
315 | Status s; |
316 | { |
317 | mutex_lock l(status_mu_); |
318 | s = status_; |
319 | } |
320 | rfv_.clear(); // Give up Refs on output tensor. |
321 | done_(s); |
322 | } |
323 | |
324 | // At the beginning of the algorithm initialize a RingField struct for |
325 | // every independent field of the tensor. |
326 | void RingAlg::InitRingField(RingField* rf, int chunk_idx, int subdiv_idx, |
327 | int field_idx) { |
328 | // Note on field indexing: There are group_size_ devices in the |
329 | // instance, implying the same number of chunks per tensor, where a |
330 | // chunk is the unit of data transferred in a time step. However, if |
331 | // a device can simultaneously send data by 2 or more independent |
332 | // channels we can speed up the transfer by subdividing chunks and |
333 | // processing multiple subdivisions at once. So the actual number |
334 | // of RingFields is group_size_ * num_subdivs_. |
335 | DCHECK_EQ(field_idx, (chunk_idx * num_subdivs_) + subdiv_idx); |
336 | rf->chunk_idx = chunk_idx; |
337 | rf->subdiv_idx = subdiv_idx; |
338 | rf->sc_idx = field_idx; |
339 | rf->rank = col_params_->subdiv_rank[subdiv_idx]; |
340 | rf->second_pass = false; |
341 | rf->action = RF_INIT; |
342 | // Recv from the device with preceding rank within the subdivision. |
343 | int recv_from_rank = (rf->rank + (group_size_ - 1)) % group_size_; |
344 | int send_to_rank = (rf->rank + 1) % group_size_; |
345 | rf->recv_dev_idx = col_params_->instance.impl_details |
346 | .subdiv_permutations[subdiv_idx][recv_from_rank]; |
347 | int send_dev_idx = col_params_->instance.impl_details |
348 | .subdiv_permutations[subdiv_idx][send_to_rank]; |
349 | rf->recv_is_remote = !col_params_->group.members[rf->recv_dev_idx].is_local; |
350 | rf->send_is_remote = !col_params_->group.members[send_dev_idx].is_local; |
351 | if (ca_->ChunkBytes(rf->sc_idx) > 0) { |
352 | // In pass 0 we skip Recv when rank = chunk_idx |
353 | rf->do_recv = (rf->chunk_idx != rf->rank); |
354 | // In pass 0 we skip Send when rank = chunk_idx-1 |
355 | rf->do_send = |
356 | (rf->rank != ((rf->chunk_idx + (group_size_ - 1)) % group_size_)); |
357 | } |
358 | rf->is_final = |
359 | (rf->rank == ((rf->chunk_idx + (group_size_ - 1)) % group_size_)); |
360 | if (rf->do_send || rf->do_recv) { |
361 | rf->chunk = ca_->ChunkAlias(rf->sc_idx); |
362 | } |
363 | VLOG(2) << this << " InitRingField " << rf->DebugString() << " chunk " |
364 | << ca_->TBounds(rf->chunk); |
365 | } |
366 | |
367 | // When a RingField transitions from first to second recompute the |
368 | // do_send and do_recv values. |
369 | void RingAlg::AdvanceToSecondPass(RingField* rf) { |
370 | VLOG(3) << "IncrRingField old value " << rf->DebugString(); |
371 | DCHECK(!rf->second_pass); |
372 | rf->second_pass = true; |
373 | rf->action = RF_INIT; |
374 | if (ca_->ChunkBytes(rf->sc_idx) > 0) { |
375 | // In pass 1 the send/no-send boundary moves down 1 place. |
376 | rf->do_recv = |
377 | (rf->rank != ((rf->chunk_idx + (group_size_ - 1)) % group_size_)); |
378 | rf->do_send = |
379 | (rf->rank != ((rf->chunk_idx + (group_size_ - 2)) % group_size_)); |
380 | } |
381 | rf->is_final = |
382 | (rf->rank == ((rf->chunk_idx + (group_size_ - 2)) % group_size_)); |
383 | VLOG(3) << "IncrRingField new value " << rf->DebugString(); |
384 | } |
385 | |
386 | string RingAlg::RingField::DebugString() const { |
387 | string rv = strings::StrCat("RingField rank=" , rank, " chunk_idx=" , chunk_idx, |
388 | " subdiv=" , subdiv_idx, " sc_idx=" , sc_idx, |
389 | " action=" , action); |
390 | strings::StrAppend(&rv, " pass=" , second_pass); |
391 | strings::StrAppend(&rv, " do_send=" , do_send, " do_recv=" , do_recv, |
392 | " is_final=" , is_final, " recv_is_remote=" , recv_is_remote, |
393 | " recv_dev_idx=" , recv_dev_idx, " sc_idx=" , sc_idx); |
394 | return rv; |
395 | } |
396 | |
397 | void RingAlg::DispatchSend(RingField* rf, const StatusCallback& done) { |
398 | DCHECK(rf->do_send); |
399 | string send_buf_key = RingAlgBufKey(name_, col_ctx_->exec_key, |
400 | rf->second_pass, rf->sc_idx, rf->rank); |
401 | VLOG(3) << "DispatchSend rank=" << col_params_->default_rank << " send key " |
402 | << send_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " sc_idx " |
403 | << rf->sc_idx; |
404 | int send_to_rank = (rf->rank + 1) % group_size_; |
405 | int send_to_dev_idx = col_params_->instance.impl_details |
406 | .subdiv_permutations[rf->subdiv_idx][send_to_rank]; |
407 | col_ctx_->col_exec->remote_access()->PostToPeer( |
408 | col_params_->group.members[send_to_dev_idx].device.name(), |
409 | col_params_->group.members[send_to_dev_idx].task, send_buf_key, |
410 | col_ctx_->device, col_ctx_->op_ctx->op_device_context(), |
411 | col_ctx_->op_ctx->output_alloc_attr(0), &rf->chunk, |
412 | col_ctx_->device_locality, col_ctx_->op_ctx->cancellation_manager(), |
413 | done); |
414 | } |
415 | |
416 | void RingAlg::DispatchRecv(RingField* rf, const StatusCallback& done) { |
417 | DCHECK(rf->do_recv); |
418 | string recv_buf_key = |
419 | RingAlgBufKey(name_, col_ctx_->exec_key, rf->second_pass, rf->sc_idx, |
420 | (rf->rank + (group_size_ - 1)) % group_size_); |
421 | VLOG(3) << "DispatchRecv rank=" << col_params_->default_rank << " recv key " |
422 | << recv_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " into " |
423 | << ((col_params_->merge_op != nullptr) ? "tmp_chunk" : "chunk" ); |
424 | Tensor* dst_tensor = (!rf->second_pass && (col_params_->merge_op != nullptr)) |
425 | ? &rf->tmp_chunk |
426 | : &rf->chunk; |
427 | col_ctx_->col_exec->remote_access()->RecvFromPeer( |
428 | col_params_->group.members[rf->recv_dev_idx].device.name(), |
429 | col_params_->group.members[rf->recv_dev_idx].task, |
430 | col_params_->group.members[rf->recv_dev_idx].is_local, recv_buf_key, |
431 | col_ctx_->device, col_ctx_->op_ctx->op_device_context(), |
432 | col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor, |
433 | col_ctx_->device_locality, rf->subdiv_idx, |
434 | col_ctx_->op_ctx->cancellation_manager(), done); |
435 | } |
436 | |
437 | string RingAlg::FieldState() { |
438 | string s = strings::StrCat( |
439 | "Ring" , name_, " " , strings::Hex(reinterpret_cast<uint64>(this)), |
440 | " exec " , col_ctx_->exec_key, " step_id=" , col_ctx_->step_id, |
441 | " state of all " , rfv_.size(), " fields:" ); |
442 | for (int i = 0; i < rfv_.size(); ++i) { |
443 | s.append("\n" ); |
444 | s.append(rfv_[i].DebugString()); |
445 | } |
446 | return s; |
447 | } |
448 | |
449 | } // namespace tensorflow |
450 | |