1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/common_runtime/all_to_all.h" |
16 | |
17 | #include <utility> |
18 | |
19 | #include "tensorflow/core/common_runtime/collective_rma_local.h" |
20 | #include "tensorflow/core/common_runtime/collective_util.h" |
21 | #include "tensorflow/core/common_runtime/copy_tensor.h" |
22 | #include "tensorflow/core/common_runtime/device.h" |
23 | #include "tensorflow/core/common_runtime/device_mgr.h" |
24 | #include "tensorflow/core/common_runtime/dma_helper.h" |
25 | #include "tensorflow/core/common_runtime/process_util.h" |
26 | #include "tensorflow/core/framework/allocator.h" |
27 | #include "tensorflow/core/framework/device_base.h" |
28 | #include "tensorflow/core/framework/op_kernel.h" |
29 | #include "tensorflow/core/framework/tensor.h" |
30 | #include "tensorflow/core/framework/types.h" |
31 | #include "tensorflow/core/lib/core/errors.h" |
32 | #include "tensorflow/core/lib/core/notification.h" |
33 | #include "tensorflow/core/lib/core/status.h" |
34 | #include "tensorflow/core/lib/strings/str_util.h" |
35 | #include "tensorflow/core/lib/strings/strcat.h" |
36 | #include "tensorflow/core/platform/blocking_counter.h" |
37 | #include "tensorflow/core/platform/env.h" |
38 | #include "tensorflow/core/platform/types.h" |
39 | |
40 | namespace tensorflow { |
41 | |
42 | AllToAll::AllToAll() |
43 | : col_ctx_(nullptr), col_params_(nullptr), done_(nullptr), counter_(0) {} |
44 | |
45 | StatusCallback AllToAll::CheckCounterAndCallDone() { |
46 | return [this](const Status& s) { |
47 | Status final_status; |
48 | { |
49 | mutex_lock l(mu_); |
50 | status_.Update(s); |
51 | ++counter_; |
52 | // For all devices other than itself, there's a send and a receive. We |
53 | // wait until all of them complete. |
54 | if (counter_ < 2 * col_params_->group.group_size) { |
55 | return; |
56 | } |
57 | CHECK_LE(counter_, 2 * col_params_->group.group_size); // Crash ok. |
58 | final_status = status_; |
59 | } |
60 | if (!final_status.ok()) { |
61 | done_(final_status); |
62 | return; |
63 | } |
64 | if (col_ctx_->output->SharesBufferWith(output_buffer_)) { |
65 | done_(final_status); |
66 | } else { |
67 | // We are using a temp buffer. Copy to the output tensor. |
68 | CollectiveRemoteAccessLocal::MemCpyAsync( |
69 | col_ctx_->op_ctx->op_device_context(), |
70 | col_ctx_->op_ctx->op_device_context(), col_ctx_->device, |
71 | col_ctx_->device, col_ctx_->op_ctx->input_alloc_attr(0), |
72 | col_ctx_->op_ctx->output_alloc_attr(0), &output_buffer_, |
73 | col_ctx_->output, /*dev_to_dev_stream_index*/ 0, done_); |
74 | } |
75 | }; |
76 | } |
77 | |
78 | Status AllToAll::InitializeCollectiveContext( |
79 | std::shared_ptr<CollectiveContext> col_ctx) { |
80 | if (col_ctx->input->dim_size(0) != col_ctx->col_params->group.group_size) { |
81 | return errors::InvalidArgument("input to all-to-all first dimension size (" , |
82 | col_ctx->input->dim_size(0), |
83 | ") must be the same as the group size (" , |
84 | col_ctx->col_params->group.group_size, ")" ); |
85 | } |
86 | DCHECK(col_ctx->dev_mgr); |
87 | col_ctx_ = col_ctx; |
88 | col_params_ = col_ctx->col_params.get(); |
89 | return collective_util::InitializeDeviceAndLocality( |
90 | col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device, |
91 | &col_ctx->device_locality); |
92 | } |
93 | |
94 | void AllToAll::Run(StatusCallback done) { |
95 | done_ = std::move(done); |
96 | input_chunks_.reserve(col_params_->group.group_size); |
97 | output_chunks_.reserve(col_params_->group.group_size); |
98 | if (col_ctx_->input->SharesBufferWith(*col_ctx_->output)) { |
99 | // The input is forwarded to the output, and we need to use a temp buffer. |
100 | output_buffer_ = Tensor( |
101 | col_ctx_->device->GetAllocator(col_ctx_->op_ctx->output_alloc_attr(0)), |
102 | col_ctx_->output->dtype(), col_ctx_->output->shape()); |
103 | } else { |
104 | output_buffer_ = *col_ctx_->output; |
105 | } |
106 | for (int i = 0; i < col_params_->group.group_size; ++i) { |
107 | input_chunks_.push_back(col_ctx_->input->SubSlice(i)); |
108 | // Select output index based on user specified rank, if available. |
109 | int output_index = col_params_->group.members[i].rank; |
110 | output_chunks_.push_back(output_buffer_.SubSlice(output_index)); |
111 | } |
112 | |
113 | for (int i = 0; i < col_params_->group.group_size; ++i) { |
114 | auto default_rank = col_params_->default_rank; |
115 | // Issue send request from current device to all devices in group. |
116 | DispatchSend(default_rank, i, &input_chunks_[i], CheckCounterAndCallDone()); |
117 | // Issue receive requests from all devices to current device. |
118 | DispatchRecv(i, default_rank, &output_chunks_[i], |
119 | CheckCounterAndCallDone()); |
120 | } |
121 | } |
122 | |
123 | void AllToAll::DispatchSend(int src_rank, int target_rank, const Tensor* tensor, |
124 | const StatusCallback& done) { |
125 | string send_buf_key = |
126 | strings::StrCat(col_ctx_->exec_key, src_rank, target_rank); |
127 | col_ctx_->col_exec->remote_access()->PostToPeer( |
128 | col_params_->group.members[target_rank].device.name(), |
129 | col_params_->group.members[target_rank].task, send_buf_key, |
130 | col_ctx_->device, col_ctx_->op_ctx->op_device_context(), |
131 | col_ctx_->op_ctx->output_alloc_attr(0), tensor, col_ctx_->device_locality, |
132 | col_ctx_->op_ctx->cancellation_manager(), done); |
133 | } |
134 | |
135 | void AllToAll::DispatchRecv(int src_rank, int target_rank, Tensor* tensor, |
136 | const StatusCallback& done) { |
137 | string recv_buf_key = |
138 | strings::StrCat(col_ctx_->exec_key, src_rank, target_rank); |
139 | col_ctx_->col_exec->remote_access()->RecvFromPeer( |
140 | col_params_->group.members[src_rank].device.name(), |
141 | col_params_->group.members[src_rank].task, |
142 | col_params_->group.members[src_rank].is_local, recv_buf_key, |
143 | col_ctx_->device, col_ctx_->op_ctx->op_device_context(), |
144 | col_ctx_->op_ctx->output_alloc_attr(0), tensor, col_ctx_->device_locality, |
145 | 0, col_ctx_->op_ctx->cancellation_manager(), done); |
146 | } |
147 | |
148 | namespace { |
149 | REGISTER_COLLECTIVE(AllToAll, AllToAll); |
150 | } // namespace |
151 | |
152 | } // namespace tensorflow |
153 | |