1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/common_runtime/ring_gatherer.h"
16
17#include <stdlib.h>
18
19#include <atomic>
20#include <functional>
21#include <utility>
22
23#include "tensorflow/core/common_runtime/collective_rma_local.h"
24#include "tensorflow/core/common_runtime/collective_util.h"
25#include "tensorflow/core/common_runtime/copy_tensor.h"
26#include "tensorflow/core/common_runtime/device.h"
27#include "tensorflow/core/common_runtime/device_mgr.h"
28#include "tensorflow/core/common_runtime/dma_helper.h"
29#include "tensorflow/core/common_runtime/process_util.h"
30#include "tensorflow/core/framework/allocator.h"
31#include "tensorflow/core/framework/device_base.h"
32#include "tensorflow/core/framework/op_kernel.h"
33#include "tensorflow/core/framework/tensor.h"
34#include "tensorflow/core/framework/types.h"
35#include "tensorflow/core/lib/core/errors.h"
36#include "tensorflow/core/lib/core/notification.h"
37#include "tensorflow/core/lib/core/status.h"
38#include "tensorflow/core/lib/strings/str_util.h"
39#include "tensorflow/core/lib/strings/strcat.h"
40#include "tensorflow/core/platform/env.h"
41#include "tensorflow/core/platform/types.h"
42#include "tensorflow/core/profiler/lib/traceme.h"
43
44namespace tensorflow {
45Status RingGatherer::InitializeCollectiveParams(CollectiveParams* col_params) {
46 DCHECK_EQ(col_params->instance.type, GATHER_COLLECTIVE);
47 DCHECK_EQ(col_params->instance.impl_details.collective_name, "RingGather");
48 // TODO(tucker): Maybe add subdiv support. It's only useful with
49 // multiple NICS, and maybe gather performance isn't important enough.
50 // For now, there must always be only a single subdiv at offset 0.
51 if (!col_params->instance.impl_details.subdiv_offsets.empty() &&
52 (col_params->instance.impl_details.subdiv_offsets.size() > 1 ||
53 col_params->instance.impl_details.subdiv_offsets[0] != 0)) {
54 return errors::InvalidArgument(
55 "RingGather cannot take any subdiv offset other than 0.");
56 }
57 if (col_params->instance.impl_details.subdiv_offsets.empty()) {
58 col_params->instance.impl_details.subdiv_offsets.push_back(0);
59 }
60 return RingAlg::InitializeCollectiveParams(col_params);
61}
62
63void RingGatherer::Run(StatusCallback done) {
64 DCHECK(col_ctx_);
65 DCHECK(col_params_);
66 done_ = std::move(done);
67 group_size_ = col_params_->group.group_size;
68 num_subdivs_ = static_cast<int>(
69 col_params_->instance.impl_details.subdiv_permutations.size());
70 DCHECK_GT(num_subdivs_, 0);
71
72 if (VLOG_IS_ON(1)) {
73 string buf;
74 for (int r = 0; r < col_params_->group.members.size(); ++r) {
75 strings::StrAppend(&buf, "dev ", r, " : ",
76 col_params_->group.members[r].device.name(), "\n");
77 }
78 for (int sd = 0;
79 sd < col_params_->instance.impl_details.subdiv_permutations.size();
80 ++sd) {
81 strings::StrAppend(&buf, "\nsubdiv ", sd, " perm: ");
82 for (auto x :
83 col_params_->instance.impl_details.subdiv_permutations[sd]) {
84 strings::StrAppend(&buf, x, ", ");
85 }
86 }
87 VLOG(1) << "RingGatherer::Run for device " << col_ctx_->device_name
88 << " default_rank " << col_params_->default_rank << "\n"
89 << buf;
90 }
91
92 // Prepare to alias fields within the output.
93 AllocatorAttributes attr = col_ctx_->op_ctx->output_alloc_attr(0);
94 ca_.reset(MakeCollectiveAdapter(col_ctx_->output, group_size_ * num_subdivs_,
95 col_ctx_->device->GetAllocator(attr),
96 false /*align_chunks*/));
97
98 // Start by copying input to the rank-specific offset of output.
99 // We are running in a blockable thread and the callback can't block so
100 // just wait here on the copy.
101 {
102 profiler::TraceMe activity("MemCpyAsync", profiler::TraceMeLevel::kInfo);
103 Notification note;
104 Status status;
105 Tensor alias_chunk(ca_->ChunkAlias(col_params_->subdiv_rank[0]));
106 CollectiveRemoteAccessLocal::MemCpyAsync(
107 col_ctx_->op_ctx->op_device_context(),
108 col_ctx_->op_ctx->op_device_context(), col_ctx_->device,
109 col_ctx_->device, col_ctx_->op_ctx->input_alloc_attr(0),
110 col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input, &alias_chunk,
111 0 /*dev_to_dev_stream_index*/, [&note, &status](const Status& s) {
112 status.Update(s);
113 note.Notify();
114 });
115 note.WaitForNotification();
116 if (!status.ok()) {
117 done_(status);
118 return;
119 }
120 }
121 Finish(RunAsyncParts());
122}
123
124bool RingGatherer::RunAsyncParts() {
125 // This function orchestrates RingGatherer actions on behalf of a
126 // single device. It is entered by a blockable thread that
127 // loops within it until all actions assigned to that device
128 // complete. Hence function local variables are accessible only by that
129 // one thread and do not require an explicit mutex.
130 rfv_.clear();
131 rfv_.resize(group_size_ * num_subdivs_);
132 PCQueue ready_queue;
133 for (int chunk_idx = 0; chunk_idx < group_size_; ++chunk_idx) {
134 for (int subdiv_idx = 0; subdiv_idx < num_subdivs_; ++subdiv_idx) {
135 int rf_index = (chunk_idx * num_subdivs_) + subdiv_idx;
136 InitRingField(&rfv_[rf_index], chunk_idx, subdiv_idx, rf_index);
137 ready_queue.Enqueue(&rfv_[rf_index]);
138 }
139 }
140 const DeviceBase::AcceleratorDeviceInfo* gpu_info =
141 col_ctx_->device->tensorflow_accelerator_device_info();
142 if (gpu_info) {
143 // Wait for all currently queued events on the CPU compute stream to
144 // complete before proceeding. The previous InitRingField calls allocated
145 // temp memory buffers that are not guaranteed to be valid (e.g. for RDMA
146 // write) unless we do.
147 profiler::TraceMe activity("WaitForQueuedEvents",
148 profiler::TraceMeLevel::kInfo);
149 Notification note;
150 Status s = gpu_info->default_context->ThenExecute(
151 col_ctx_->device, gpu_info->stream, [&note]() { note.Notify(); });
152 if (s.ok()) {
153 note.WaitForNotification();
154 } else {
155 mutex_lock l(status_mu_);
156 status_ =
157 errors::Internal("Failed to dispatch ThenExecute in RingGatherer");
158 return false;
159 }
160 }
161
162 int field_done_count = 0;
163 int send_pending_count = 0;
164 int recv_pending_count = 0;
165 std::atomic<bool> aborted(false);
166
167 // Loop until all RingFields have advanced to completion.
168 {
169 profiler::TraceMe activity("Loop", profiler::TraceMeLevel::kInfo);
170 while (field_done_count < rfv_.size()) {
171 VLOG(4) << FieldState();
172 // Wait for a RingField to appear in the ready_queue.
173 RingField* rf = ready_queue.Dequeue();
174 // Advance the RingField to its next action and execute, repeating
175 // until either an async action has been started or the RingField
176 // is done.
177 bool dispatched = false; // true if async action was initiated
178 do {
179 if (aborted) {
180 // Requeue this RingField to be counted off below.
181 ready_queue.Enqueue(rf);
182 break;
183 }
184 switch (rf->action) {
185 case RF_INIT:
186 if (rf->do_recv) {
187 rf->action = RF_RECV;
188 auto requeue = [this, rf, &ready_queue, &aborted](Status s) {
189 if (!s.ok()) {
190 aborted = true;
191 StartAbort(s);
192 }
193 ready_queue.Enqueue(rf);
194 };
195 DispatchRecv(rf, requeue);
196 dispatched = true;
197 ++recv_pending_count;
198 } else {
199 rf->action = RF_SEND_READY;
200 }
201 break;
202 case RF_RECV:
203 DCHECK_GT(recv_pending_count, 0);
204 --recv_pending_count;
205 rf->action = RF_SEND_READY;
206 break;
207 case RF_REDUCE:
208 // Never used for Gather, so just fall through.
209 TF_FALLTHROUGH_INTENDED;
210 case RF_FINALIZE:
211 // Never used for Gather, so just fall through.
212 TF_FALLTHROUGH_INTENDED;
213 case RF_SEND_READY:
214 if (rf->do_send) {
215 rf->action = RF_SEND;
216 auto send_complete = [this, rf, &ready_queue,
217 &aborted](Status s) {
218 if (!s.ok()) {
219 aborted = true;
220 StartAbort(s);
221 }
222 ready_queue.Enqueue(rf);
223 };
224 DispatchSend(rf, send_complete);
225 dispatched = true;
226 ++send_pending_count;
227 } else {
228 rf->action = RF_DONE;
229 }
230 break;
231 case RF_SEND:
232 DCHECK_GT(send_pending_count, 0);
233 --send_pending_count;
234 rf->action = RF_DONE;
235 break;
236 case RF_DONE:
237 break;
238 }
239 if (rf->action == RF_DONE) {
240 // There's only one pass.
241 ++field_done_count;
242 break; // from do while(!dispatched)
243 }
244 } while (!dispatched);
245 if (aborted) break;
246 } // while (field_done_count < number of fields)
247
248 if (aborted) {
249 // All of the pending data actions should be aborted; field the
250 // callbacks and clear the queue before quitting.
251 while ((send_pending_count > 0) || (recv_pending_count > 0)) {
252 RingField* rf = ready_queue.Dequeue();
253 switch (rf->action) {
254 case RF_RECV:
255 --recv_pending_count;
256 break;
257 case RF_SEND:
258 --send_pending_count;
259 break;
260 default: {
261 } // Ignore any other actions
262 }
263 }
264 }
265 }
266
267 DCHECK_EQ(send_pending_count, 0);
268 DCHECK_EQ(recv_pending_count, 0);
269
270 VLOG(2) << this << " device=" << col_ctx_->device_name << " finish;"
271 << " final value " << TensorDebugString(ca_->Value());
272 return !aborted;
273}
274
275namespace {
276REGISTER_COLLECTIVE(RingGather, RingGatherer);
277} // namespace
278
279} // namespace tensorflow
280