1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/common_runtime/ring_reducer.h" |
16 | |
17 | #include <stdlib.h> |
18 | |
19 | #include <atomic> |
20 | #include <functional> |
21 | #include <utility> |
22 | |
23 | #include "tensorflow/core/common_runtime/collective_rma_local.h" |
24 | #include "tensorflow/core/common_runtime/collective_util.h" |
25 | #include "tensorflow/core/common_runtime/copy_tensor.h" |
26 | #include "tensorflow/core/common_runtime/device.h" |
27 | #include "tensorflow/core/common_runtime/device_mgr.h" |
28 | #include "tensorflow/core/common_runtime/dma_helper.h" |
29 | #include "tensorflow/core/common_runtime/process_util.h" |
30 | #include "tensorflow/core/framework/allocator.h" |
31 | #include "tensorflow/core/framework/device_base.h" |
32 | #include "tensorflow/core/framework/op_kernel.h" |
33 | #include "tensorflow/core/framework/tensor.h" |
34 | #include "tensorflow/core/framework/types.h" |
35 | #include "tensorflow/core/lib/core/errors.h" |
36 | #include "tensorflow/core/lib/core/notification.h" |
37 | #include "tensorflow/core/lib/core/status.h" |
38 | #include "tensorflow/core/lib/strings/str_util.h" |
39 | #include "tensorflow/core/lib/strings/strcat.h" |
40 | #include "tensorflow/core/platform/env.h" |
41 | #include "tensorflow/core/platform/types.h" |
42 | #include "tensorflow/core/profiler/lib/traceme.h" |
43 | |
44 | namespace tensorflow { |
45 | |
46 | RingReducer::~RingReducer() { group_size_tensor_ready_.WaitForNotification(); } |
47 | |
48 | Status RingReducer::InitializeCollectiveParams(CollectiveParams* col_params) { |
49 | // TODO(b/113171733): change CHECKs to return errors. |
50 | CHECK_EQ(col_params->instance.type, REDUCTION_COLLECTIVE); |
51 | CHECK_EQ(col_params->instance.impl_details.collective_name, "RingReduce" ); |
52 | return RingAlg::InitializeCollectiveParams(col_params); |
53 | } |
54 | |
55 | void RingReducer::Run(StatusCallback done) { |
56 | CHECK(col_ctx_); |
57 | CHECK(col_params_); |
58 | // Since `RingReducer` doesn't require non-overlapping collectives, unblock |
59 | // any collective that is blocked on this instance. |
60 | col_ctx_->col_exec->UnblockDependencies(*col_params_); |
61 | |
62 | done_ = std::move(done); |
63 | group_size_ = col_params_->group.group_size; |
64 | num_subdivs_ = static_cast<int>( |
65 | col_params_->instance.impl_details.subdiv_permutations.size()); |
66 | CHECK_GT(num_subdivs_, 0); |
67 | |
68 | if (VLOG_IS_ON(1)) { |
69 | string buf; |
70 | for (int r = 0; r < col_params_->group.members.size(); ++r) { |
71 | strings::StrAppend(&buf, "dev " , r, " : " , |
72 | col_params_->group.members[r].device.name(), "\n" ); |
73 | } |
74 | for (int sd = 0; |
75 | sd < col_params_->instance.impl_details.subdiv_permutations.size(); |
76 | ++sd) { |
77 | strings::StrAppend(&buf, "\nsubdiv " , sd, " perm: " ); |
78 | for (auto x : |
79 | col_params_->instance.impl_details.subdiv_permutations[sd]) { |
80 | strings::StrAppend(&buf, x, ", " ); |
81 | } |
82 | } |
83 | VLOG(1) << "RingReducer::Run for device " << col_ctx_->device_name |
84 | << " default_rank " << col_params_->default_rank << "\n" |
85 | << buf; |
86 | } |
87 | |
88 | // Start by copying input to output if they're not already the same, i.e. if |
89 | // we're not computing in-place on the input tensor. |
90 | if ((col_ctx_->input != col_ctx_->output) && |
91 | (DMAHelper::base(col_ctx_->input) != DMAHelper::base(col_ctx_->output))) { |
92 | // We are running in a blockable thread and the callback can't block so |
93 | // just wait here on the copy. |
94 | Notification note; |
95 | Status status; |
96 | profiler::TraceMe activity("MemCpyAsync" , profiler::TraceMeLevel::kInfo); |
97 | CollectiveRemoteAccessLocal::MemCpyAsync( |
98 | col_ctx_->op_ctx->op_device_context(), |
99 | col_ctx_->op_ctx->op_device_context(), col_ctx_->device, |
100 | col_ctx_->device, col_ctx_->op_ctx->input_alloc_attr(0), |
101 | col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input, |
102 | col_ctx_->output, 0 /*dev_to_dev_stream_index*/, |
103 | [¬e, &status](const Status& s) { |
104 | status.Update(s); |
105 | note.Notify(); |
106 | }); |
107 | note.WaitForNotification(); |
108 | if (!status.ok()) { |
109 | done_(status); |
110 | return; |
111 | } |
112 | } |
113 | ContinueAfterInputCopy(); |
114 | } |
115 | |
116 | // Note that this function is blocking and must not run in any thread |
117 | // which cannot be blocked. |
118 | void RingReducer::ContinueAfterInputCopy() { |
119 | AllocatorAttributes attr = col_ctx_->op_ctx->output_alloc_attr(0); |
120 | ca_.reset(MakeCollectiveAdapter(col_ctx_->output, group_size_ * num_subdivs_, |
121 | col_ctx_->device->GetAllocator(attr))); |
122 | |
123 | if (col_params_->final_op) { |
124 | // Create an on-device scalar value from group_size_ that may be needed |
125 | // later. |
126 | // TODO(tucker): Cache and reuse across invocations? Or maybe the scalar |
127 | // can be provided to the kernel in host memory? |
128 | Tensor group_size_val = ca_->Scalar(group_size_); |
129 | if (col_params_->group.device_type != "CPU" ) { |
130 | uint64 safe_alloc_frontier = col_ctx_->device->SafeAllocFrontier(0); |
131 | AllocationAttributes aa; |
132 | std::function<uint64()> freed_by_func = [this, &safe_alloc_frontier]() { |
133 | safe_alloc_frontier = |
134 | col_ctx_->device->SafeAllocFrontier(safe_alloc_frontier); |
135 | return safe_alloc_frontier; |
136 | }; |
137 | if (safe_alloc_frontier > 0) { |
138 | aa.freed_by_func = &freed_by_func; |
139 | } |
140 | group_size_tensor_ = ca_->Scalar( |
141 | col_ctx_->device->GetAllocator(col_ctx_->op_ctx->input_alloc_attr(0)), |
142 | aa); |
143 | DeviceContext* op_dev_ctx = col_ctx_->op_ctx->op_device_context(); |
144 | op_dev_ctx->CopyCPUTensorToDevice( |
145 | &group_size_val, col_ctx_->device, &group_size_tensor_, |
146 | [this](const Status& s) { |
147 | if (!s.ok()) { |
148 | StartAbort(s); |
149 | } |
150 | group_size_tensor_ready_.Notify(); |
151 | }, |
152 | (safe_alloc_frontier == 0)); |
153 | } else { |
154 | group_size_tensor_ = group_size_val; |
155 | group_size_tensor_ready_.Notify(); |
156 | } |
157 | } else { |
158 | // Value won't be used, so no need to initialize. |
159 | group_size_tensor_ready_.Notify(); |
160 | } |
161 | Finish(RunAsyncParts()); |
162 | } |
163 | |
164 | void RingReducer::InitRingField(RingField* rf, int chunk_idx, int subdiv_idx, |
165 | int field_idx) { |
166 | RingAlg::InitRingField(rf, chunk_idx, subdiv_idx, field_idx); |
167 | if (rf->do_recv) { |
168 | rf->tmp_chunk = ca_->TempChunk(rf->sc_idx); |
169 | } |
170 | } |
171 | |
172 | // At the beginning of the algorithm initialize a RingField struct for |
173 | // every independent field of the tensor. |
174 | bool RingReducer::RunAsyncParts() { |
175 | // This function orchestrates RingReduce actions on behalf of a |
176 | // single device. It is entered by a blockable thread that |
177 | // loops within it until all actions assigned to that device |
178 | // complete. Hence function local variables are accessible only by that |
179 | // one thread and do not require an explicit mutex. |
180 | rfv_.clear(); |
181 | rfv_.resize(group_size_ * num_subdivs_); |
182 | PCQueue ready_queue; |
183 | for (int chunk_idx = 0; chunk_idx < group_size_; ++chunk_idx) { |
184 | for (int subdiv_idx = 0; subdiv_idx < num_subdivs_; ++subdiv_idx) { |
185 | int rf_index = (chunk_idx * num_subdivs_) + subdiv_idx; |
186 | InitRingField(&rfv_[rf_index], chunk_idx, subdiv_idx, rf_index); |
187 | ready_queue.Enqueue(&rfv_[rf_index]); |
188 | } |
189 | } |
190 | const DeviceBase::AcceleratorDeviceInfo* gpu_info = |
191 | col_ctx_->device->tensorflow_accelerator_device_info(); |
192 | if (gpu_info) { |
193 | // Wait for all currently queued events on the CPU compute stream to |
194 | // complete before proceeding. The previous InitRingField calls allocated |
195 | // temp memory buffers that are not guaranteed to be valid (e.g. for RDMA |
196 | // write) unless we do. |
197 | profiler::TraceMe activity("WaitForQueuedEvents" , |
198 | profiler::TraceMeLevel::kInfo); |
199 | Notification note; |
200 | Status s = gpu_info->default_context->ThenExecute( |
201 | col_ctx_->device, gpu_info->stream, [¬e]() { note.Notify(); }); |
202 | if (s.ok()) { |
203 | note.WaitForNotification(); |
204 | } else { |
205 | mutex_lock l(status_mu_); |
206 | status_ = |
207 | errors::Internal("Failed to dispatch ThenExecute in RingReducer" ); |
208 | return false; |
209 | } |
210 | } |
211 | |
212 | int field_done_count = 0; |
213 | int send_pending_count = 0; |
214 | int recv_pending_count = 0; |
215 | std::atomic<bool> aborted(false); |
216 | |
217 | { |
218 | profiler::TraceMe activity("Loop" , profiler::TraceMeLevel::kInfo); |
219 | // Loop until all RingFields have advanced to completion. |
220 | while (field_done_count < rfv_.size()) { |
221 | VLOG(4) << FieldState(); |
222 | // Wait for a RingField to appear in the ready_queue. |
223 | RingField* rf = ready_queue.Dequeue(); |
224 | // Advance the RingField to its next action and execute, repeating |
225 | // until either an async action has been started or the RingField |
226 | // is done. |
227 | bool dispatched = false; // true if async action was initiated |
228 | do { |
229 | if (aborted) { |
230 | // Requeue this RingField to be counted off below. |
231 | ready_queue.Enqueue(rf); |
232 | break; |
233 | } |
234 | switch (rf->action) { |
235 | case RF_INIT: |
236 | if (rf->do_recv) { |
237 | rf->action = RF_RECV; |
238 | auto requeue = [this, rf, &ready_queue, &aborted](Status s) { |
239 | if (!s.ok()) { |
240 | aborted = true; |
241 | StartAbort(s); |
242 | } |
243 | ready_queue.Enqueue(rf); |
244 | }; |
245 | DispatchRecv(rf, requeue); |
246 | dispatched = true; |
247 | ++recv_pending_count; |
248 | } else { |
249 | rf->action = RF_SEND_READY; |
250 | } |
251 | break; |
252 | case RF_RECV: |
253 | CHECK_GT(recv_pending_count, 0); |
254 | --recv_pending_count; |
255 | if (!rf->second_pass) { |
256 | rf->action = RF_REDUCE; |
257 | Status s = collective_util::ComputeBinOp( |
258 | col_ctx_->op_ctx, col_ctx_->op_params, col_ctx_->device, |
259 | col_params_->merge_op, &rf->chunk, &rf->tmp_chunk); |
260 | if (!s.ok()) { |
261 | aborted = true; |
262 | StartAbort(s); |
263 | } |
264 | } else { |
265 | rf->action = RF_SEND_READY; |
266 | } |
267 | break; |
268 | case RF_REDUCE: |
269 | if (!rf->second_pass && col_params_->final_op && rf->is_final) { |
270 | rf->action = RF_FINALIZE; |
271 | group_size_tensor_ready_.WaitForNotification(); |
272 | Status s = collective_util::ComputeBinOp( |
273 | col_ctx_->op_ctx, col_ctx_->op_params, col_ctx_->device, |
274 | col_params_->final_op, &rf->chunk, &group_size_tensor_); |
275 | if (!s.ok()) { |
276 | aborted = true; |
277 | StartAbort(s); |
278 | } |
279 | } else { |
280 | rf->action = RF_SEND_READY; |
281 | } |
282 | break; |
283 | case RF_FINALIZE: |
284 | rf->action = RF_DONE; |
285 | break; |
286 | case RF_SEND_READY: |
287 | if (rf->do_send) { |
288 | rf->action = RF_SEND; |
289 | auto send_complete = [this, rf, &ready_queue, |
290 | &aborted](Status s) { |
291 | if (!s.ok()) { |
292 | aborted = true; |
293 | StartAbort(s); |
294 | } |
295 | ready_queue.Enqueue(rf); |
296 | }; |
297 | DispatchSend(rf, send_complete); |
298 | dispatched = true; |
299 | ++send_pending_count; |
300 | } else { |
301 | rf->action = RF_DONE; |
302 | } |
303 | break; |
304 | case RF_SEND: |
305 | CHECK_GT(send_pending_count, 0); |
306 | --send_pending_count; |
307 | rf->action = RF_DONE; |
308 | break; |
309 | case RF_DONE: |
310 | break; |
311 | } |
312 | if (rf->action == RF_DONE) { |
313 | if (rf->second_pass) { |
314 | ++field_done_count; |
315 | break; // from do while(!dispatched) |
316 | } else { |
317 | AdvanceToSecondPass(rf); |
318 | } |
319 | } |
320 | } while (!dispatched); |
321 | if (aborted) break; |
322 | } // while (field_done_count < number of fields) |
323 | |
324 | if (aborted) { |
325 | // All of the pending data actions should be aborted; field the |
326 | // callbacks and clear the queue before quitting. |
327 | while ((send_pending_count > 0) || (recv_pending_count > 0)) { |
328 | RingField* rf = ready_queue.Dequeue(); |
329 | switch (rf->action) { |
330 | case RF_RECV: |
331 | --recv_pending_count; |
332 | break; |
333 | case RF_SEND: |
334 | --send_pending_count; |
335 | break; |
336 | default: { |
337 | } // Ignore any other actions |
338 | } |
339 | } |
340 | } |
341 | } |
342 | |
343 | CHECK_EQ(send_pending_count, 0); |
344 | CHECK_EQ(recv_pending_count, 0); |
345 | |
346 | VLOG(2) << this << " device=" << col_ctx_->device_name << " finish;" |
347 | << " final value " << TensorDebugString(ca_->Value()); |
348 | return !aborted; |
349 | } |
350 | |
351 | namespace { |
352 | REGISTER_COLLECTIVE(RingReduce, RingReducer); |
353 | } // namespace |
354 | |
355 | } // namespace tensorflow |
356 | |