1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/common_runtime/ring_gatherer.h" |
16 | |
17 | #include <stdlib.h> |
18 | |
19 | #include <atomic> |
20 | #include <functional> |
21 | #include <utility> |
22 | |
23 | #include "tensorflow/core/common_runtime/collective_rma_local.h" |
24 | #include "tensorflow/core/common_runtime/collective_util.h" |
25 | #include "tensorflow/core/common_runtime/copy_tensor.h" |
26 | #include "tensorflow/core/common_runtime/device.h" |
27 | #include "tensorflow/core/common_runtime/device_mgr.h" |
28 | #include "tensorflow/core/common_runtime/dma_helper.h" |
29 | #include "tensorflow/core/common_runtime/process_util.h" |
30 | #include "tensorflow/core/framework/allocator.h" |
31 | #include "tensorflow/core/framework/device_base.h" |
32 | #include "tensorflow/core/framework/op_kernel.h" |
33 | #include "tensorflow/core/framework/tensor.h" |
34 | #include "tensorflow/core/framework/types.h" |
35 | #include "tensorflow/core/lib/core/errors.h" |
36 | #include "tensorflow/core/lib/core/notification.h" |
37 | #include "tensorflow/core/lib/core/status.h" |
38 | #include "tensorflow/core/lib/strings/str_util.h" |
39 | #include "tensorflow/core/lib/strings/strcat.h" |
40 | #include "tensorflow/core/platform/env.h" |
41 | #include "tensorflow/core/platform/types.h" |
42 | #include "tensorflow/core/profiler/lib/traceme.h" |
43 | |
44 | namespace tensorflow { |
45 | Status RingGatherer::InitializeCollectiveParams(CollectiveParams* col_params) { |
46 | DCHECK_EQ(col_params->instance.type, GATHER_COLLECTIVE); |
47 | DCHECK_EQ(col_params->instance.impl_details.collective_name, "RingGather" ); |
48 | // TODO(tucker): Maybe add subdiv support. It's only useful with |
49 | // multiple NICS, and maybe gather performance isn't important enough. |
50 | // For now, there must always be only a single subdiv at offset 0. |
51 | if (!col_params->instance.impl_details.subdiv_offsets.empty() && |
52 | (col_params->instance.impl_details.subdiv_offsets.size() > 1 || |
53 | col_params->instance.impl_details.subdiv_offsets[0] != 0)) { |
54 | return errors::InvalidArgument( |
55 | "RingGather cannot take any subdiv offset other than 0." ); |
56 | } |
57 | if (col_params->instance.impl_details.subdiv_offsets.empty()) { |
58 | col_params->instance.impl_details.subdiv_offsets.push_back(0); |
59 | } |
60 | return RingAlg::InitializeCollectiveParams(col_params); |
61 | } |
62 | |
63 | void RingGatherer::Run(StatusCallback done) { |
64 | DCHECK(col_ctx_); |
65 | DCHECK(col_params_); |
66 | done_ = std::move(done); |
67 | group_size_ = col_params_->group.group_size; |
68 | num_subdivs_ = static_cast<int>( |
69 | col_params_->instance.impl_details.subdiv_permutations.size()); |
70 | DCHECK_GT(num_subdivs_, 0); |
71 | |
72 | if (VLOG_IS_ON(1)) { |
73 | string buf; |
74 | for (int r = 0; r < col_params_->group.members.size(); ++r) { |
75 | strings::StrAppend(&buf, "dev " , r, " : " , |
76 | col_params_->group.members[r].device.name(), "\n" ); |
77 | } |
78 | for (int sd = 0; |
79 | sd < col_params_->instance.impl_details.subdiv_permutations.size(); |
80 | ++sd) { |
81 | strings::StrAppend(&buf, "\nsubdiv " , sd, " perm: " ); |
82 | for (auto x : |
83 | col_params_->instance.impl_details.subdiv_permutations[sd]) { |
84 | strings::StrAppend(&buf, x, ", " ); |
85 | } |
86 | } |
87 | VLOG(1) << "RingGatherer::Run for device " << col_ctx_->device_name |
88 | << " default_rank " << col_params_->default_rank << "\n" |
89 | << buf; |
90 | } |
91 | |
92 | // Prepare to alias fields within the output. |
93 | AllocatorAttributes attr = col_ctx_->op_ctx->output_alloc_attr(0); |
94 | ca_.reset(MakeCollectiveAdapter(col_ctx_->output, group_size_ * num_subdivs_, |
95 | col_ctx_->device->GetAllocator(attr), |
96 | false /*align_chunks*/)); |
97 | |
98 | // Start by copying input to the rank-specific offset of output. |
99 | // We are running in a blockable thread and the callback can't block so |
100 | // just wait here on the copy. |
101 | { |
102 | profiler::TraceMe activity("MemCpyAsync" , profiler::TraceMeLevel::kInfo); |
103 | Notification note; |
104 | Status status; |
105 | Tensor alias_chunk(ca_->ChunkAlias(col_params_->subdiv_rank[0])); |
106 | CollectiveRemoteAccessLocal::MemCpyAsync( |
107 | col_ctx_->op_ctx->op_device_context(), |
108 | col_ctx_->op_ctx->op_device_context(), col_ctx_->device, |
109 | col_ctx_->device, col_ctx_->op_ctx->input_alloc_attr(0), |
110 | col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input, &alias_chunk, |
111 | 0 /*dev_to_dev_stream_index*/, [¬e, &status](const Status& s) { |
112 | status.Update(s); |
113 | note.Notify(); |
114 | }); |
115 | note.WaitForNotification(); |
116 | if (!status.ok()) { |
117 | done_(status); |
118 | return; |
119 | } |
120 | } |
121 | Finish(RunAsyncParts()); |
122 | } |
123 | |
124 | bool RingGatherer::RunAsyncParts() { |
125 | // This function orchestrates RingGatherer actions on behalf of a |
126 | // single device. It is entered by a blockable thread that |
127 | // loops within it until all actions assigned to that device |
128 | // complete. Hence function local variables are accessible only by that |
129 | // one thread and do not require an explicit mutex. |
130 | rfv_.clear(); |
131 | rfv_.resize(group_size_ * num_subdivs_); |
132 | PCQueue ready_queue; |
133 | for (int chunk_idx = 0; chunk_idx < group_size_; ++chunk_idx) { |
134 | for (int subdiv_idx = 0; subdiv_idx < num_subdivs_; ++subdiv_idx) { |
135 | int rf_index = (chunk_idx * num_subdivs_) + subdiv_idx; |
136 | InitRingField(&rfv_[rf_index], chunk_idx, subdiv_idx, rf_index); |
137 | ready_queue.Enqueue(&rfv_[rf_index]); |
138 | } |
139 | } |
140 | const DeviceBase::AcceleratorDeviceInfo* gpu_info = |
141 | col_ctx_->device->tensorflow_accelerator_device_info(); |
142 | if (gpu_info) { |
143 | // Wait for all currently queued events on the CPU compute stream to |
144 | // complete before proceeding. The previous InitRingField calls allocated |
145 | // temp memory buffers that are not guaranteed to be valid (e.g. for RDMA |
146 | // write) unless we do. |
147 | profiler::TraceMe activity("WaitForQueuedEvents" , |
148 | profiler::TraceMeLevel::kInfo); |
149 | Notification note; |
150 | Status s = gpu_info->default_context->ThenExecute( |
151 | col_ctx_->device, gpu_info->stream, [¬e]() { note.Notify(); }); |
152 | if (s.ok()) { |
153 | note.WaitForNotification(); |
154 | } else { |
155 | mutex_lock l(status_mu_); |
156 | status_ = |
157 | errors::Internal("Failed to dispatch ThenExecute in RingGatherer" ); |
158 | return false; |
159 | } |
160 | } |
161 | |
162 | int field_done_count = 0; |
163 | int send_pending_count = 0; |
164 | int recv_pending_count = 0; |
165 | std::atomic<bool> aborted(false); |
166 | |
167 | // Loop until all RingFields have advanced to completion. |
168 | { |
169 | profiler::TraceMe activity("Loop" , profiler::TraceMeLevel::kInfo); |
170 | while (field_done_count < rfv_.size()) { |
171 | VLOG(4) << FieldState(); |
172 | // Wait for a RingField to appear in the ready_queue. |
173 | RingField* rf = ready_queue.Dequeue(); |
174 | // Advance the RingField to its next action and execute, repeating |
175 | // until either an async action has been started or the RingField |
176 | // is done. |
177 | bool dispatched = false; // true if async action was initiated |
178 | do { |
179 | if (aborted) { |
180 | // Requeue this RingField to be counted off below. |
181 | ready_queue.Enqueue(rf); |
182 | break; |
183 | } |
184 | switch (rf->action) { |
185 | case RF_INIT: |
186 | if (rf->do_recv) { |
187 | rf->action = RF_RECV; |
188 | auto requeue = [this, rf, &ready_queue, &aborted](Status s) { |
189 | if (!s.ok()) { |
190 | aborted = true; |
191 | StartAbort(s); |
192 | } |
193 | ready_queue.Enqueue(rf); |
194 | }; |
195 | DispatchRecv(rf, requeue); |
196 | dispatched = true; |
197 | ++recv_pending_count; |
198 | } else { |
199 | rf->action = RF_SEND_READY; |
200 | } |
201 | break; |
202 | case RF_RECV: |
203 | DCHECK_GT(recv_pending_count, 0); |
204 | --recv_pending_count; |
205 | rf->action = RF_SEND_READY; |
206 | break; |
207 | case RF_REDUCE: |
208 | // Never used for Gather, so just fall through. |
209 | TF_FALLTHROUGH_INTENDED; |
210 | case RF_FINALIZE: |
211 | // Never used for Gather, so just fall through. |
212 | TF_FALLTHROUGH_INTENDED; |
213 | case RF_SEND_READY: |
214 | if (rf->do_send) { |
215 | rf->action = RF_SEND; |
216 | auto send_complete = [this, rf, &ready_queue, |
217 | &aborted](Status s) { |
218 | if (!s.ok()) { |
219 | aborted = true; |
220 | StartAbort(s); |
221 | } |
222 | ready_queue.Enqueue(rf); |
223 | }; |
224 | DispatchSend(rf, send_complete); |
225 | dispatched = true; |
226 | ++send_pending_count; |
227 | } else { |
228 | rf->action = RF_DONE; |
229 | } |
230 | break; |
231 | case RF_SEND: |
232 | DCHECK_GT(send_pending_count, 0); |
233 | --send_pending_count; |
234 | rf->action = RF_DONE; |
235 | break; |
236 | case RF_DONE: |
237 | break; |
238 | } |
239 | if (rf->action == RF_DONE) { |
240 | // There's only one pass. |
241 | ++field_done_count; |
242 | break; // from do while(!dispatched) |
243 | } |
244 | } while (!dispatched); |
245 | if (aborted) break; |
246 | } // while (field_done_count < number of fields) |
247 | |
248 | if (aborted) { |
249 | // All of the pending data actions should be aborted; field the |
250 | // callbacks and clear the queue before quitting. |
251 | while ((send_pending_count > 0) || (recv_pending_count > 0)) { |
252 | RingField* rf = ready_queue.Dequeue(); |
253 | switch (rf->action) { |
254 | case RF_RECV: |
255 | --recv_pending_count; |
256 | break; |
257 | case RF_SEND: |
258 | --send_pending_count; |
259 | break; |
260 | default: { |
261 | } // Ignore any other actions |
262 | } |
263 | } |
264 | } |
265 | } |
266 | |
267 | DCHECK_EQ(send_pending_count, 0); |
268 | DCHECK_EQ(recv_pending_count, 0); |
269 | |
270 | VLOG(2) << this << " device=" << col_ctx_->device_name << " finish;" |
271 | << " final value " << TensorDebugString(ca_->Value()); |
272 | return !aborted; |
273 | } |
274 | |
275 | namespace { |
276 | REGISTER_COLLECTIVE(RingGather, RingGatherer); |
277 | } // namespace |
278 | |
279 | } // namespace tensorflow |
280 | |