1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/common_runtime/ring_alg.h"
16
17#include <stdlib.h>
18
19#include <atomic>
20#include <functional>
21#include <utility>
22
23#include "tensorflow/core/common_runtime/collective_rma_local.h"
24#include "tensorflow/core/common_runtime/collective_util.h"
25#include "tensorflow/core/common_runtime/copy_tensor.h"
26#include "tensorflow/core/common_runtime/device.h"
27#include "tensorflow/core/common_runtime/device_mgr.h"
28#include "tensorflow/core/common_runtime/dma_helper.h"
29#include "tensorflow/core/common_runtime/process_util.h"
30#include "tensorflow/core/framework/allocator.h"
31#include "tensorflow/core/framework/device_base.h"
32#include "tensorflow/core/framework/op_kernel.h"
33#include "tensorflow/core/framework/tensor.h"
34#include "tensorflow/core/framework/types.h"
35#include "tensorflow/core/lib/core/errors.h"
36#include "tensorflow/core/lib/core/notification.h"
37#include "tensorflow/core/lib/core/status.h"
38#include "tensorflow/core/lib/strings/str_util.h"
39#include "tensorflow/core/lib/strings/strcat.h"
40#include "tensorflow/core/platform/env.h"
41#include "tensorflow/core/platform/types.h"
42
43// Set true for greater intelligibility of debug mode log messages.
44#define READABLE_KEYS false
45// A ring algorithm exchanges chunks of tensor between devices. The chunk size
46// depends on the number of subdivisions specified in the algorithm. If the
47// user does not specify the number of subdivisions we may infer the number
48// dynamically so that the resulting chunk size does not exceed
49// kMaxChunkSizeBytes, empirically set at 4 MiB.
50constexpr size_t kMaxChunkSizeBytes = (4 * 1024 * 1024);
51// kMaxSubdivsPerDeviceDefault is used to give an upper bound on the number of
52// subdivisions dynamically generated when user does not provide the parameter
53// through the collectives API. A reasonable value would be a small
54// multiple of the number of NICs adjacent to each device.
55constexpr int kMaxSubdivsPerDeviceDefault = 2;
56
57namespace tensorflow {
58namespace {
59// Each CollectiveOp implementation is free to define its own
60// BufRendezvous key format. This function produces the key used by
61// RingAlg instances. Note that the exec_key will differentiate between
62// different instances consequently we don't need to further differentiate
63// between subclasses of RingAlg.
64string RingAlgBufKey(const string& name, const string& exec_key, int pass,
65 int section, int source_rank) {
66 if (READABLE_KEYS) {
67 return strings::StrCat(name, "(", exec_key, "):pass(", pass, "):section(",
68 section, "):srcrank(", source_rank, ")");
69 } else {
70 // TODO(b/78352018): Try out some kind of denser encoding, e.g. 128 bit
71 // hash.
72 return strings::StrCat(exec_key, ":", pass, ":", section, ":", source_rank);
73 }
74}
75
76} // namespace
77
78void RingAlg::PCQueue::Enqueue(RingField* rf) {
79 mutex_lock l(pcq_mu_);
80 deque_.push_back(rf);
81 if (waiter_count_ > 0) {
82 cv_.notify_one();
83 }
84}
85
86RingAlg::RingField* RingAlg::PCQueue::Dequeue() {
87 mutex_lock l(pcq_mu_);
88 if (deque_.empty()) {
89 ++waiter_count_;
90 while (deque_.empty()) {
91 cv_.wait(l);
92 }
93 --waiter_count_;
94 }
95 RingField* rf = deque_.front();
96 deque_.pop_front();
97 return rf;
98}
99
100RingAlg::RingAlg(CollectiveType type, const string& name)
101 : type_(type),
102 name_(name),
103 col_ctx_(nullptr),
104 col_params_(nullptr),
105 done_(nullptr),
106 group_size_(-1),
107 num_subdivs_(-1) {}
108
109namespace {
110Status GenerateSubdivsInCollectiveParams(CollectiveParams* col_params) {
111 // This function generates subdivision_offsets. Expect it to be empty when
112 // called.
113 DCHECK(col_params->instance.impl_details.subdiv_offsets.empty());
114
115 if (col_params->instance.impl_details.max_subdivs_per_device == -1) {
116 col_params->instance.impl_details.subdiv_offsets = {0};
117 VLOG(2) << "Limiting to 1 subdivision as max_subdivs_per_device == -1";
118 return OkStatus();
119 }
120
121 if (col_params->instance.shape.num_elements() == 0) {
122 return errors::Internal("shape in CollectiveParams should be non-empty");
123 }
124 const int kAvgDevPerTask =
125 col_params->group.group_size / col_params->group.num_tasks;
126 const int max_subdivs_per_device =
127 (col_params->instance.impl_details.max_subdivs_per_device > 0)
128 ? col_params->instance.impl_details.max_subdivs_per_device
129 : kMaxSubdivsPerDeviceDefault;
130 const int kMaxNumSubdivs = max_subdivs_per_device * kAvgDevPerTask;
131 if (kMaxNumSubdivs <= 0) {
132 return errors::Internal("Unexpected kMaxNumSubdivs ", kMaxNumSubdivs,
133 " in ",
134 col_params->instance.impl_details.collective_name);
135 }
136 // NOTE(ayushd): If no subdiv_offsets have been specified, dynamically add
137 // as many offsets as needed so that the size of tensor chunks <=
138 // kMaxChunkSizeBytes. Empirically, chunks that are too small or too large
139 // lead to worse performance.
140 int num_subdivs = 0;
141 const size_t tensor_size = col_params->instance.shape.num_elements() *
142 DataTypeSize(col_params->instance.data_type);
143 size_t chunk_size;
144 do {
145 ++num_subdivs;
146 int num_chunks = col_params->group.group_size * num_subdivs;
147 chunk_size = tensor_size / num_chunks;
148 VLOG(2) << "num_subdivs " << num_subdivs << " num_chunks " << num_chunks
149 << " chunk_size " << chunk_size;
150 } while (chunk_size > kMaxChunkSizeBytes && num_subdivs < kMaxNumSubdivs);
151 if (num_subdivs <= 0) {
152 return errors::Internal("Unexpected num_subdivs ", num_subdivs, " in ",
153 col_params->instance.impl_details.collective_name);
154 }
155
156 int subdiv_stride = kAvgDevPerTask / num_subdivs;
157 if (subdiv_stride == 0) subdiv_stride = 1;
158 col_params->instance.impl_details.subdiv_offsets.reserve(num_subdivs);
159 for (int sdi = 0; sdi < num_subdivs; ++sdi) {
160 int subdiv_offset = subdiv_stride * sdi;
161 if (sdi % 2 == 1) subdiv_offset *= -1;
162 col_params->instance.impl_details.subdiv_offsets.push_back(subdiv_offset);
163 }
164
165 if (VLOG_IS_ON(2)) {
166 string subdiv_buf;
167 for (const int subdiv_offset :
168 col_params->instance.impl_details.subdiv_offsets) {
169 strings::StrAppend(&subdiv_buf, " ", subdiv_offset);
170 }
171 VLOG(2) << "Dynamically generated " << num_subdivs
172 << " subdiv_offsets:" << subdiv_buf << " tensor_size "
173 << tensor_size << " chunk_size " << chunk_size;
174 }
175
176 return OkStatus();
177}
178} // namespace
179
180Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) {
181 const string& device_name =
182 col_params->group.members[col_params->default_rank].device.name();
183 // Each subdiv permutation is a ring formed by rotating each
184 // single-task subsequence of devices by an offset. This makes most
185 // sense when each task has the same number of devices but we can't
186 // depend on that being the case so we'll compute something that
187 // works in any case.
188
189 // Start by counting the devices in each task.
190 // Precondition: device_names must be sorted so that all devices in
191 // the same task are adjacent.
192 std::vector<int> dev_per_task;
193 const string* prior_task_name = &col_params->group.members[0].task;
194 int dev_count = 1;
195 for (int di = 1; di < col_params->group.group_size; ++di) {
196 if (col_params->group.members[di].task != *prior_task_name) {
197 dev_per_task.push_back(dev_count);
198 dev_count = 1;
199 prior_task_name = &col_params->group.members[di].task;
200 } else {
201 ++dev_count;
202 }
203 }
204 dev_per_task.push_back(dev_count);
205 DCHECK_EQ(col_params->group.num_tasks, dev_per_task.size());
206
207 if (col_params->instance.impl_details.subdiv_offsets.empty()) {
208 TF_RETURN_IF_ERROR(GenerateSubdivsInCollectiveParams(col_params));
209 }
210
211 // Generate a ring permutation for requested offset.
212 VLOG(2) << "Setting up perms for col_params " << col_params
213 << " subdiv_permutations "
214 << &col_params->instance.impl_details.subdiv_permutations;
215 col_params->instance.impl_details.subdiv_permutations.resize(
216 col_params->instance.impl_details.subdiv_offsets.size());
217 col_params->subdiv_rank.resize(
218 col_params->instance.impl_details.subdiv_offsets.size(), -1);
219 for (int sdi = 0;
220 sdi < col_params->instance.impl_details.subdiv_offsets.size(); ++sdi) {
221 std::vector<int>& perm =
222 col_params->instance.impl_details.subdiv_permutations[sdi];
223 DCHECK_EQ(perm.size(), 0);
224 int offset = col_params->instance.impl_details.subdiv_offsets[sdi];
225 // A negative subdivision offset is interpreted as follows:
226 // 1. Reverse the local device ordering.
227 // 2. Begin the subdivision at abs(offset) in the reversed ordering.
228 bool reverse = false;
229 if (offset < 0) {
230 offset = abs(offset);
231 reverse = true;
232 }
233 int prior_dev_count = 0; // sum over prior worker device counts
234 for (int ti = 0; ti < col_params->group.num_tasks; ++ti) {
235 for (int di = 0; di < dev_per_task[ti]; ++di) {
236 int di_offset = (di + offset) % dev_per_task[ti];
237 int offset_di =
238 reverse ? (dev_per_task[ti] - (di_offset + 1)) : di_offset;
239 // Device index in global subdivision permutation.
240 int permuted_di = prior_dev_count + offset_di;
241 int rank = static_cast<int>(perm.size());
242 perm.push_back(permuted_di);
243 if (col_params->group.members[permuted_di].device.name() ==
244 device_name) {
245 DCHECK_EQ(permuted_di, col_params->default_rank);
246 col_params->subdiv_rank[sdi] = rank;
247 }
248 }
249 prior_dev_count += dev_per_task[ti];
250 }
251 DCHECK_EQ(col_params->group.group_size, perm.size());
252 }
253
254 VLOG(2) << collective_util::SubdivPermDebugString(*col_params);
255 return OkStatus();
256}
257
258Status RingAlg::InitializeCollectiveContext(
259 std::shared_ptr<CollectiveContext> col_ctx) {
260 DCHECK(col_ctx->dev_mgr);
261 col_ctx_ = col_ctx;
262 col_params_ = col_ctx->col_params.get();
263 return collective_util::InitializeDeviceAndLocality(
264 col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
265 &col_ctx->device_locality);
266}
267
268string RingAlg::TensorDebugString(const Tensor& tensor) {
269 const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info =
270 col_ctx_->op_ctx->device()->tensorflow_accelerator_device_info();
271 if (accelerator_device_info) {
272 Tensor cpu_tensor(tensor.dtype(), tensor.shape());
273 Status st =
274 accelerator_device_info->default_context->CopyDeviceTensorToCPUSync(
275 &tensor, "" /*tensor_name*/, col_ctx_->device, &cpu_tensor);
276 DCHECK(st.ok());
277 return cpu_tensor.SummarizeValue(64);
278 } else {
279 return tensor.SummarizeValue(64);
280 }
281}
282
283void RingAlg::StartAbort(const Status& s) {
284 // In abort mode we stop issuing additional ProvideBuf
285 // and ConsumeBuf calls, but we need to wait for all of the
286 // outstanding callbacks to be invoked before quitting.
287 bool abort_started = false;
288 {
289 mutex_lock l(status_mu_);
290 if (status_.ok()) {
291 LOG(ERROR) << "Aborting Ring" << name_ << " with " << s;
292 abort_started = true;
293 status_.Update(s);
294 }
295 }
296 // If this is the initial entry to abort mode and it's not a cancellation,
297 // then invoke StartAbort on the CollectiveExecutor that invoked us. That
298 // should start cancellation on all of the outstanding CollectiveRemoteAccess
299 // actions. If it's cancellation all pending send/recv should be cancelled as
300 // well and there's then no need to abort.
301 if (abort_started) {
302 if (col_ctx_->op_ctx->cancellation_manager() == nullptr ||
303 (!col_ctx_->op_ctx->cancellation_manager()->IsCancelled() &&
304 !col_ctx_->op_ctx->cancellation_manager()->IsCancelling())) {
305 col_ctx_->col_exec->StartAbort(s);
306 }
307 }
308}
309
310void RingAlg::Finish(bool ok) {
311 if (ok) {
312 // Recover the output from the adaptor.
313 ca_->ConsumeFinalValue(col_ctx_->output);
314 }
315 Status s;
316 {
317 mutex_lock l(status_mu_);
318 s = status_;
319 }
320 rfv_.clear(); // Give up Refs on output tensor.
321 done_(s);
322}
323
324// At the beginning of the algorithm initialize a RingField struct for
325// every independent field of the tensor.
326void RingAlg::InitRingField(RingField* rf, int chunk_idx, int subdiv_idx,
327 int field_idx) {
328 // Note on field indexing: There are group_size_ devices in the
329 // instance, implying the same number of chunks per tensor, where a
330 // chunk is the unit of data transferred in a time step. However, if
331 // a device can simultaneously send data by 2 or more independent
332 // channels we can speed up the transfer by subdividing chunks and
333 // processing multiple subdivisions at once. So the actual number
334 // of RingFields is group_size_ * num_subdivs_.
335 DCHECK_EQ(field_idx, (chunk_idx * num_subdivs_) + subdiv_idx);
336 rf->chunk_idx = chunk_idx;
337 rf->subdiv_idx = subdiv_idx;
338 rf->sc_idx = field_idx;
339 rf->rank = col_params_->subdiv_rank[subdiv_idx];
340 rf->second_pass = false;
341 rf->action = RF_INIT;
342 // Recv from the device with preceding rank within the subdivision.
343 int recv_from_rank = (rf->rank + (group_size_ - 1)) % group_size_;
344 int send_to_rank = (rf->rank + 1) % group_size_;
345 rf->recv_dev_idx = col_params_->instance.impl_details
346 .subdiv_permutations[subdiv_idx][recv_from_rank];
347 int send_dev_idx = col_params_->instance.impl_details
348 .subdiv_permutations[subdiv_idx][send_to_rank];
349 rf->recv_is_remote = !col_params_->group.members[rf->recv_dev_idx].is_local;
350 rf->send_is_remote = !col_params_->group.members[send_dev_idx].is_local;
351 if (ca_->ChunkBytes(rf->sc_idx) > 0) {
352 // In pass 0 we skip Recv when rank = chunk_idx
353 rf->do_recv = (rf->chunk_idx != rf->rank);
354 // In pass 0 we skip Send when rank = chunk_idx-1
355 rf->do_send =
356 (rf->rank != ((rf->chunk_idx + (group_size_ - 1)) % group_size_));
357 }
358 rf->is_final =
359 (rf->rank == ((rf->chunk_idx + (group_size_ - 1)) % group_size_));
360 if (rf->do_send || rf->do_recv) {
361 rf->chunk = ca_->ChunkAlias(rf->sc_idx);
362 }
363 VLOG(2) << this << " InitRingField " << rf->DebugString() << " chunk "
364 << ca_->TBounds(rf->chunk);
365}
366
367// When a RingField transitions from first to second recompute the
368// do_send and do_recv values.
369void RingAlg::AdvanceToSecondPass(RingField* rf) {
370 VLOG(3) << "IncrRingField old value " << rf->DebugString();
371 DCHECK(!rf->second_pass);
372 rf->second_pass = true;
373 rf->action = RF_INIT;
374 if (ca_->ChunkBytes(rf->sc_idx) > 0) {
375 // In pass 1 the send/no-send boundary moves down 1 place.
376 rf->do_recv =
377 (rf->rank != ((rf->chunk_idx + (group_size_ - 1)) % group_size_));
378 rf->do_send =
379 (rf->rank != ((rf->chunk_idx + (group_size_ - 2)) % group_size_));
380 }
381 rf->is_final =
382 (rf->rank == ((rf->chunk_idx + (group_size_ - 2)) % group_size_));
383 VLOG(3) << "IncrRingField new value " << rf->DebugString();
384}
385
386string RingAlg::RingField::DebugString() const {
387 string rv = strings::StrCat("RingField rank=", rank, " chunk_idx=", chunk_idx,
388 " subdiv=", subdiv_idx, " sc_idx=", sc_idx,
389 " action=", action);
390 strings::StrAppend(&rv, " pass=", second_pass);
391 strings::StrAppend(&rv, " do_send=", do_send, " do_recv=", do_recv,
392 " is_final=", is_final, " recv_is_remote=", recv_is_remote,
393 " recv_dev_idx=", recv_dev_idx, " sc_idx=", sc_idx);
394 return rv;
395}
396
397void RingAlg::DispatchSend(RingField* rf, const StatusCallback& done) {
398 DCHECK(rf->do_send);
399 string send_buf_key = RingAlgBufKey(name_, col_ctx_->exec_key,
400 rf->second_pass, rf->sc_idx, rf->rank);
401 VLOG(3) << "DispatchSend rank=" << col_params_->default_rank << " send key "
402 << send_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " sc_idx "
403 << rf->sc_idx;
404 int send_to_rank = (rf->rank + 1) % group_size_;
405 int send_to_dev_idx = col_params_->instance.impl_details
406 .subdiv_permutations[rf->subdiv_idx][send_to_rank];
407 col_ctx_->col_exec->remote_access()->PostToPeer(
408 col_params_->group.members[send_to_dev_idx].device.name(),
409 col_params_->group.members[send_to_dev_idx].task, send_buf_key,
410 col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
411 col_ctx_->op_ctx->output_alloc_attr(0), &rf->chunk,
412 col_ctx_->device_locality, col_ctx_->op_ctx->cancellation_manager(),
413 done);
414}
415
416void RingAlg::DispatchRecv(RingField* rf, const StatusCallback& done) {
417 DCHECK(rf->do_recv);
418 string recv_buf_key =
419 RingAlgBufKey(name_, col_ctx_->exec_key, rf->second_pass, rf->sc_idx,
420 (rf->rank + (group_size_ - 1)) % group_size_);
421 VLOG(3) << "DispatchRecv rank=" << col_params_->default_rank << " recv key "
422 << recv_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " into "
423 << ((col_params_->merge_op != nullptr) ? "tmp_chunk" : "chunk");
424 Tensor* dst_tensor = (!rf->second_pass && (col_params_->merge_op != nullptr))
425 ? &rf->tmp_chunk
426 : &rf->chunk;
427 col_ctx_->col_exec->remote_access()->RecvFromPeer(
428 col_params_->group.members[rf->recv_dev_idx].device.name(),
429 col_params_->group.members[rf->recv_dev_idx].task,
430 col_params_->group.members[rf->recv_dev_idx].is_local, recv_buf_key,
431 col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
432 col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor,
433 col_ctx_->device_locality, rf->subdiv_idx,
434 col_ctx_->op_ctx->cancellation_manager(), done);
435}
436
437string RingAlg::FieldState() {
438 string s = strings::StrCat(
439 "Ring", name_, " ", strings::Hex(reinterpret_cast<uint64>(this)),
440 " exec ", col_ctx_->exec_key, " step_id=", col_ctx_->step_id,
441 " state of all ", rfv_.size(), " fields:");
442 for (int i = 0; i < rfv_.size(); ++i) {
443 s.append("\n");
444 s.append(rfv_[i].DebugString());
445 }
446 return s;
447}
448
449} // namespace tensorflow
450