1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/common_runtime/ring_reducer.h"
16
17#include <stdlib.h>
18
19#include <atomic>
20#include <functional>
21#include <utility>
22
23#include "tensorflow/core/common_runtime/collective_rma_local.h"
24#include "tensorflow/core/common_runtime/collective_util.h"
25#include "tensorflow/core/common_runtime/copy_tensor.h"
26#include "tensorflow/core/common_runtime/device.h"
27#include "tensorflow/core/common_runtime/device_mgr.h"
28#include "tensorflow/core/common_runtime/dma_helper.h"
29#include "tensorflow/core/common_runtime/process_util.h"
30#include "tensorflow/core/framework/allocator.h"
31#include "tensorflow/core/framework/device_base.h"
32#include "tensorflow/core/framework/op_kernel.h"
33#include "tensorflow/core/framework/tensor.h"
34#include "tensorflow/core/framework/types.h"
35#include "tensorflow/core/lib/core/errors.h"
36#include "tensorflow/core/lib/core/notification.h"
37#include "tensorflow/core/lib/core/status.h"
38#include "tensorflow/core/lib/strings/str_util.h"
39#include "tensorflow/core/lib/strings/strcat.h"
40#include "tensorflow/core/platform/env.h"
41#include "tensorflow/core/platform/types.h"
42#include "tensorflow/core/profiler/lib/traceme.h"
43
44namespace tensorflow {
45
46RingReducer::~RingReducer() { group_size_tensor_ready_.WaitForNotification(); }
47
48Status RingReducer::InitializeCollectiveParams(CollectiveParams* col_params) {
49 // TODO(b/113171733): change CHECKs to return errors.
50 CHECK_EQ(col_params->instance.type, REDUCTION_COLLECTIVE);
51 CHECK_EQ(col_params->instance.impl_details.collective_name, "RingReduce");
52 return RingAlg::InitializeCollectiveParams(col_params);
53}
54
55void RingReducer::Run(StatusCallback done) {
56 CHECK(col_ctx_);
57 CHECK(col_params_);
58 // Since `RingReducer` doesn't require non-overlapping collectives, unblock
59 // any collective that is blocked on this instance.
60 col_ctx_->col_exec->UnblockDependencies(*col_params_);
61
62 done_ = std::move(done);
63 group_size_ = col_params_->group.group_size;
64 num_subdivs_ = static_cast<int>(
65 col_params_->instance.impl_details.subdiv_permutations.size());
66 CHECK_GT(num_subdivs_, 0);
67
68 if (VLOG_IS_ON(1)) {
69 string buf;
70 for (int r = 0; r < col_params_->group.members.size(); ++r) {
71 strings::StrAppend(&buf, "dev ", r, " : ",
72 col_params_->group.members[r].device.name(), "\n");
73 }
74 for (int sd = 0;
75 sd < col_params_->instance.impl_details.subdiv_permutations.size();
76 ++sd) {
77 strings::StrAppend(&buf, "\nsubdiv ", sd, " perm: ");
78 for (auto x :
79 col_params_->instance.impl_details.subdiv_permutations[sd]) {
80 strings::StrAppend(&buf, x, ", ");
81 }
82 }
83 VLOG(1) << "RingReducer::Run for device " << col_ctx_->device_name
84 << " default_rank " << col_params_->default_rank << "\n"
85 << buf;
86 }
87
88 // Start by copying input to output if they're not already the same, i.e. if
89 // we're not computing in-place on the input tensor.
90 if ((col_ctx_->input != col_ctx_->output) &&
91 (DMAHelper::base(col_ctx_->input) != DMAHelper::base(col_ctx_->output))) {
92 // We are running in a blockable thread and the callback can't block so
93 // just wait here on the copy.
94 Notification note;
95 Status status;
96 profiler::TraceMe activity("MemCpyAsync", profiler::TraceMeLevel::kInfo);
97 CollectiveRemoteAccessLocal::MemCpyAsync(
98 col_ctx_->op_ctx->op_device_context(),
99 col_ctx_->op_ctx->op_device_context(), col_ctx_->device,
100 col_ctx_->device, col_ctx_->op_ctx->input_alloc_attr(0),
101 col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input,
102 col_ctx_->output, 0 /*dev_to_dev_stream_index*/,
103 [&note, &status](const Status& s) {
104 status.Update(s);
105 note.Notify();
106 });
107 note.WaitForNotification();
108 if (!status.ok()) {
109 done_(status);
110 return;
111 }
112 }
113 ContinueAfterInputCopy();
114}
115
116// Note that this function is blocking and must not run in any thread
117// which cannot be blocked.
118void RingReducer::ContinueAfterInputCopy() {
119 AllocatorAttributes attr = col_ctx_->op_ctx->output_alloc_attr(0);
120 ca_.reset(MakeCollectiveAdapter(col_ctx_->output, group_size_ * num_subdivs_,
121 col_ctx_->device->GetAllocator(attr)));
122
123 if (col_params_->final_op) {
124 // Create an on-device scalar value from group_size_ that may be needed
125 // later.
126 // TODO(tucker): Cache and reuse across invocations? Or maybe the scalar
127 // can be provided to the kernel in host memory?
128 Tensor group_size_val = ca_->Scalar(group_size_);
129 if (col_params_->group.device_type != "CPU") {
130 uint64 safe_alloc_frontier = col_ctx_->device->SafeAllocFrontier(0);
131 AllocationAttributes aa;
132 std::function<uint64()> freed_by_func = [this, &safe_alloc_frontier]() {
133 safe_alloc_frontier =
134 col_ctx_->device->SafeAllocFrontier(safe_alloc_frontier);
135 return safe_alloc_frontier;
136 };
137 if (safe_alloc_frontier > 0) {
138 aa.freed_by_func = &freed_by_func;
139 }
140 group_size_tensor_ = ca_->Scalar(
141 col_ctx_->device->GetAllocator(col_ctx_->op_ctx->input_alloc_attr(0)),
142 aa);
143 DeviceContext* op_dev_ctx = col_ctx_->op_ctx->op_device_context();
144 op_dev_ctx->CopyCPUTensorToDevice(
145 &group_size_val, col_ctx_->device, &group_size_tensor_,
146 [this](const Status& s) {
147 if (!s.ok()) {
148 StartAbort(s);
149 }
150 group_size_tensor_ready_.Notify();
151 },
152 (safe_alloc_frontier == 0));
153 } else {
154 group_size_tensor_ = group_size_val;
155 group_size_tensor_ready_.Notify();
156 }
157 } else {
158 // Value won't be used, so no need to initialize.
159 group_size_tensor_ready_.Notify();
160 }
161 Finish(RunAsyncParts());
162}
163
164void RingReducer::InitRingField(RingField* rf, int chunk_idx, int subdiv_idx,
165 int field_idx) {
166 RingAlg::InitRingField(rf, chunk_idx, subdiv_idx, field_idx);
167 if (rf->do_recv) {
168 rf->tmp_chunk = ca_->TempChunk(rf->sc_idx);
169 }
170}
171
172// At the beginning of the algorithm initialize a RingField struct for
173// every independent field of the tensor.
174bool RingReducer::RunAsyncParts() {
175 // This function orchestrates RingReduce actions on behalf of a
176 // single device. It is entered by a blockable thread that
177 // loops within it until all actions assigned to that device
178 // complete. Hence function local variables are accessible only by that
179 // one thread and do not require an explicit mutex.
180 rfv_.clear();
181 rfv_.resize(group_size_ * num_subdivs_);
182 PCQueue ready_queue;
183 for (int chunk_idx = 0; chunk_idx < group_size_; ++chunk_idx) {
184 for (int subdiv_idx = 0; subdiv_idx < num_subdivs_; ++subdiv_idx) {
185 int rf_index = (chunk_idx * num_subdivs_) + subdiv_idx;
186 InitRingField(&rfv_[rf_index], chunk_idx, subdiv_idx, rf_index);
187 ready_queue.Enqueue(&rfv_[rf_index]);
188 }
189 }
190 const DeviceBase::AcceleratorDeviceInfo* gpu_info =
191 col_ctx_->device->tensorflow_accelerator_device_info();
192 if (gpu_info) {
193 // Wait for all currently queued events on the CPU compute stream to
194 // complete before proceeding. The previous InitRingField calls allocated
195 // temp memory buffers that are not guaranteed to be valid (e.g. for RDMA
196 // write) unless we do.
197 profiler::TraceMe activity("WaitForQueuedEvents",
198 profiler::TraceMeLevel::kInfo);
199 Notification note;
200 Status s = gpu_info->default_context->ThenExecute(
201 col_ctx_->device, gpu_info->stream, [&note]() { note.Notify(); });
202 if (s.ok()) {
203 note.WaitForNotification();
204 } else {
205 mutex_lock l(status_mu_);
206 status_ =
207 errors::Internal("Failed to dispatch ThenExecute in RingReducer");
208 return false;
209 }
210 }
211
212 int field_done_count = 0;
213 int send_pending_count = 0;
214 int recv_pending_count = 0;
215 std::atomic<bool> aborted(false);
216
217 {
218 profiler::TraceMe activity("Loop", profiler::TraceMeLevel::kInfo);
219 // Loop until all RingFields have advanced to completion.
220 while (field_done_count < rfv_.size()) {
221 VLOG(4) << FieldState();
222 // Wait for a RingField to appear in the ready_queue.
223 RingField* rf = ready_queue.Dequeue();
224 // Advance the RingField to its next action and execute, repeating
225 // until either an async action has been started or the RingField
226 // is done.
227 bool dispatched = false; // true if async action was initiated
228 do {
229 if (aborted) {
230 // Requeue this RingField to be counted off below.
231 ready_queue.Enqueue(rf);
232 break;
233 }
234 switch (rf->action) {
235 case RF_INIT:
236 if (rf->do_recv) {
237 rf->action = RF_RECV;
238 auto requeue = [this, rf, &ready_queue, &aborted](Status s) {
239 if (!s.ok()) {
240 aborted = true;
241 StartAbort(s);
242 }
243 ready_queue.Enqueue(rf);
244 };
245 DispatchRecv(rf, requeue);
246 dispatched = true;
247 ++recv_pending_count;
248 } else {
249 rf->action = RF_SEND_READY;
250 }
251 break;
252 case RF_RECV:
253 CHECK_GT(recv_pending_count, 0);
254 --recv_pending_count;
255 if (!rf->second_pass) {
256 rf->action = RF_REDUCE;
257 Status s = collective_util::ComputeBinOp(
258 col_ctx_->op_ctx, col_ctx_->op_params, col_ctx_->device,
259 col_params_->merge_op, &rf->chunk, &rf->tmp_chunk);
260 if (!s.ok()) {
261 aborted = true;
262 StartAbort(s);
263 }
264 } else {
265 rf->action = RF_SEND_READY;
266 }
267 break;
268 case RF_REDUCE:
269 if (!rf->second_pass && col_params_->final_op && rf->is_final) {
270 rf->action = RF_FINALIZE;
271 group_size_tensor_ready_.WaitForNotification();
272 Status s = collective_util::ComputeBinOp(
273 col_ctx_->op_ctx, col_ctx_->op_params, col_ctx_->device,
274 col_params_->final_op, &rf->chunk, &group_size_tensor_);
275 if (!s.ok()) {
276 aborted = true;
277 StartAbort(s);
278 }
279 } else {
280 rf->action = RF_SEND_READY;
281 }
282 break;
283 case RF_FINALIZE:
284 rf->action = RF_DONE;
285 break;
286 case RF_SEND_READY:
287 if (rf->do_send) {
288 rf->action = RF_SEND;
289 auto send_complete = [this, rf, &ready_queue,
290 &aborted](Status s) {
291 if (!s.ok()) {
292 aborted = true;
293 StartAbort(s);
294 }
295 ready_queue.Enqueue(rf);
296 };
297 DispatchSend(rf, send_complete);
298 dispatched = true;
299 ++send_pending_count;
300 } else {
301 rf->action = RF_DONE;
302 }
303 break;
304 case RF_SEND:
305 CHECK_GT(send_pending_count, 0);
306 --send_pending_count;
307 rf->action = RF_DONE;
308 break;
309 case RF_DONE:
310 break;
311 }
312 if (rf->action == RF_DONE) {
313 if (rf->second_pass) {
314 ++field_done_count;
315 break; // from do while(!dispatched)
316 } else {
317 AdvanceToSecondPass(rf);
318 }
319 }
320 } while (!dispatched);
321 if (aborted) break;
322 } // while (field_done_count < number of fields)
323
324 if (aborted) {
325 // All of the pending data actions should be aborted; field the
326 // callbacks and clear the queue before quitting.
327 while ((send_pending_count > 0) || (recv_pending_count > 0)) {
328 RingField* rf = ready_queue.Dequeue();
329 switch (rf->action) {
330 case RF_RECV:
331 --recv_pending_count;
332 break;
333 case RF_SEND:
334 --send_pending_count;
335 break;
336 default: {
337 } // Ignore any other actions
338 }
339 }
340 }
341 }
342
343 CHECK_EQ(send_pending_count, 0);
344 CHECK_EQ(recv_pending_count, 0);
345
346 VLOG(2) << this << " device=" << col_ctx_->device_name << " finish;"
347 << " final value " << TensorDebugString(ca_->Value());
348 return !aborted;
349}
350
351namespace {
352REGISTER_COLLECTIVE(RingReduce, RingReducer);
353} // namespace
354
355} // namespace tensorflow
356