1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/common_runtime/base_collective_executor.h" |
16 | |
17 | #include <algorithm> |
18 | #include <functional> |
19 | #include <utility> |
20 | |
21 | #include "tensorflow/core/common_runtime/copy_tensor.h" |
22 | #include "tensorflow/core/common_runtime/device_mgr.h" |
23 | #include "tensorflow/core/common_runtime/dma_helper.h" |
24 | #include "tensorflow/core/common_runtime/process_util.h" |
25 | #include "tensorflow/core/framework/allocator.h" |
26 | #include "tensorflow/core/framework/cancellation.h" |
27 | #include "tensorflow/core/framework/op_kernel.h" |
28 | #include "tensorflow/core/framework/tensor.h" |
29 | #include "tensorflow/core/framework/tensor_shape.h" |
30 | #include "tensorflow/core/framework/types.h" |
31 | #include "tensorflow/core/framework/types.pb.h" |
32 | #include "tensorflow/core/lib/core/errors.h" |
33 | #include "tensorflow/core/lib/core/notification.h" |
34 | #include "tensorflow/core/lib/core/status.h" |
35 | #include "tensorflow/core/lib/strings/strcat.h" |
36 | #include "tensorflow/core/platform/macros.h" |
37 | #include "tensorflow/core/platform/refcount.h" |
38 | #include "tensorflow/core/platform/tracing.h" |
39 | #include "tensorflow/core/platform/types.h" |
40 | #include "tensorflow/core/profiler/lib/connected_traceme.h" |
41 | #include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h" |
42 | #include "tensorflow/core/profiler/lib/traceme.h" |
43 | |
44 | #define VALUE_IN_DEBUG_STRING false |
45 | |
46 | namespace tensorflow { |
47 | |
48 | namespace { |
49 | bool IsCancelled(CancellationManager* cancel_mgr) { |
50 | return cancel_mgr != nullptr && |
51 | (cancel_mgr->IsCancelled() || cancel_mgr->IsCancelling()); |
52 | } |
53 | } // namespace |
54 | |
55 | /*static*/ |
56 | int64_t CollectiveAdapter::AlignedChunkElts(int64_t elt_bytes, |
57 | int64_t total_elts, |
58 | int64_t num_chunks) { |
59 | DCHECK_GT(num_chunks, 0); |
60 | int64_t base_chunk_elts = (total_elts + (num_chunks - 1)) / num_chunks; |
61 | if (EIGEN_MAX_ALIGN_BYTES == 0) return base_chunk_elts; |
62 | if (EIGEN_MAX_ALIGN_BYTES <= elt_bytes) { |
63 | // Tolerate weird small values of EIGEN_MAX_ALIGN_BYTES |
64 | DCHECK_EQ(0, elt_bytes % EIGEN_MAX_ALIGN_BYTES); |
65 | return base_chunk_elts; |
66 | } |
67 | // elt_bytes < EIGEN_MAX_ALIGN_BYTES, which |
68 | // must be a common multiple of the various atomic data types. |
69 | DCHECK_EQ(0, EIGEN_MAX_ALIGN_BYTES % elt_bytes) |
70 | << "total_elts=" << total_elts << " num_chunks=" << num_chunks |
71 | << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES |
72 | << " elt_bytes=" << elt_bytes; |
73 | // Round bytes per chunk up to the next multiple of EIGEN_MAX_ALIGN_BYTES. |
74 | int64_t chunk_bytes = base_chunk_elts * elt_bytes; |
75 | int64_t diff = |
76 | (chunk_bytes < EIGEN_MAX_ALIGN_BYTES) |
77 | ? (EIGEN_MAX_ALIGN_BYTES - chunk_bytes) |
78 | : (EIGEN_MAX_ALIGN_BYTES - (chunk_bytes % EIGEN_MAX_ALIGN_BYTES)); |
79 | DCHECK_EQ(0, diff % elt_bytes); |
80 | base_chunk_elts += (diff / elt_bytes); |
81 | DCHECK_EQ(0, ((base_chunk_elts * elt_bytes) % EIGEN_MAX_ALIGN_BYTES)) |
82 | << "total_elts=" << total_elts << " num_chunks=" << num_chunks |
83 | << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES |
84 | << " base_chunk_elts=" << base_chunk_elts << " elt_bytes=" << elt_bytes; |
85 | return base_chunk_elts; |
86 | } |
87 | |
88 | namespace { |
89 | template <typename T> |
90 | class CollectiveAdapterImpl : public CollectiveAdapter { |
91 | public: |
92 | // Takes ownership of output and prepares to properly alias its chunks. |
93 | // Ownership is taken because the shape may temporarily change. |
94 | CollectiveAdapterImpl(Tensor* output, int64_t num_chunks, |
95 | Allocator* allocator, bool align_chunks) |
96 | : output_(std::move(*output)), |
97 | dt_(output_.dtype()), |
98 | old_shape_(output_.shape()), |
99 | num_chunks_(num_chunks), |
100 | allocator_(allocator), |
101 | total_elts_(output_.NumElements()), |
102 | chunk_elts_(align_chunks |
103 | ? AlignedChunkElts(sizeof(T), total_elts_, num_chunks_) |
104 | : total_elts_ / num_chunks_), |
105 | data_start_(reinterpret_cast<T*>(DMAHelper::base(&output_))), |
106 | data_end_(data_start_ + total_elts_) { |
107 | if (!align_chunks) { |
108 | DCHECK_EQ(total_elts_, num_chunks_ * chunk_elts_); |
109 | } |
110 | DCHECK_GT(chunk_elts_, 0); |
111 | Flatten(); |
112 | } |
113 | |
114 | ~CollectiveAdapterImpl() override {} |
115 | |
116 | const Tensor& Value() const override { return output_; } |
117 | |
118 | // If necessary, flatten output. |
119 | void Flatten() { |
120 | if (old_shape_.dims() != 1) { |
121 | TensorShape new_shape = TensorShape({old_shape_.num_elements()}); |
122 | DMAHelper::UnsafeSetShape(&output_, new_shape); |
123 | } |
124 | } |
125 | |
126 | void ConsumeFinalValue(Tensor* output) override { |
127 | if (old_shape_ != output_.shape()) { |
128 | DMAHelper::UnsafeSetShape(&output_, old_shape_); |
129 | } |
130 | *output = std::move(output_); |
131 | } |
132 | |
133 | // Number of T elements in a particular chunk. |
134 | inline int64_t ChunkElts(int i) const { |
135 | DCHECK_LT(i, num_chunks_); |
136 | const T* chunk_start = std::min(data_end_, data_start_ + i * chunk_elts_); |
137 | const T* chunk_end = std::min(data_end_, chunk_start + chunk_elts_); |
138 | return chunk_end - chunk_start; |
139 | } |
140 | |
141 | int64_t ChunkBytes(int i) const override { return sizeof(T) * ChunkElts(i); } |
142 | |
143 | // Returns a new Tensor that aliases the required chunk. |
144 | Tensor ChunkAlias(int i) override { |
145 | int64_t start = chunk_elts_ * i; |
146 | int64_t num_elts = ChunkElts(i); |
147 | // If this chunk is empty the prior chunk might also be short |
148 | // so always take an empty slice from the front of the tensor |
149 | // to avoid an illegal offset check failure somewhere. |
150 | return (num_elts > 0) ? output_.Slice(start, start + num_elts) |
151 | : output_.Slice(0, 0); |
152 | } |
153 | |
154 | Tensor TempChunk(int i) const override { |
155 | AllocationAttributes empty; |
156 | profiler::ScopedMemoryDebugAnnotation op_annotation( |
157 | "CollectiveAdapterImpl::TempChunk" ); |
158 | return Tensor(allocator_, dt_, {ChunkElts(i)}, empty); |
159 | } |
160 | |
161 | string DebugString() const override { |
162 | return strings::StrCat( |
163 | "base addr " , reinterpret_cast<int64_t>(DMAHelper::base(&output_)), |
164 | " num_chunks " , num_chunks_, " total_elts " , total_elts_, " chunk_elts" , |
165 | chunk_elts_, " value " , |
166 | VALUE_IN_DEBUG_STRING ? output_.SummarizeValue(1024) : "<hidden>" ); |
167 | } |
168 | |
169 | string TBounds(const Tensor& t) const override { |
170 | int64_t base_addr = reinterpret_cast<int64_t>(DMAHelper::base(&t)); |
171 | return strings::StrCat("(" , base_addr, ", " , (base_addr + t.TotalBytes()), |
172 | ")" ); |
173 | } |
174 | |
175 | Tensor Scalar(int v) const override { return Tensor(static_cast<T>(v)); } |
176 | |
177 | Tensor Scalar(Allocator* a, const AllocationAttributes& attr) const override { |
178 | Tensor t(a, dt_, TensorShape({}), attr); |
179 | return t; |
180 | } |
181 | |
182 | Tensor output_; |
183 | const DataType dt_; |
184 | const TensorShape old_shape_; |
185 | const int64_t num_chunks_; |
186 | Allocator* allocator_; |
187 | const int64_t total_elts_; |
188 | const int64_t chunk_elts_; |
189 | const T* data_start_; |
190 | const T* data_end_; |
191 | }; |
192 | |
193 | } // namespace |
194 | |
195 | CollectiveAdapter* MakeCollectiveAdapter(Tensor* output, int num_chunks, |
196 | Allocator* allocator, |
197 | bool align_chunks) { |
198 | switch (output->dtype()) { |
199 | case DT_BFLOAT16: |
200 | return new CollectiveAdapterImpl<Eigen::bfloat16>( |
201 | output, num_chunks, allocator, align_chunks); |
202 | break; |
203 | case DT_HALF: |
204 | return new CollectiveAdapterImpl<Eigen::half>(output, num_chunks, |
205 | allocator, align_chunks); |
206 | break; |
207 | case DT_FLOAT: |
208 | return new CollectiveAdapterImpl<float>(output, num_chunks, allocator, |
209 | align_chunks); |
210 | break; |
211 | case DT_DOUBLE: |
212 | return new CollectiveAdapterImpl<double>(output, num_chunks, allocator, |
213 | align_chunks); |
214 | break; |
215 | case DT_INT32: |
216 | return new CollectiveAdapterImpl<int32>(output, num_chunks, allocator, |
217 | align_chunks); |
218 | break; |
219 | case DT_INT64: |
220 | return new CollectiveAdapterImpl<int64_t>(output, num_chunks, allocator, |
221 | align_chunks); |
222 | break; |
223 | default: |
224 | LOG(FATAL) << "Unsupported type " << DataTypeString(output->dtype()) |
225 | << " to MakeCollectiveAdapter" ; |
226 | return nullptr; |
227 | } |
228 | } |
229 | |
230 | BaseCollectiveExecutor::~BaseCollectiveExecutor() {} |
231 | |
232 | void BaseCollectiveExecutor::StartAbort(const Status& s) { |
233 | Status status; |
234 | { |
235 | mutex_lock l(status_mu_); |
236 | if (!status_.ok()) { |
237 | VLOG(2) << "BaseCollectiveExecutor already aborted, ignoring StartAbort: " |
238 | << s; |
239 | return; |
240 | } |
241 | status_ = StatusGroup::MakeDerived(Status( |
242 | s.code(), |
243 | absl::StrCat( |
244 | "Collective ops is aborted by: " , s.error_message(), |
245 | "\nThe error could be from a previous operation. Restart your " |
246 | "program to reset." ))); |
247 | status = status_; |
248 | } |
249 | LOG(ERROR) << "BaseCollectiveExecutor::StartAbort " << s; |
250 | cem_->GetParamResolver()->StartAbort(status); |
251 | remote_access_->StartAbort(status); |
252 | if (cem_->GetNcclCommunicator() != nullptr) { |
253 | cem_->GetNcclCommunicator()->StartAbort(status); |
254 | } |
255 | } |
256 | |
257 | Status BaseCollectiveExecutor::GetStatus(const Status& s) { |
258 | if (s.ok()) return s; |
259 | mutex_lock l(status_mu_); |
260 | // If the collective executor is already aborted, use the aborted status |
261 | // which is more likely the actual error instead of an artifact of an |
262 | // abortion. |
263 | if (!status_.ok()) { |
264 | VLOG(2) << "Overriding status with collective ops executor status. " |
265 | "Original status: " |
266 | << s; |
267 | return status_; |
268 | } |
269 | return s; |
270 | } |
271 | |
272 | void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx, |
273 | const CollectiveParams* col_params, |
274 | const string& exec_key, |
275 | StatusCallback done) { |
276 | // See CompleteParamsAsync() how done() and the timeout callback interacts. |
277 | const auto is_callback_called = std::make_shared<std::atomic<bool>>(false); |
278 | auto done_safe = [this, done, ctx, is_callback_called](const Status& s) { |
279 | bool called = is_callback_called->exchange(true); |
280 | if (!called) { |
281 | if (!s.ok() && !IsCancelled(ctx->cancellation_manager())) { |
282 | // This is a collective error. Abort CollectiveExecutor so that this |
283 | // error can propagate to other workers. |
284 | StartAbort(s); |
285 | } |
286 | done(GetStatus(s)); |
287 | } |
288 | }; |
289 | auto timeout_microseconds = static_cast<int64_t>( |
290 | col_params->instance.impl_details.timeout_seconds * 1'000'000); |
291 | if (timeout_microseconds > 0) { |
292 | // TODO(xldrx): Share the timeout watchdog thread among collectives. |
293 | SchedNonBlockingClosureAfter( |
294 | timeout_microseconds, [this, is_callback_called, done] { |
295 | bool called = is_callback_called->exchange(true); |
296 | if (!called) { |
297 | Status status(error::DEADLINE_EXCEEDED, |
298 | "Collective has timed out during execution." ); |
299 | StartAbort(status); |
300 | done(status); |
301 | } |
302 | }); |
303 | } |
304 | |
305 | Tensor* output = ctx->mutable_output(0); |
306 | const Tensor* input = (col_params->instance.type == REDUCTION_COLLECTIVE || |
307 | col_params->instance.type == GATHER_COLLECTIVE || |
308 | col_params->instance.type == PERMUTE_COLLECTIVE || |
309 | col_params->instance.type == ALL_TO_ALL_COLLECTIVE || |
310 | (col_params->instance.type == BROADCAST_COLLECTIVE && |
311 | col_params->is_source)) |
312 | ? &ctx->input(0) |
313 | : nullptr; |
314 | CollectiveImplementationInterface* col_impl = nullptr; |
315 | Status status = CreateCollective(*col_params, &col_impl); |
316 | if (!status.ok()) { |
317 | done_safe(status); |
318 | DCHECK_EQ(nullptr, col_impl); |
319 | return; |
320 | } |
321 | core::ScopedUnref unref(col_impl); |
322 | auto col_ctx = std::make_shared<CollectiveContext>( |
323 | this, cem_->GetNcclCommunicator(), dev_mgr_, ctx, CtxParams(ctx), |
324 | col_params, exec_key, step_id_, input, output); |
325 | status = col_impl->InitializeCollectiveContext(col_ctx); |
326 | if (!status.ok()) { |
327 | done_safe(status); |
328 | return; |
329 | } |
330 | // Run on an unbounded work queue that can handle blocking work so as to not |
331 | // starve executor threads. |
332 | col_impl->Ref(); |
333 | profiler::TraceMeProducer producer("BaseCollectiveExecutor::ExecuteAsync" ); |
334 | RunClosure([col_impl, col_ctx, done_safe, ctx, |
335 | context_id = producer.GetContextId()]() { |
336 | core::ScopedUnref unref(col_impl); |
337 | profiler::TraceMeConsumer consumer( |
338 | [ctx, col_ctx] { |
339 | string op = profiler::TraceMeOp(ctx->op_kernel().name_view(), |
340 | ctx->op_kernel().type_string_view()); |
341 | return profiler::TraceMeEncode( |
342 | std::move(op), |
343 | {{"id" , ctx->step_id()}, |
344 | {"instance_key" , col_ctx->col_params->instance.instance_key}, |
345 | {"collective" , col_ctx->col_params->instance.type}}); |
346 | }, |
347 | context_id); |
348 | col_impl->Ref(); |
349 | col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) { |
350 | core::ScopedUnref unref(col_impl); |
351 | done_safe(s); |
352 | }); |
353 | }); |
354 | } |
355 | |
356 | void BaseCollectiveExecutor::CompleteParamsAsync( |
357 | const DeviceAttributes& device, CollectiveParams* cp, |
358 | CancellationManager* cancel_mgr, StatusCallback done) { |
359 | // We need to make sure that when the timeout callback executes, |
360 | // CollectiveExecutor and CollectiveExecutorMgr are both alive. After done() |
361 | // is called, CollectiveExecutorMgr may be destructed and we don't have a way |
362 | // to keep it without making the ownerships more complicated. Therefore if the |
363 | // timeout callback executes, done_safe will become a no-op and the timeout |
364 | // callback is responsible for invoking done() at the end. |
365 | const auto is_callback_called = std::make_shared<std::atomic<bool>>(false); |
366 | int64_t trace_id = profiler::TraceMe::ActivityStart([cp]() { |
367 | return profiler::TraceMeEncode("CollectiveExecutor::CompleteParams" , |
368 | {{"group_key" , cp->group.group_key}, |
369 | {"group_size" , cp->group.group_size}}); |
370 | }); |
371 | |
372 | auto done_safe = [this, is_callback_called, cancel_mgr, trace_id, |
373 | done](const Status& s) { |
374 | profiler::TraceMe::ActivityEnd(trace_id); |
375 | bool called = is_callback_called->exchange(true); |
376 | if (!called) { |
377 | if (!s.ok() && !IsCancelled(cancel_mgr)) { |
378 | // This is a collective error. Abort CollectiveExecutor so that this |
379 | // error can propagate to other workers. |
380 | StartAbort(s); |
381 | } |
382 | done(GetStatus(s)); |
383 | } |
384 | }; |
385 | auto timeout_microseconds = static_cast<int64_t>( |
386 | cp->instance.impl_details.timeout_seconds * 1'000'000); |
387 | if (timeout_microseconds > 0) { |
388 | // TODO(xldrx): Share the timeout watchdog thread among collectives. |
389 | SchedNonBlockingClosureAfter( |
390 | timeout_microseconds, [this, is_callback_called, done]() { |
391 | bool called = is_callback_called->exchange(true); |
392 | if (!called) { |
393 | Status status( |
394 | error::DEADLINE_EXCEEDED, |
395 | "Collective has timed out waiting for other workers." ); |
396 | StartAbort(status); |
397 | done(status); |
398 | } |
399 | }); |
400 | } |
401 | cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr, |
402 | done_safe); |
403 | } |
404 | |
405 | Status BaseCollectiveExecutor::CreateCollective( |
406 | const CollectiveParams& col_params, |
407 | CollectiveImplementationInterface** col_impl) { |
408 | VLOG(2) << "CreateCollective type " |
409 | << DataTypeString(col_params.instance.data_type) << " name " |
410 | << col_params.instance.impl_details.collective_name; |
411 | *col_impl = nullptr; |
412 | switch (col_params.instance.data_type) { |
413 | case DT_BOOL: |
414 | if (col_params.instance.type == BROADCAST_COLLECTIVE) { |
415 | return CollectiveRegistry::Lookup( |
416 | col_params.instance.impl_details.collective_name, col_impl); |
417 | } else { |
418 | return errors::Internal( |
419 | "No collective other than broadcast supports DT_BOOL" ); |
420 | } |
421 | case DT_INT32: |
422 | if (col_params.group.device_type == DEVICE_GPU && |
423 | col_params.instance.type == REDUCTION_COLLECTIVE) { |
424 | // TODO(b/139421603): enable int32 all-reduce on GPU. |
425 | return errors::Internal( |
426 | "Collective all-reduce does not support datatype DT_INT32 on " |
427 | "DEVICE_GPU" ); |
428 | } else { |
429 | return CollectiveRegistry::Lookup( |
430 | col_params.instance.impl_details.collective_name, col_impl); |
431 | } |
432 | case DT_BFLOAT16: |
433 | if (col_params.group.device_type == DEVICE_GPU && |
434 | col_params.instance.type == REDUCTION_COLLECTIVE) { |
435 | return errors::Internal( |
436 | "Collective all-reduce does not support datatype DT_BFLOAT16 on " |
437 | "DEVICE_GPU" ); |
438 | } else { |
439 | return CollectiveRegistry::Lookup( |
440 | col_params.instance.impl_details.collective_name, col_impl); |
441 | } |
442 | case DT_HALF: |
443 | case DT_FLOAT: |
444 | case DT_DOUBLE: |
445 | case DT_INT64: { |
446 | return CollectiveRegistry::Lookup( |
447 | col_params.instance.impl_details.collective_name, col_impl); |
448 | } |
449 | default: |
450 | return errors::Internal( |
451 | "CollectiveImplementation does not support datatype " , |
452 | DataTypeString(col_params.instance.data_type)); |
453 | } |
454 | } |
455 | |
456 | bool BaseCollectiveExecutor::CheckDependencies( |
457 | const CollectiveParams& col_params) { |
458 | for (int32_t instance : col_params.instance.impl_details.dependencies) { |
459 | auto find_iter = launched_.find(instance); |
460 | if (find_iter == launched_.end() || find_iter->second != 0) { |
461 | VLOG(1) << "Collective " << col_params.ToString() |
462 | << " blocked by instance " << instance; |
463 | return false; |
464 | } |
465 | } |
466 | return true; |
467 | } |
468 | |
469 | void BaseCollectiveExecutor::WaitForDependencies( |
470 | const CollectiveParams& col_params) { |
471 | mutex_lock l(launch_mu_); |
472 | while (!CheckDependencies(col_params)) { |
473 | launch_cv_.wait(l); |
474 | } |
475 | VLOG(1) << "Unblocking collective " << col_params.ToString(); |
476 | } |
477 | |
478 | void BaseCollectiveExecutor::UnblockDependencies( |
479 | const CollectiveParams& col_params) { |
480 | mutex_lock l(launch_mu_); |
481 | if (launched_.find(col_params.instance.instance_key) == launched_.end()) { |
482 | const string& task_name = |
483 | col_params.group.members[col_params.default_rank].task; |
484 | const int32_t num_devices = |
485 | col_params.group.num_devices_per_task.at(task_name); |
486 | launched_[col_params.instance.instance_key] = num_devices; |
487 | } |
488 | if (--launched_[col_params.instance.instance_key] == 0) { |
489 | VLOG(1) << "Unblocking dependencies for collective instance " |
490 | << col_params.instance.instance_key; |
491 | launch_cv_.notify_all(); |
492 | } |
493 | } |
494 | |
495 | } // namespace tensorflow |
496 | |