1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/common_runtime/all_to_all.h"
16
17#include <utility>
18
19#include "tensorflow/core/common_runtime/collective_rma_local.h"
20#include "tensorflow/core/common_runtime/collective_util.h"
21#include "tensorflow/core/common_runtime/copy_tensor.h"
22#include "tensorflow/core/common_runtime/device.h"
23#include "tensorflow/core/common_runtime/device_mgr.h"
24#include "tensorflow/core/common_runtime/dma_helper.h"
25#include "tensorflow/core/common_runtime/process_util.h"
26#include "tensorflow/core/framework/allocator.h"
27#include "tensorflow/core/framework/device_base.h"
28#include "tensorflow/core/framework/op_kernel.h"
29#include "tensorflow/core/framework/tensor.h"
30#include "tensorflow/core/framework/types.h"
31#include "tensorflow/core/lib/core/errors.h"
32#include "tensorflow/core/lib/core/notification.h"
33#include "tensorflow/core/lib/core/status.h"
34#include "tensorflow/core/lib/strings/str_util.h"
35#include "tensorflow/core/lib/strings/strcat.h"
36#include "tensorflow/core/platform/blocking_counter.h"
37#include "tensorflow/core/platform/env.h"
38#include "tensorflow/core/platform/types.h"
39
40namespace tensorflow {
41
42AllToAll::AllToAll()
43 : col_ctx_(nullptr), col_params_(nullptr), done_(nullptr), counter_(0) {}
44
45StatusCallback AllToAll::CheckCounterAndCallDone() {
46 return [this](const Status& s) {
47 Status final_status;
48 {
49 mutex_lock l(mu_);
50 status_.Update(s);
51 ++counter_;
52 // For all devices other than itself, there's a send and a receive. We
53 // wait until all of them complete.
54 if (counter_ < 2 * col_params_->group.group_size) {
55 return;
56 }
57 CHECK_LE(counter_, 2 * col_params_->group.group_size); // Crash ok.
58 final_status = status_;
59 }
60 if (!final_status.ok()) {
61 done_(final_status);
62 return;
63 }
64 if (col_ctx_->output->SharesBufferWith(output_buffer_)) {
65 done_(final_status);
66 } else {
67 // We are using a temp buffer. Copy to the output tensor.
68 CollectiveRemoteAccessLocal::MemCpyAsync(
69 col_ctx_->op_ctx->op_device_context(),
70 col_ctx_->op_ctx->op_device_context(), col_ctx_->device,
71 col_ctx_->device, col_ctx_->op_ctx->input_alloc_attr(0),
72 col_ctx_->op_ctx->output_alloc_attr(0), &output_buffer_,
73 col_ctx_->output, /*dev_to_dev_stream_index*/ 0, done_);
74 }
75 };
76}
77
78Status AllToAll::InitializeCollectiveContext(
79 std::shared_ptr<CollectiveContext> col_ctx) {
80 if (col_ctx->input->dim_size(0) != col_ctx->col_params->group.group_size) {
81 return errors::InvalidArgument("input to all-to-all first dimension size (",
82 col_ctx->input->dim_size(0),
83 ") must be the same as the group size (",
84 col_ctx->col_params->group.group_size, ")");
85 }
86 DCHECK(col_ctx->dev_mgr);
87 col_ctx_ = col_ctx;
88 col_params_ = col_ctx->col_params.get();
89 return collective_util::InitializeDeviceAndLocality(
90 col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
91 &col_ctx->device_locality);
92}
93
94void AllToAll::Run(StatusCallback done) {
95 done_ = std::move(done);
96 input_chunks_.reserve(col_params_->group.group_size);
97 output_chunks_.reserve(col_params_->group.group_size);
98 if (col_ctx_->input->SharesBufferWith(*col_ctx_->output)) {
99 // The input is forwarded to the output, and we need to use a temp buffer.
100 output_buffer_ = Tensor(
101 col_ctx_->device->GetAllocator(col_ctx_->op_ctx->output_alloc_attr(0)),
102 col_ctx_->output->dtype(), col_ctx_->output->shape());
103 } else {
104 output_buffer_ = *col_ctx_->output;
105 }
106 for (int i = 0; i < col_params_->group.group_size; ++i) {
107 input_chunks_.push_back(col_ctx_->input->SubSlice(i));
108 // Select output index based on user specified rank, if available.
109 int output_index = col_params_->group.members[i].rank;
110 output_chunks_.push_back(output_buffer_.SubSlice(output_index));
111 }
112
113 for (int i = 0; i < col_params_->group.group_size; ++i) {
114 auto default_rank = col_params_->default_rank;
115 // Issue send request from current device to all devices in group.
116 DispatchSend(default_rank, i, &input_chunks_[i], CheckCounterAndCallDone());
117 // Issue receive requests from all devices to current device.
118 DispatchRecv(i, default_rank, &output_chunks_[i],
119 CheckCounterAndCallDone());
120 }
121}
122
123void AllToAll::DispatchSend(int src_rank, int target_rank, const Tensor* tensor,
124 const StatusCallback& done) {
125 string send_buf_key =
126 strings::StrCat(col_ctx_->exec_key, src_rank, target_rank);
127 col_ctx_->col_exec->remote_access()->PostToPeer(
128 col_params_->group.members[target_rank].device.name(),
129 col_params_->group.members[target_rank].task, send_buf_key,
130 col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
131 col_ctx_->op_ctx->output_alloc_attr(0), tensor, col_ctx_->device_locality,
132 col_ctx_->op_ctx->cancellation_manager(), done);
133}
134
135void AllToAll::DispatchRecv(int src_rank, int target_rank, Tensor* tensor,
136 const StatusCallback& done) {
137 string recv_buf_key =
138 strings::StrCat(col_ctx_->exec_key, src_rank, target_rank);
139 col_ctx_->col_exec->remote_access()->RecvFromPeer(
140 col_params_->group.members[src_rank].device.name(),
141 col_params_->group.members[src_rank].task,
142 col_params_->group.members[src_rank].is_local, recv_buf_key,
143 col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
144 col_ctx_->op_ctx->output_alloc_attr(0), tensor, col_ctx_->device_locality,
145 0, col_ctx_->op_ctx->cancellation_manager(), done);
146}
147
148namespace {
149REGISTER_COLLECTIVE(AllToAll, AllToAll);
150} // namespace
151
152} // namespace tensorflow
153