1#include <torch/csrc/distributed/c10d/reducer.hpp>
2
3#include <torch/csrc/distributed/c10d/Utils.hpp>
4#include <torch/csrc/distributed/c10d/default_comm_hooks.hpp>
5
6#include <functional>
7
8#include <c10/core/DeviceGuard.h>
9#include <c10/core/StreamGuard.h>
10#include <c10/util/Exception.h>
11#include <c10/util/Logging.h>
12#include <c10/util/hash.h>
13#include <c10/util/irange.h>
14#include <torch/csrc/autograd/engine.h>
15#include <torch/csrc/autograd/function_hook.h>
16#include <torch/csrc/autograd/functions/accumulate_grad.h>
17#include <torch/csrc/autograd/profiler.h>
18#include <torch/csrc/autograd/utils/grad_layout_contract.h>
19#include <torch/csrc/autograd/utils/lambda_post_hook.h>
20#include <torch/csrc/distributed/c10d/comm.hpp>
21#include <torch/csrc/distributed/c10d/logger.hpp>
22#include <torch/csrc/utils/memory.h>
23
24namespace c10d {
25namespace {
26
27constexpr int kUnsetDivFactor = -1;
28
29// Macro that wraps TORCH_CHECK with DDP logging.
30#define REDUCER_CHECK(cond, logger_, ...) \
31 if (C10_UNLIKELY_OR_CONST(!(cond))) { \
32 if (!logger_.expired()) { \
33 logger_.lock()->set_error_and_log(__VA_ARGS__); \
34 } \
35 TORCH_CHECK(false, ##__VA_ARGS__); \
36 }
37
38} // namespace
39
40C10_DEFINE_TYPED_REGISTRY( // NOLINT
41 TimerRegistry,
42 c10::DeviceType,
43 Timer,
44 std::unique_ptr,
45 c10::Device);
46
47namespace {
48
49class CpuTimer : public Timer {
50 public:
51 explicit CpuTimer(c10::Device /* unused */) {}
52
53 c10::optional<int64_t> measureDifference(Event start, Event end) override {
54 int64_t start_time = getTimeRef(start);
55 int64_t end_time = getTimeRef(end);
56 // If cpu_end_time is not recorded in this iteration,
57 // avg_time will return invalid value.
58 // For some cases like DDP runs on non-sync mode, backward compute
59 // end time can not be recorded in this iteration and thus can not
60 // calculate the valid avg_time.
61 // In this case, skip calculating the avg_time and return.
62 if (end_time < start_time) {
63 return c10::nullopt;
64 }
65 return end_time - start_time;
66 }
67};
68
69C10_REGISTER_TYPED_CLASS(TimerRegistry, c10::kCPU, CpuTimer);
70
71std::vector<at::Tensor> extractTensors(const c10::IValue& result) {
72 if (result.isPyObject()) {
73 return result.toPyObjectHolder()->extractTensors();
74 }
75 TORCH_INTERNAL_ASSERT(
76 result.isTensor() || result.isTensorList(),
77 "expected the hook result is either a Tensor or a TensorList found ",
78 result.tagKind());
79
80 if (result.isTensor()) {
81 return {result.toTensor()};
82 }
83
84 return result.toTensorVector();
85}
86
87} // namespace
88
89Reducer::Reducer(
90 std::vector<at::Tensor> params,
91 std::vector<std::vector<size_t>> bucket_indices,
92 std::vector<size_t> per_bucket_size_limits,
93 c10::intrusive_ptr<c10d::ProcessGroup> process_group,
94 std::vector<bool> expect_sparse_gradients,
95 int64_t bucket_bytes_cap,
96 bool find_unused_parameters,
97 bool gradient_as_bucket_view,
98 std::unordered_map<size_t, std::string> param_names,
99 int64_t first_bucket_bytes_cap)
100 : params_(std::move(params)),
101 process_group_(std::move(process_group)),
102 expect_sparse_gradients_(std::move(expect_sparse_gradients)),
103 expect_autograd_hooks_(false),
104 require_finalize_(false),
105 next_bucket_(0),
106 has_marked_unused_parameters_(false),
107 find_unused_parameters_(find_unused_parameters),
108 gradient_as_bucket_view_(gradient_as_bucket_view),
109 local_used_map_reduced_(false),
110 num_iterations_(0),
111 num_buckets_ready_(0),
112 has_rebuilt_bucket_(false),
113 bucket_bytes_cap_(bucket_bytes_cap),
114 div_factor_(kUnsetDivFactor),
115 static_graph_(false),
116 comm_hook_(nullptr),
117 ddp_debug_level_(debug_level()),
118 param_names_(std::move(param_names)),
119 first_bucket_bytes_cap_(first_bucket_bytes_cap) {
120 C10_LOG_API_USAGE_ONCE("torch.distributed.ddp.reducer");
121 TORCH_INTERNAL_ASSERT(!params_.empty(), "Expected at least one parameter.");
122
123 if (ddp_debug_level_ != c10d::DebugLevel::Off) {
124 LOG(INFO) << "Reducer initialized with bucket_bytes_cap: "
125 << bucket_bytes_cap_
126 << " first_bucket_bytes_cap: " << first_bucket_bytes_cap;
127 }
128 // Check whether the module is multi_device_module
129 {
130 std::set<int> unique_devices;
131 for (const auto& v : params_) {
132 auto device_idx = int(v.device().index());
133 if (unique_devices.find(device_idx) == unique_devices.end()) {
134 unique_devices.insert(device_idx);
135 if (unique_devices.size() > 1) {
136 is_multi_device_module_ = true;
137 break;
138 }
139 }
140 }
141 }
142
143 // For CUDA, record events only for single device module.
144 c10::Device device = params_[0].device();
145 if (!(device.is_cuda() && is_multi_device_module_)) {
146 timer_ = TimerRegistry()->Create(device.type(), device);
147 }
148
149 // If `expect_sparse_gradients` is not specified, initialize it such that
150 // we do not expect sparse gradients for any parameter.
151 if (expect_sparse_gradients_.empty()) {
152 expect_sparse_gradients_ = std::vector<bool>(params_.size(), false);
153 }
154 TORCH_INTERNAL_ASSERT(expect_sparse_gradients_.size() == params_.size());
155
156 // Initialize variable bucketing.
157 // This can be reinitialized later after capturing runtime information.
158 {
159 std::lock_guard<std::mutex> lock(mutex_);
160 initialize_buckets(std::move(bucket_indices));
161 }
162
163 // All variables are expected to have their `grad_fn` set to the gradient
164 // accumulation function (since they are leafs in the autograd graph).
165 // We store pointers to these functions such that we can check if they are
166 // used in an autograd pass. If they are not, we know their grad tensors
167 // can be marked as ready for reduction.
168 {
169 const auto variable_count = params_.size();
170 grad_accumulators_.resize(variable_count);
171 for (const auto variable_index : c10::irange(variable_count)) {
172 auto& variable = params_[variable_index];
173
174 // The gradient accumulator function is lazily initialized once.
175 // Therefore we can use its presence in the autograd graph as
176 // evidence that the parameter has participated in an iteration.
177 auto grad_accumulator = torch::autograd::impl::grad_accumulator(variable);
178
179#ifndef _WIN32
180 using torch::distributed::autograd::ThreadLocalDistAutogradContext;
181#endif
182 // Hook to execute after the gradient accumulator has executed.
183 hooks_.emplace_back(
184 grad_accumulator->add_post_hook(
185 torch::make_unique<torch::autograd::utils::LambdaPostHook>(
186 [=](const torch::autograd::variable_list& outputs,
187 const torch::autograd::variable_list& /* unused */) {
188#ifndef _WIN32
189 this->rpc_context_.set(
190 ThreadLocalDistAutogradContext::getContextPtr());
191#endif
192 this->autograd_hook(variable_index);
193 return outputs;
194 })),
195 grad_accumulator);
196
197 // Map raw function pointer to parameter index.
198 // This is used later on when the autograd graph is traversed
199 // to check for parameters for which no gradient is computed, if
200 // find_unused_parameters=True.
201 // Note that the mapping of gradient accumulator to variable should be
202 // one to one as we deduplicate shared parameters before constructing
203 // Reducer.
204 if (find_unused_parameters_) {
205 gradAccToVariableMap_[grad_accumulator.get()] = variable_index;
206 }
207
208 numGradHooksTriggeredMap_[variable_index] = 0;
209
210 // The gradient accumulator is stored as weak_ptr in the autograd
211 // metadata of the variable, so we have to keep it alive here for
212 // the raw pointer to be valid.
213 REDUCER_CHECK(
214 grad_accumulators_[variable_index] == nullptr,
215 logger_,
216 c10::str(
217 "Reducer tried to register duplicate grad accumulator for variable ",
218 variable_index));
219
220 grad_accumulators_[variable_index] = std::move(grad_accumulator);
221 }
222 }
223
224 // Initialize backward stats vector.
225 {
226 const auto variable_count = params_.size();
227 backward_stats_.resize(variable_count);
228 }
229
230 // See Note [Skip allreducing local_used_map_dev]
231 if (find_unused_parameters_) {
232 initialize_local_used_map();
233 }
234}
235
236// Note [Skip allreducing local_used_map_dev]
237// ~~~~~~~~~~~~~~~~~~~~~~~~~~
238// If find_unused_parameters_ is set to false, there is no need to allreduce
239// local_used_map_dev_, because all parameters will be reduced anyway.
240// Therefore, we can avoid allocating memory for local_used_map and
241// local_used_map_dev_ if find_unused_parameters_ is false.
242
243// Note [DDP Communication Hook]
244// ~~~~~~~~~~~~~~~~~~~~~~~~~~
245// If DDP communication hook is not registered, the reducer reduces the buckets
246// by just calling allreduce. If registered, it calls the hook and uses future
247// work handle. If registered, reducer also skips dividing grads by world size.
248// The reason for this is that the communication hook is expected to completely
249// override how we perform communication and the user should have complete
250// control over how the grads are handled.
251//
252// DDP communication hook is an enhancement that provides a hook which can be
253// used to override how DDP communicates gradients across ranks, this can be
254// used for algorithms like Gradient Compression/GossipGrad. This hook can be
255// registered from Python API using `register_comm_hook`. `PythonCommHook`
256// enables registering a Python hook and is a subclass of `CommHookInterface`.
257// Additionally, there are also some built-in C++ hook implementations that can
258// be specified by calling `register_builtin_comm_hook` from Python API.
259
260Reducer::~Reducer() noexcept(false) {
261 // Remove all hooks on variables registered by this Reducer. This is necessary
262 // to make DDP failure recoverable. Otherwise, multiple Reducer instances
263 // (from recoveries) will add their hooks to the original model, and those
264 // hooks will try to invoke methods on a deleted Reducer objects.
265 for (auto& hook : hooks_) {
266 auto& key = hook.first;
267 auto& grad_accumulator = hook.second;
268
269 TORCH_INTERNAL_ASSERT(
270 grad_accumulator->del_post_hook(key),
271 "Reducer attempts to delete a non-existing hook.");
272 }
273}
274
275bool Reducer::dynamic_graph_find_unused() {
276 return !static_graph_ && find_unused_parameters_;
277}
278
279bool Reducer::static_graph_first_iteration() {
280 return static_graph_ && num_iterations_ == 1;
281}
282
283bool Reducer::static_graph_after_first_iteration() {
284 return static_graph_ && num_iterations_ > 1;
285}
286
287bool Reducer::ddp_graph_static() {
288 std::lock_guard<std::mutex> lock(mutex_);
289 return ddp_graph_static_;
290}
291
292void Reducer::initialize_local_used_map() {
293 const auto variable_count = params_.size();
294 at::TensorOptions options;
295 options = options.dtype(at::kInt);
296
297 // Deliberately don't pin the memory even if local_used_map_dev_ will
298 // be cuda. See Note [local_used_map_ -> local_used_map_dev copying]
299 local_used_map_ = at::zeros({static_cast<long>(variable_count)}, options);
300
301 // This tensor needs to be on the same device as the replica params because
302 // backend such as NCCL may not support CPU tensors, and hence it might not
303 // work if we always put it on CPU.
304 options = options.device(params_[0].device());
305 local_used_map_dev_ = at::empty({static_cast<long>(variable_count)}, options);
306}
307
308void Reducer::check_grad_layout(
309 const at::Tensor& grad,
310 const at::Tensor& bucket_view) {
311 // Ensure that the gradient type matches the bucket type.
312 REDUCER_CHECK(
313 grad.options().type_equal(bucket_view.options()),
314 logger_,
315 c10::str("Expected ", bucket_view.toString(), ", got ", grad.toString()));
316
317 TORCH_INTERNAL_ASSERT(grad.device() == bucket_view.device());
318 TORCH_INTERNAL_ASSERT(grad.numel() == bucket_view.numel());
319 // AccumulateGrad doesn't HAVE to obey the grad layout contract.
320 // The penalty for disobedience is reduced performance, not numerical
321 // death. Warnings here help diagnose poor DDP performance.
322 if (grad.strides() != bucket_view.strides()) {
323 TORCH_WARN_ONCE(
324 "Grad strides do not match bucket view strides. "
325 "This may indicate grad was not created according to the "
326 "gradient layout contract, or that the param's strides "
327 "changed since DDP was constructed. This is not an error, "
328 "but may impair performance.\n"
329 "grad.sizes() = ",
330 grad.sizes(),
331 ", strides() = ",
332 grad.strides(),
333 "\n",
334 "bucket_view.sizes() = ",
335 bucket_view.sizes(),
336 ", strides() = ",
337 bucket_view.strides());
338 }
339 if (!gradient_as_bucket_view_) {
340 TORCH_INTERNAL_ASSERT(!grad.is_alias_of(bucket_view));
341 }
342}
343
344void Reducer::mark_variable_ready_dense(size_t variable_index) {
345 const auto& bucket_index = variable_locators_[variable_index];
346 auto& bucket = buckets_[bucket_index.bucket_index];
347 auto& variable = bucket.variables[bucket_index.intra_bucket_index];
348 auto& bucket_view = bucket.bucket_views_in[bucket_index.intra_bucket_index];
349
350 // Copy the contents of the gradient tensor to the corresponding part of the
351 // bucket's flattened gradient tensor.
352 // If the gradient is not set, we assume it wasn't computed as part of the
353 // current backwards pass, and we zero the part of the bucket it would
354 // otherwise hold.
355 runGradCallbackForVariable(variable, [&](auto& grad) {
356 if (grad.defined()) {
357 this->check_grad_layout(grad, bucket_view);
358 // When gradient_as_bucket_view_ is false, or even when
359 // gradient_as_bucket_view_ is true, in rare cases users may set grad to
360 // be None after every iteration. In these cases, grad and bucket_view are
361 // pointing to different storages and thus need to copy grads to
362 // bucket_view. If gradient_as_bucket_view_ is set as true, let grad point
363 // to bucket_view. If grad has already been set as views of buckets in
364 // previous iterations, no copy is needed.
365 if (!grad.is_alias_of(bucket_view)) {
366 if (comm_hook_ == nullptr) {
367 auto wrapped =
368 at::native::wrapped_scalar_tensor(double(1.) / div_factor_);
369 if (!grad.requires_grad()) {
370 // Divides while copying into the bucket view to save one scan over
371 // all the input parameters.
372 at::mul_out(bucket_view, grad, wrapped);
373 } else {
374 // If DDP is running with create_graph=True, gradients require_grad
375 // themselves in order to compute higher order derivatives. However,
376 // DDP will not sync up these gradients currently (see
377 // https://github.com/pytorch/pytorch/issues/63812).
378 C10_LOG_EVERY_N(WARNING, 1000)
379 << "Using DistributedDataParallel with create_graph=True "
380 << " is not well-supported. The higher-order gradient will "
381 << " not be synchronized across ranks, and backpropagation "
382 << " through all_reduce operations will not occur. If you require "
383 << " DDP to work with higher-order gradients for your use case, "
384 << " please ping https://github.com/pytorch/pytorch/issues/63929";
385 auto div_result = at::mul(grad, wrapped);
386 bucket_view.copy_(div_result);
387 }
388 } else {
389 bucket_view.copy_(grad);
390 }
391
392 if (gradient_as_bucket_view_) {
393 // Let grad point to bucket_view buffer.
394 grad = bucket_view;
395 // The grad is modified and need to be written back.
396 return true;
397 }
398 } else {
399 // If grad and bucket view point to the same storage, no need to copy.
400 if (comm_hook_ == nullptr) {
401 bucket_view.div_(div_factor_);
402 }
403 }
404 } else {
405 // Gradient is undefined. When find_unused_parameters=True, ensure it is
406 // not marked as locally used, otherwise we will be allreducing zero's
407 // instead of not touching .grad field of parameter.
408 if (this->dynamic_graph_find_unused() ||
409 this->static_graph_first_iteration()) {
410 REDUCER_CHECK(
411 local_used_map_[variable_index].item<int>() == 0,
412 logger_,
413 "Encountered gradient which is undefined, but still allreduced by "
414 "DDP reducer. This indicates a bug in DDP implementation, please "
415 "report a bug with a repro to PyTorch.");
416 }
417 bucket_view.zero_();
418 }
419 // The grad is not modified and doesn't need to be written back.
420 return false;
421 });
422}
423
424void Reducer::mark_variable_ready_sparse(size_t variable_index) {
425 const auto& bucket_index = variable_locators_[variable_index];
426 auto& bucket = buckets_[bucket_index.bucket_index];
427 auto& variable = bucket.variables[bucket_index.intra_bucket_index];
428
429 runGradCallbackForVariable(variable, [&](auto& grad) {
430 REDUCER_CHECK(
431 grad.defined(), logger_, "Expected sparse gradient to be defined.");
432 REDUCER_CHECK(
433 grad.options().layout() == c10::kSparse,
434 logger_,
435 "Expected variable to have sparse gradient.");
436
437 // Sparse tensors cannot be grouped together with other sparse tensors in a
438 // single reduction operation like we can for dense tensors. Therefore, the
439 // `offsets` and `lengths` vectors in the bucket struct are empty, and
440 // there is no pre-existing accumulation tensor.
441 // Directly assign the sparse tensor to the `gradients` field.
442 bucket.gradients = grad;
443 // If no DDP comm hook is registered, the allreduce only sums up the
444 // value, and a separate division is required.
445 if (comm_hook_ == nullptr) {
446 bucket.gradients.div_(div_factor_);
447 }
448 // The grad is modified in place and needs to be written back.
449 return true;
450 });
451}
452
453std::vector<c10d::GradBucket> Reducer::get_grad_buckets(
454 bool return_zero_tensors) const {
455 std::lock_guard<std::mutex> lock(mutex_);
456 std::vector<c10d::GradBucket> gradBuckets;
457 gradBuckets.reserve(buckets_.size());
458 for (const auto i : c10::irange(buckets_.size())) {
459 auto& bucket = buckets_[i];
460 auto variables_for_bucket = get_variables_for_bucket(i, bucket);
461 gradBuckets.emplace_back(
462 i,
463 buckets_.size(),
464 return_zero_tensors ? at::zeros_like(bucket.gradients)
465 : bucket.gradients,
466 bucket.offsets,
467 bucket.lengths,
468 bucket.sizes_vec,
469 variables_for_bucket);
470 }
471 return gradBuckets;
472}
473
474void Reducer::set_forward_pass_work_handle(
475 c10::intrusive_ptr<c10d::Work> forwardPassWorkHandle,
476 bool useStaticWorldSize) {
477 std::lock_guard<std::mutex> lock(mutex_);
478 forwardPassWorkHandle_.workHandle = std::move(forwardPassWorkHandle);
479 forwardPassWorkHandle_.useStaticWorldSize = useStaticWorldSize;
480}
481
482at::Tensor Reducer::get_local_used_map_on_device() const {
483 std::lock_guard<std::mutex> lock(mutex_);
484 return local_used_map_dev_;
485}
486
487void Reducer::push_rebuilt_params_for_all_indices() {
488 std::lock_guard<std::mutex> lock(mutex_);
489 if (!should_rebuild_buckets() || !rebuilt_param_indices_.empty()) {
490 return;
491 }
492 const auto variable_count = params_.size();
493 for (const auto variable_index : c10::irange(variable_count)) {
494 push_rebuilt_params(variable_index);
495 }
496}
497
498void Reducer::push_rebuilt_params(const size_t& index) {
499 rebuilt_params_.push_back(params_[index]);
500 rebuilt_param_indices_.push_back(index);
501}
502
503void Reducer::set_divide_factor() {
504 // If it was scheduled, wait on allreduce in forward pass that tells us
505 // division factor based on no. of currently participating processes.
506 if (div_factor_ == kUnsetDivFactor) {
507 div_factor_ = process_group_->getSize();
508 auto& workHandle = forwardPassWorkHandle_.workHandle;
509 if (workHandle && !forwardPassWorkHandle_.useStaticWorldSize) {
510 workHandle->wait();
511 // PyProcessGroup::PyWork doesn't expose value, so fetch it from the
512 // future
513 auto results = extractTensors(workHandle->getFuture()->value());
514
515 // Guard against the results being empty
516 TORCH_INTERNAL_ASSERT(!results.empty());
517 at::Tensor& res = results.front();
518 div_factor_ = res.item().to<int>();
519 }
520 }
521}
522
523// Right now delay_all_reduce is only called when static_graph_=true and
524// num_iterations_==1.
525void Reducer::delay_all_reduce() {
526 std::lock_guard<std::mutex> lock(this->mutex_);
527
528 if (should_collect_runtime_stats()) {
529 record_backward_compute_end_time();
530 record_backward_comm_start_time();
531 }
532
533 // launch all reduce local used map
534 all_reduce_local_used_map();
535
536 // prepare to set unused_parameters_, if it is static graph,
537 // unused_parameters_ will not change after 1st iteration.
538 unused_parameters_.clear();
539
540 // copy all gradients to buckets
541 for (const auto variable_index : c10::irange(params_.size())) {
542 // set unused_parameters_
543 if (numGradHooksTriggeredMap_[variable_index] == 0) {
544 unused_parameters_.push_back(variable_index);
545 }
546 require_finalize_ = true;
547 set_divide_factor();
548 if (expect_sparse_gradients_[variable_index]) {
549 mark_variable_ready_sparse(variable_index);
550 } else {
551 mark_variable_ready_dense(variable_index);
552 }
553 }
554
555 // To avoid confusion around why static graph is picking up
556 // some parameters as unused on a rank vs not, we log
557 // unused parameter names for each rank for better
558 // debugability when TORCH_DISTRIBUTED_DEBUG is set to
559 // INFO or DETAIL
560 if (ddp_debug_level_ != c10d::DebugLevel::Off) {
561 // construct one string to output
562 std::ostringstream unused_params_stream;
563
564 for (const auto& unused_index : unused_parameters_) {
565 auto param_name = param_names_.find(unused_index);
566 TORCH_INTERNAL_ASSERT(
567 param_name != param_names_.end(),
568 "Expected to find parameter name from unused parameters map in debug mode.");
569 // Add the param_name
570 unused_params_stream << "{" << param_name->second << "," << unused_index
571 << "}";
572 }
573
574 // Each rank prints out all the unused parameters detected
575 if (!unused_parameters_.empty()) {
576 LOG(INFO) << "[Rank " << process_group_->getRank() << "]: "
577 << "Parameter(s) (in the format of {param_name, index}): "
578 << unused_params_stream.str()
579 << " is(are) unused during first iteration. Since"
580 << " static_graph=True is enabled for DDP, we expect"
581 << " this set of unused parameters to remain consistent"
582 << " on this rank throughout the training.";
583 }
584 }
585
586 // launch all reduces for all buckets
587 for (auto& bucket : buckets_) {
588 all_reduce_bucket(bucket);
589 }
590
591 finalize_backward();
592}
593
594void Reducer::set_logger(std::weak_ptr<c10d::Logger> logger) {
595 logger_ = logger;
596}
597
598// The function `autograd_hook` is called after the gradient for a
599// model parameter has been accumulated into its gradient tensor.
600// This function is only to be called from the autograd thread.
601void Reducer::autograd_hook(size_t index) {
602 std::lock_guard<std::mutex> lock(this->mutex_);
603 // Ignore if we don't expect to be called.
604 // This may be the case if the user wants to accumulate gradients
605 // for number of iterations before reducing them.
606 if (!expect_autograd_hooks_) {
607 return;
608 }
609
610 grad_ready_order_indices_.push_back(index);
611
612 // See Note [Skip allreducing local_used_map_dev]
613 if (dynamic_graph_find_unused() || static_graph_first_iteration()) {
614 // Since it gets here, this param has been used for this iteration. We want
615 // to mark it in local_used_map_. During no_sync session, the same var can
616 // be set multiple times, which is OK as does not affect correctness. As
617 // long as it is used once during no_sync session, it is marked as used.
618 // Only set it as locally used if the grad is defined. Otherwise, hooks can
619 // be fired with undefined grads, such as when not all outputs are used in
620 // DDP when computing loss. In this case, we don't want to mark it as
621 // locally used to ensure we don't touch the parameter's .grad field.
622 auto& variable = get_param_from_index(index);
623 runGradCallbackForVariable(variable, [&](auto& grad) {
624 if (grad.defined()) {
625 local_used_map_[index] = 1;
626 }
627 // The gradient is never modified.
628 return false;
629 });
630 }
631
632 if (static_graph_first_iteration()) {
633 numGradHooksTriggeredMap_[index] += 1;
634 return;
635 }
636
637 // If `find_unused_parameters_` is true there may be model parameters that
638 // went unused when computing the model output, they won't be part of the
639 // autograd graph, and won't receive gradients. These parameters are
640 // discovered in the `prepare_for_backward` function and their indexes stored
641 // in the `unused_parameters_` vector.
642 if (!has_marked_unused_parameters_) {
643 has_marked_unused_parameters_ = true;
644 for (const auto& unused_index : unused_parameters_) {
645 mark_variable_ready(unused_index);
646 }
647 }
648
649 // Rebuild bucket only if 1) it is the first time to rebuild bucket 2)
650 // static_graph_ is true or find_unused_parameters_ is false,
651 // 3) this backward pass needs to run allreduce.
652 // Here, we just dump tensors and their parameter indices into
653 // rebuilt_params_ and rebuilt_param_indices_ based on gradient arriving
654 // order, and then at the end of finalize_backward(), buckets will be
655 // rebuilt based on rebuilt_params_ and rebuilt_param_indices_, and then
656 // will be broadcasted and initialized.
657 // If it is static graph, after 1st iteration, check if a variable
658 // is ready for communication based on numGradHooksTriggeredMap_.
659 if (static_graph_after_first_iteration()) {
660 REDUCER_CHECK(
661 numGradHooksTriggeredMapPerIteration_[index] > 0,
662 logger_,
663 "Your training graph has changed in this iteration, ",
664 "e.g., one parameter is unused in first iteration, but ",
665 "then got used in the second iteration. this is not ",
666 "compatible with static_graph set to True.");
667 if (--numGradHooksTriggeredMapPerIteration_[index] == 0) {
668 if (should_rebuild_buckets()) {
669 push_rebuilt_params(index);
670 }
671 // Finally mark variable for which this function was originally called.
672 mark_variable_ready(index);
673 }
674 } else {
675 if (should_rebuild_buckets()) {
676 push_rebuilt_params(index);
677 }
678 // Finally mark variable for which this function was originally called.
679 mark_variable_ready(index);
680 }
681}
682
683void Reducer::all_reduce_local_used_map() {
684 // See Note [Skip allreducing local_used_map_dev]
685 // H2D from local_used_map_ to local_used_map_dev_
686 if (local_used_map_dev_.is_cuda()) {
687 // Note [local_used_map_ -> local_used_map_dev copying]
688 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
689 // We do async H2D to avoid the blocking overhead. The async copy and
690 // allreduce respect the current stream, so will be sequenced
691 // correctly.
692 //
693 // Correct sequencing with respect to host operations is also
694 // essential. The H2D copy_ is stream ordered, while the host's
695 // changes to local_used_map_ are host ordered. If a large backlog of
696 // cuda-stream work pushes the copy_ far into the future, and if no
697 // blocking calls occur between now and finalize_backward()** such
698 // that finalize_backward() re-zeroes local_used_map_ on the host
699 // before the stream executes the copy_, copy_ will read those zeros
700 // instead of the values we thought we told it to read here. Copying
701 // local_used_map_ to a pinned temporary (which the pinned caching
702 // allocator should supply asynchronously) avoids this nasty, rare
703 // race condition.
704 //
705 // ** In the hoped-for case where all params are used, DDP itself
706 // won't do any blocking work between now and the re-zeroing, so the
707 // danger is real.
708 //
709 // Defensively ensures local_used_map_tmp is distinct from
710 // local_used_map_
711 auto local_used_map_tmp = at::native::empty_like(
712 local_used_map_,
713 optTypeMetaToScalarType(local_used_map_.options().dtype_opt()),
714 local_used_map_.options().layout_opt(),
715 local_used_map_.options().device_opt(),
716 true /* pinned_memory */);
717 // Paranoid asserts here because in some workloads, the pinned
718 // allocator behaves in a way we don't understand, and may be bugged.
719 // See https://github.com/pytorch/pytorch/pull/54474
720 TORCH_INTERNAL_ASSERT(local_used_map_tmp.is_pinned());
721 TORCH_INTERNAL_ASSERT(
722 local_used_map_tmp.data_ptr() != local_used_map_.data_ptr());
723 local_used_map_tmp.copy_(local_used_map_);
724 local_used_map_dev_.copy_(local_used_map_tmp, true);
725 } else {
726 local_used_map_dev_.copy_(local_used_map_, true);
727 }
728 std::vector<at::Tensor> temp_local_used_map_dev_vec_ = {local_used_map_dev_};
729 local_used_work_ = process_group_->allreduce(temp_local_used_map_dev_vec_);
730}
731
732at::Tensor& Reducer::get_param_from_index(size_t index) {
733 const auto& bucket_index = variable_locators_[index];
734 auto& bucket = buckets_[bucket_index.bucket_index];
735 // Cannot simply access variable via `bucket.variables[variable_index]` since
736 // return value is used in `runGradCallbackForVariable()` which does not
737 // accept const tensors.
738 auto& variable = bucket.variables[bucket_index.intra_bucket_index];
739 return variable;
740}
741
742void Reducer::checkAndRaiseMarkedTwiceError(size_t index) {
743 // Something is wrong if all variables contained in this bucket have
744 // already been marked as ready.
745 // We don't expect the same variable to be marked ready twice.
746 bool marked_twice =
747 perIterationReadyParams_.find(index) != perIterationReadyParams_.end();
748
749 if (marked_twice) {
750 // Report index of param that has been marked twice. In debug mode, also
751 // report fully qualified parameter name.
752 auto param_name = param_names_.find(index);
753 const bool found_param_name = param_name != param_names_.end();
754 TORCH_INTERNAL_ASSERT(
755 ddp_debug_level_ == c10d::DebugLevel::Off || found_param_name,
756 "Expected to find parameter name in debug mode.");
757 std::string paramInfo = c10::str(
758 "Parameter at index ",
759 index,
760 found_param_name ? c10::str(" with name ", param_name->second) : "",
761 " has been marked as ready twice. This means that multiple autograd engine ",
762 " hooks have fired for this particular parameter during this iteration.");
763 // param_names_ is empty in debug mode.
764 if (!found_param_name) {
765 paramInfo += c10::str(
766 " You can set the environment variable TORCH_DISTRIBUTED_DEBUG to either",
767 " INFO or DETAIL to print parameter names for further debugging.");
768 }
769 std::string common_error = c10::str(
770 "Expected to mark a variable ready only once. ",
771 "",
772 "This error is caused by one of the following reasons: ",
773 "1) Use of a module parameter outside the `forward` function. ",
774 "Please make sure model parameters are not shared across multiple ",
775 "concurrent forward-backward passes. or try to use _set_static_graph() ",
776 "as a workaround if this module graph does not change ",
777 "during training loop.",
778 "2) Reused parameters in multiple reentrant backward passes. For ",
779 "example, if you use multiple `checkpoint` functions to wrap the ",
780 "same part of your model, it would result in the same set of ",
781 "parameters been used by different reentrant backward passes ",
782 "multiple times, and hence marking a variable ready multiple times. ",
783 "DDP does not support such use cases in default. You can try to ",
784 "use _set_static_graph() as a workaround if your module graph ",
785 "does not change over iterations.");
786
787 common_error += c10::str("\n", paramInfo);
788
789 REDUCER_CHECK(
790 has_marked_unused_parameters_,
791 logger_,
792 common_error,
793 "3) Incorrect unused parameter detection. The return value of the ",
794 "`forward` function is inspected by the distributed data parallel ",
795 "wrapper to figure out if any of the module's parameters went ",
796 "unused. For unused parameters, DDP would not expect gradients from ",
797 "then. However, if an unused parameter becomes part of the autograd ",
798 "graph at a later point in time (e.g., in a reentrant backward when ",
799 "using `checkpoint`), the gradient will show up unexpectedly. If all ",
800 "parameters in the model participate in the backward pass, you can ",
801 "disable unused parameter detection by passing the keyword argument ",
802 "`find_unused_parameters=False` to ",
803 "`torch.nn.parallel.DistributedDataParallel`. If unused parameters ",
804 "in the model do not change over iterations, You can try to use ",
805 "_set_static_graph() as a workaround if this module graph does not ",
806 "change during training loop.");
807 REDUCER_CHECK(!has_marked_unused_parameters_, logger_, common_error);
808 }
809}
810
811void Reducer::mark_variable_ready(size_t variable_index) {
812 REDUCER_CHECK(
813 variable_index < variable_locators_.size(),
814 logger_,
815 "Out of range variable index.");
816
817 checkAndRaiseMarkedTwiceError(variable_index);
818 perIterationReadyParams_.insert(variable_index);
819 backward_stats_[variable_index] =
820 current_time_in_nanos() - backward_compute_start_time_;
821
822 // Any time we mark a variable ready (be it in line due to unused parameters,
823 // or via an autograd hook), we require a call to the finalize function. If
824 // this doesn't happen before the next iteration (or call to
825 // `prepare_for_backwards`), we know something is wrong.
826 require_finalize_ = true;
827
828 const auto& bucket_index = variable_locators_[variable_index];
829 auto& bucket = buckets_[bucket_index.bucket_index];
830
831 set_divide_factor();
832
833 if (bucket.expect_sparse_gradient) {
834 mark_variable_ready_sparse(variable_index);
835 } else {
836 mark_variable_ready_dense(variable_index);
837 }
838
839 // TODO(@pietern): Make this work for both CPU/CUDA tensors.
840 // When using CPU tensors we don't need to do this.
841 // Record event so that we can wait for all of them.
842 // auto& event = bucket.events[bucket_index.intra_bucket_index];
843 // event.record();
844
845 // Check if this was the final gradient for this bucket.
846 if (--bucket.pending == 0) {
847 mark_bucket_ready(bucket_index.bucket_index);
848 }
849
850 // Run finalizer function and kick off reduction for local_used_map once the
851 // final bucket was marked ready.
852 if (next_bucket_ == buckets_.size()) {
853 if (dynamic_graph_find_unused()) {
854 all_reduce_local_used_map();
855 }
856
857 torch::autograd::Engine::get_default_engine().queue_callback([=] {
858 std::lock_guard<std::mutex> lock(this->mutex_);
859 if (should_collect_runtime_stats()) {
860 record_backward_compute_end_time();
861 }
862 // Check that all buckets were completed and had their work kicked off.
863 TORCH_INTERNAL_ASSERT(next_bucket_ == buckets_.size());
864 if (static_graph_after_first_iteration() && should_rebuild_buckets()) {
865 for (const auto& unused_index : unused_parameters_) {
866 push_rebuilt_params(unused_index);
867 }
868 }
869 this->finalize_backward();
870 });
871 }
872}
873
874c10::intrusive_ptr<c10::ivalue::Future> Reducer::run_comm_hook(
875 GradBucket& grad_bucket) {
876 if (comm_hook_ == nullptr) {
877 return run_allreduce_hook(grad_bucket);
878 } else {
879 return comm_hook_->runHook(grad_bucket);
880 }
881}
882
883c10::intrusive_ptr<c10::ivalue::Future> Reducer::run_allreduce_hook(
884 GradBucket& grad_bucket) {
885 _AllReduceBySumCommHook allreduce_hook(process_group_);
886 return allreduce_hook.runHook(grad_bucket);
887}
888
889void Reducer::all_reduce_bucket(Bucket& bucket) {
890 auto variables_for_bucket = get_variables_for_bucket(next_bucket_, bucket);
891 // TODO(@pietern): Ensure proper synchronization with the CUDA events
892 // that recorded copies into this `gradients` tensor. If these copies are
893 // executed on non-default streams, the current stream for the device
894 // that holds the `gradients` tensor must wait on these events.
895 //
896 // As long as autograd uses the default stream for every device,
897 // these operations are implicitly sequenced, and we don't need to
898 // do any extra synchronization here.
899 const auto& tensor = bucket.gradients;
900
901 GradBucket grad_bucket(
902 next_bucket_,
903 buckets_.size(),
904 tensor,
905 bucket.offsets,
906 bucket.lengths,
907 bucket.sizes_vec,
908 variables_for_bucket);
909 bucket.future_work = run_comm_hook(grad_bucket);
910}
911
912std::vector<at::Tensor> Reducer::get_variables_for_bucket(
913 size_t bucket_index,
914 const Bucket& bucket) const {
915 // Check if we have cached mapping previously.
916 if (has_rebuilt_bucket_ &&
917 cached_variables_for_bucket_.find(bucket_index) !=
918 cached_variables_for_bucket_.end()) {
919 return cached_variables_for_bucket_[bucket_index];
920 }
921 std::vector<at::Tensor> variables_for_bucket;
922 variables_for_bucket.reserve(bucket.variable_indices.size());
923 for (const auto& variable_index : bucket.variable_indices) {
924 // Grab bucket index where gradient is located using variable_locators_.
925 auto& bucket_index_for_variable = variable_locators_[variable_index];
926 // Grab the actual model parameter.
927 auto& variable =
928 bucket.variables[bucket_index_for_variable.intra_bucket_index];
929 variables_for_bucket.emplace_back(variable);
930 }
931
932 if (has_rebuilt_bucket_) {
933 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
934 cached_variables_for_bucket_.find(bucket_index) ==
935 cached_variables_for_bucket_.end());
936 cached_variables_for_bucket_.insert(
937 {bucket_index, std::move(variables_for_bucket)});
938 return cached_variables_for_bucket_[bucket_index];
939 } else {
940 return variables_for_bucket;
941 }
942}
943
944// Called when the bucket at the specified index is ready to be reduced.
945void Reducer::mark_bucket_ready(size_t bucket_index) {
946 TORCH_INTERNAL_ASSERT(bucket_index >= next_bucket_);
947
948 // Buckets are reduced in sequence. Ignore this bucket if
949 // it's not its turn to be reduced.
950 if (bucket_index > next_bucket_) {
951 return;
952 }
953
954 // Keep going, until we either:
955 // - have kicked off reduction for all buckets, or
956 // - found a bucket that's not yet ready for reduction.
957 for (; next_bucket_ < buckets_.size() && buckets_[next_bucket_].pending == 0;
958 next_bucket_++) {
959 num_buckets_ready_++;
960 if (num_buckets_ready_ == 1 && should_collect_runtime_stats()) {
961 record_backward_comm_start_time();
962 }
963 auto& bucket = buckets_[next_bucket_];
964 all_reduce_bucket(bucket);
965 }
966}
967
968void Reducer::install_futures(
969 c10::List<c10::intrusive_ptr<c10::ivalue::Future>> futs) {
970 // Append instead of overwrite so that this method can be called multiple
971 // times in one iteration.
972 if (!installed_futures_) {
973 installed_futures_ = std::move(futs);
974 } else {
975 installed_futures_->append(futs);
976 }
977}
978
979void Reducer::initialize_buckets(
980 std::vector<std::vector<size_t>> bucket_indices) {
981 // If initialize_buckets is called inside DDP constructor, then
982 // it does not matter rpc context ptr is nullptr or not, as grad
983 // will not be mutated.
984 // If initialize_buckets is called during training loop, e.g, inside
985 // rebuild_buckets(), since grad could be mutated and be pointed to
986 // bucket_view, then it needs to check rpc context ptr is nullptr or not,
987 // If rpc context ptr is nullptr, mutate variable.grad(); otherwise,
988 // mutate grad in rpc context.
989#ifndef _WIN32
990 using torch::distributed::autograd::ThreadLocalDistAutogradContext;
991 this->rpc_context_.set(ThreadLocalDistAutogradContext::getContextPtr());
992#endif
993
994 // This shouldn't be called if we're expecting autograd hooks to fire.
995 REDUCER_CHECK(
996 !expect_autograd_hooks_,
997 logger_,
998 "`initialize_buckets` must NOT be called during autograd execution.");
999
1000 // Clear current bucket assignment.
1001 buckets_.clear();
1002 variable_locators_.clear();
1003
1004 // Ensure we have a bucket index for every variable.
1005 variable_locators_.resize(params_.size());
1006
1007 // Iterate over buckets.
1008 const auto bucket_count = bucket_indices.size();
1009 buckets_.reserve(bucket_count);
1010 for (const auto bucket_index : c10::irange(bucket_count)) {
1011 Bucket bucket;
1012
1013 // TODO(@pietern): Validate indices.
1014 // Must be non-empty, unique, and unique across buckets.
1015 REDUCER_CHECK(
1016 !bucket_indices[bucket_index].empty(),
1017 logger_,
1018 "Empty bucket specified.");
1019
1020 // Variables that expect sparse gradients must have their own bucket.
1021 if (bucket_indices[bucket_index].size() == 1) {
1022 const auto variable_index = bucket_indices[bucket_index].front();
1023 bucket.expect_sparse_gradient = expect_sparse_gradients_[variable_index];
1024 } else {
1025 for (const auto variable_index : bucket_indices[bucket_index]) {
1026 REDUCER_CHECK(
1027 !expect_sparse_gradients_[variable_index],
1028 logger_,
1029 "Buckets with more than one variable cannot include variables ",
1030 "that expect a sparse gradient.");
1031 }
1032 }
1033
1034 if (bucket.expect_sparse_gradient) {
1035 const auto variable_index = bucket_indices[bucket_index].front();
1036 const auto& variable = params_[variable_index];
1037 TORCH_INTERNAL_ASSERT(bucket_indices[bucket_index].size() == 1);
1038 bucket.variables = {variable};
1039 } else {
1040 at::TensorOptions options;
1041 // The start index of the variable in the flattened tensor.
1042 size_t offset = 0;
1043
1044 // Reserve enough space for the per-variable fields stored in the bucket
1045 // for efficiency.
1046 const size_t num_variables = bucket_indices[bucket_index].size();
1047 bucket.variables.reserve(num_variables);
1048 bucket.offsets.reserve(num_variables);
1049 bucket.lengths.reserve(num_variables);
1050 bucket.sizes_vec.reserve(num_variables);
1051
1052 // Iterate over bucket variables.
1053 for (const auto variable_index : bucket_indices[bucket_index]) {
1054 TORCH_INTERNAL_ASSERT(
1055 variable_index < params_.size(),
1056 "Out of range variable index specified.");
1057 const auto& variable = params_[variable_index];
1058 if (!options.has_device()) {
1059 options = options.device(variable.device());
1060 } else {
1061 REDUCER_CHECK(
1062 variable.device() == options.device(),
1063 logger_,
1064 "All parameters in a bucket must be ",
1065 "placed on the same device.");
1066 }
1067 if (!options.has_dtype()) {
1068 options = options.dtype(variable.dtype());
1069 } else {
1070 REDUCER_CHECK(
1071 variable.dtype() == options.dtype(),
1072 logger_,
1073 "All parameters in a bucket must have the same dtype.");
1074 }
1075 const auto length = variable.numel();
1076 bucket.variables.push_back(variable);
1077 bucket.offsets.push_back(offset);
1078 bucket.lengths.push_back(length);
1079 bucket.sizes_vec.push_back(variable.sizes());
1080 offset += length;
1081 }
1082
1083 // Allocate the bucket's flattened `gradients` tensor.
1084 bucket.gradients = at::empty({static_cast<long>(offset)}, options);
1085
1086 // Note: "Gradient Layout Contract"
1087 //
1088 // Here, create views into the `gradients` tensor for each variable's
1089 // grad. Views serve as entry points to `copy_()` each grad's data in/out
1090 // of the flattened `gradients` tensor.
1091 //
1092 // Gradients may have dense memory but non-row-major-contiguous strides
1093 // (e.g. channels_last or channels_last_3d). For coalesced accesses
1094 // during copy_s, it's beneficial for each view's layout to match its
1095 // grad's layout.
1096 //
1097 // Specifically, we expect torch/csrc/autograd/functions/accumulate_grad.h
1098 // produces grads that obey the "Gradient Layout Contract":
1099 // (1) if variable.is_non_overlapping_and_dense(), the stashed grad's
1100 // strides match variable.
1101 // (2) else, stashed grad is rowmajor contiguous.
1102 // and create views to match.
1103 //
1104 // If AccumulateGrad breaks the contract, and produces a grad with an
1105 // unexpected layout, performance will degrade due to poor memory access
1106 // patterns when copy_ing grad data in and out of its bucket view.
1107 // However, numerics remain correct, because the bucket view is the same
1108 // on either end of the raw allreduce. bucket_view_in.copy(grad)
1109 // tranposes
1110 // (+ densifies) to the bucket view's layout, the data is allreduced,
1111 // then grad.copy_(bucket_view_out) transposes it back to grad's layout.
1112 //
1113 // The only way the numerics can go haywire is if the bucket views
1114 // themselves have different layouts across processes.
1115 // Bucket views' sizes and strides are set based on param layouts, using
1116 // the same logic that (we expect) AccumulateGrad uses for their grads.
1117 // Therefore, the only way a bucket view could have different layouts in
1118 // different processes is if its param has a different layout in
1119 // different processes. We can check that param layouts match across
1120 // processes in Reducer's constructor by allreducing some metadata.
1121 // Checking just once won't catch if someone messes with
1122 // param layouts over time, but not messing with params after DDP
1123 // construction is already a documented constraint.
1124 initialize_bucket_views(bucket);
1125 }
1126
1127 // Map participating variables to this bucket.
1128 size_t intra_bucket_index = 0;
1129 for (const auto variable_index : bucket_indices[bucket_index]) {
1130 TORCH_INTERNAL_ASSERT(
1131 variable_index < variable_locators_.size(),
1132 "Out of range variable index specified.");
1133 variable_locators_[variable_index] =
1134 VariableLocator(bucket_index, intra_bucket_index++);
1135 }
1136 bucket.variable_indices = std::move(bucket_indices[bucket_index]);
1137
1138 buckets_.push_back(std::move(bucket));
1139 }
1140}
1141
1142// (see Note: "Gradient Layout Contract" in initialize_buckets).
1143void Reducer::initialize_bucket_views(Reducer::Bucket& bucket) {
1144 const auto& gradients = bucket.gradients;
1145 for (const auto i : c10::irange(bucket.variables.size())) {
1146 auto& v = bucket.variables[i];
1147 const auto offset = bucket.offsets[i];
1148 const auto length = bucket.lengths[i];
1149 if (v.is_non_overlapping_and_dense()) {
1150 // If the param's memory is dense, match its layout, anticipating
1151 // the autograd engine (AccumulateGrad) will also create gradients
1152 // matching its layout.
1153 bucket.bucket_views_in.push_back(
1154 gradients.as_strided(v.sizes(), v.strides(), offset));
1155 } else {
1156 // Fall back to a C-style contiguous view, again anticipating
1157 // AccumulateGrad will do the same when stashing grads for non-dense
1158 // params.
1159 bucket.bucket_views_in.push_back(
1160 gradients.narrow(0, offset, length).view(v.sizes()));
1161 }
1162 // By default `bucket_views_out` and `bucket_views_in` are
1163 // essentially the same thing.
1164 bucket.bucket_views_out = bucket.bucket_views_in;
1165
1166 // If gradient_as_bucket_view_ is set as true, then there are two cases to
1167 // handle: initialize_bucket_views could be called inside initialize_buckets
1168 // when rebuild_buckets, if grad has already been defined/calculated in
1169 // previous iteration, old grad needs to be copied into new bucket_view and
1170 // let grad point to the new bucket_view, initialize_bucket_views could also
1171 // be called inside initialize_buckets during construction. Grads are not
1172 // defined during construction time, in this case, do not let grad point to
1173 // bucket_view, because grads should be kept as being undefined for globally
1174 // unused parameters.
1175 if (gradient_as_bucket_view_) {
1176 auto& bucket_view = bucket.bucket_views_in.back();
1177 runGradCallbackForVariable(v, [&](auto& grad) {
1178 if (grad.defined() && !grad.is_alias_of(bucket_view)) {
1179 bucket_view.copy_(grad);
1180 grad = bucket_view;
1181 // The grad is modefied and needs to be written back.
1182 return true;
1183 }
1184 // The grad is not modified and does not need to be written back.
1185 return false;
1186 });
1187 }
1188 }
1189}
1190
1191// (see Note: "Gradient Layout Contract" in initialize_buckets).
1192void Reducer::populate_bucket_views_out(
1193 Reducer::Bucket& bucket,
1194 at::Tensor& tensor) {
1195 bucket.bucket_views_out.clear();
1196 for (const auto i : c10::irange(bucket.variables.size())) {
1197 const auto& v = bucket.variables[i];
1198 const auto offset = bucket.offsets[i];
1199 const auto length = bucket.lengths[i];
1200 if (v.is_non_overlapping_and_dense()) {
1201 // If the param's memory is dense, match its layout, anticipating
1202 // the autograd engine (AccumulateGrad) will also create gradients
1203 // matching its layout.
1204 bucket.bucket_views_out.push_back(
1205 tensor.as_strided(v.sizes(), v.strides(), offset));
1206 } else {
1207 // Fall back to a C-style contiguous view, again anticipating
1208 // AccumulateGrad will do the same when stashing grads for non-dense
1209 // params.
1210 bucket.bucket_views_out.push_back(
1211 tensor.narrow(0, offset, length).view(v.sizes()));
1212 }
1213 }
1214}
1215
1216void Reducer::prepare_for_forward() {
1217 std::lock_guard<std::mutex> lock(mutex_);
1218 num_iterations_++;
1219 if (should_collect_runtime_stats()) {
1220 record_forward_compute_start_time();
1221 }
1222}
1223
1224void Reducer::reset_bucket_counting() {
1225 next_bucket_ = 0;
1226 // Reset num_buckets_ready_ at the beginning of backward computation
1227 // in each iteration.
1228 num_buckets_ready_ = 0;
1229
1230 for (auto& bucket : buckets_) {
1231 bucket.pending = bucket.variables.size();
1232 }
1233
1234 if (static_graph_) {
1235 numGradHooksTriggeredMapPerIteration_ = numGradHooksTriggeredMap_;
1236 }
1237}
1238
1239// Traverse the autograd graph starting at the specified output.
1240// All parameters for which we have a pointer to their gradient accumulation
1241// functions, but don't show up in the autograd graph will be marked ready for
1242// for reduction as soon as the first autograd hook is called. This is not
1243// done immediately because the model output may be ignored, and we only
1244// want to start performing reductions on `torch.autograd.backward()`.
1245void Reducer::search_unused_parameters(
1246 const std::vector<torch::autograd::Variable>& outputs) {
1247 std::unordered_set<torch::autograd::Node*> seen;
1248 std::vector<torch::autograd::Node*> queue;
1249
1250 RECORD_FUNCTION(
1251 "torch.distributed.ddp.reducer::search_unused_parameters",
1252 std::vector<c10::IValue>());
1253
1254 // Seed queue with the grad functions of all outputs.
1255 for (const auto& output : outputs) {
1256 const auto& grad_fn = output.grad_fn();
1257 if (grad_fn) {
1258 queue.push_back(grad_fn.get());
1259 }
1260 }
1261
1262 // Traverse the autograd graph starting at the specified output.
1263 while (!queue.empty()) {
1264 auto fn = queue.back();
1265 queue.pop_back();
1266 for (const auto& edge : fn->next_edges()) {
1267 if (auto next_ptr = edge.function.get()) {
1268 const bool was_inserted = seen.insert(next_ptr).second;
1269 if (was_inserted) {
1270 queue.push_back(next_ptr);
1271 }
1272 }
1273 }
1274 }
1275
1276 // Find accumulator functions that don't show up in this graph.
1277 for (const auto& it : gradAccToVariableMap_) {
1278 // If the accumulator function is present in the graph, we know
1279 // a gradient will be computed for the corresponding parameter.
1280 if (seen.count(it.first) == 0) {
1281 if (ddp_debug_level_ == c10d::DebugLevel::Detail) {
1282 const auto param_info = param_names_.find(it.second);
1283 TORCH_INTERNAL_ASSERT(
1284 param_info != param_names_.end(),
1285 "Did not find variable index ",
1286 it.second,
1287 " in DDP parameter name mapping!");
1288 const auto param_name = param_info->second;
1289 LOG(INFO) << "[Rank " << process_group_->getRank() << "]: "
1290 << "Parameter " << param_name << " at index " << it.second
1291 << " is marked as unused.";
1292 }
1293 unused_parameters_.push_back(it.second);
1294 }
1295 }
1296
1297 // Warn user about unnecessary perf hit if all parameters were used in
1298 // forward.
1299 if (unused_parameters_.empty()) {
1300 TORCH_WARN_ONCE(
1301 "find_unused_parameters=True was specified in DDP constructor, "
1302 "but did not find any unused parameters in the forward pass. This flag "
1303 "results in an extra traversal of the autograd graph every iteration, "
1304 " which can adversely affect performance. If your model indeed never "
1305 "has any unused parameters in the forward pass, consider turning this "
1306 "flag off. Note that this warning may be a false positive if your model "
1307 "has flow control causing later iterations to have unused parameters.");
1308 }
1309 if (!static_graph_ && ddp_graph_static_) {
1310 if (num_iterations_ > 1) {
1311 // Graph is still static if the set of unused parameters did not change.
1312 ddp_graph_static_ =
1313 prev_iteration_unused_parameters_ == unused_parameters_;
1314
1315 if (!ddp_graph_static_) {
1316 // Log graph is not static. Logger takes care of ensuring this is done
1317 // only once to avoid overhead.
1318 logger_.lock()->log_if_graph_static(false);
1319 }
1320 }
1321 prev_iteration_unused_parameters_ = unused_parameters_;
1322 }
1323}
1324
1325void Reducer::prepare_for_backward(
1326 const std::vector<torch::autograd::Variable>& outputs) {
1327 std::lock_guard<std::mutex> lock(mutex_);
1328
1329 backward_compute_start_time_ = current_time_in_nanos();
1330 if (should_collect_runtime_stats()) {
1331 record_backward_compute_start_time();
1332 }
1333
1334 // Reset accounting.
1335 expect_autograd_hooks_ = true;
1336 // Clear gradient ready order as it can be different in the next iteration.
1337 grad_ready_order_indices_.clear();
1338
1339 reset_bucket_counting();
1340
1341 // Reset unused parameter accounting.
1342 has_marked_unused_parameters_ = false;
1343 // Reset per iteration marked ready parameters.
1344 perIterationReadyParams_.clear();
1345
1346 // If static graph is not set, search graph to detect unused parameters.
1347 // When static graph is set, unused_parameters_ will be detected and will
1348 // not change after 1st iteration.
1349 // If static_graph_ = false and find_unused_parameters_ is false,
1350 // we assume that autograd hooks for ALL variables will be called,
1351 // and we don't have to search the autograd graph for presence of these hooks.
1352 if (dynamic_graph_find_unused()) {
1353 unused_parameters_.clear();
1354 search_unused_parameters(outputs);
1355 }
1356}
1357
1358void Reducer::copy_bucket_to_grad(
1359 at::Tensor& variable,
1360 Reducer::Bucket& bucket,
1361 size_t intra_bucket_index,
1362 bool global_unused) {
1363 const auto& bucket_view = bucket.bucket_views_out[intra_bucket_index];
1364 runGradCallbackForVariable(variable, [&](auto& grad) {
1365 // If a parameter is globally unused, we keep its grad untouched.
1366 if (!global_unused) {
1367 if (!grad.defined()) {
1368 // Creates grad according to the "Gradient Layout Contract"
1369 // (see torch/csrc/autograd/functions/accumulate_grad.h)
1370 grad =
1371 torch::autograd::utils::clone_obey_contract(bucket_view, variable);
1372 } else {
1373 grad.copy_(bucket_view);
1374 }
1375 // The grad is modified and needs to be written back.
1376 return true;
1377 }
1378 // The grad is not modified.
1379 return false;
1380 });
1381}
1382
1383std::vector<std::string> Reducer::getUnmarkedParamsForIteration() {
1384 std::vector<std::string> unMarkedParamNames;
1385 for (const auto& it : param_names_) {
1386 if (perIterationReadyParams_.find(it.first) ==
1387 perIterationReadyParams_.end()) {
1388 unMarkedParamNames.push_back(it.second);
1389 }
1390 }
1391 return unMarkedParamNames;
1392}
1393
1394std::vector<size_t> Reducer::getUnmarkedParamIndicesForIteration() {
1395 std::vector<size_t> unmarked_param_indices;
1396 const auto variable_count = params_.size();
1397 for (const auto variable_index : c10::irange(variable_count)) {
1398 if (perIterationReadyParams_.find(variable_index) ==
1399 perIterationReadyParams_.end()) {
1400 unmarked_param_indices.push_back(variable_index);
1401 }
1402 }
1403 return unmarked_param_indices;
1404}
1405
1406// A bucket with one or more dense tensors needs to be unflattened.
1407void Reducer::finalize_bucket_dense(Bucket& bucket) {
1408 for (const auto intra_bucket_index : c10::irange(bucket.variables.size())) {
1409 auto& variable = bucket.variables[intra_bucket_index];
1410
1411 bool global_unused = false;
1412 // See Note [Skip allreducing local_used_map_dev]
1413 if (static_graph_ || find_unused_parameters_) {
1414 // Determine if this param has been used globally or not.
1415 //
1416 // If the variable was used locally, it is also used globally and then
1417 // we don't need to wait for the reduction. Otherwise we lazily wait for
1418 // the reduction to complete, only when we see a variable that was
1419 // unused locally. Then we end up delaying the synchronization point
1420 // that local_used_work_->wait() implies. If we don't have any unused
1421 // parameters at all, we can skip waiting for the work to complete
1422 // altogether, and cause negligible performance overhead for models
1423 // where all parameters are used. Such lazily waiting means minimizing
1424 // performance impact for the big majority of models where all
1425 // parameters are always used. Then we only pay the overhead cost if
1426 // there is indeed a parameter that is locally unused, because we need
1427 // to check if it's also globally unused.
1428 size_t variable_index = bucket.variable_indices[intra_bucket_index];
1429 // Note: global_unused might not be global yet. As we lazily wait for
1430 // the reduction to complete, it becomes really global only if we get to
1431 // the point as below where we wait for the reduction work, make D2H
1432 // copy, and update global_unused with the real global consensus, i.e.
1433 // local_used_map_reduced_ is true.
1434 global_unused = local_used_map_[variable_index].item<int>() == 0;
1435 if (global_unused && !local_used_map_reduced_) {
1436 // Wait for local_used_map reduction to complete.
1437 local_used_work_->wait();
1438 // D2H from local_used_map_dev_ to local_used_map_
1439 // Blocking copy, if local_used_map_dev_ is cuda
1440 local_used_map_.copy_(local_used_map_dev_);
1441
1442 global_unused = local_used_map_[variable_index].item<int>() == 0;
1443 local_used_map_reduced_ = true;
1444 }
1445 }
1446
1447 if (!gradient_as_bucket_view_) {
1448 if (set_grads_to_none_) {
1449 runGradCallbackForVariable(variable, [&](auto& grad) {
1450 grad.reset();
1451 return true;
1452 });
1453 } else {
1454 RECORD_FUNCTION(
1455 "torch.distributed.ddp.reducer::copy_bucket_to_grad",
1456 std::vector<c10::IValue>({variable}));
1457 copy_bucket_to_grad(
1458 variable, bucket, intra_bucket_index, global_unused);
1459 }
1460 } else {
1461 const auto& bucket_view_out = bucket.bucket_views_out[intra_bucket_index];
1462 auto& bucket_view_in = bucket.bucket_views_in[intra_bucket_index];
1463 // If a communication hook is registered, then `bucket_view_out` stores
1464 // the allreduced results in a newly allocated tensor, so we copy
1465 // `bucket_view_out` back to `bucket_view_in` for this gradient.
1466 if (!bucket_view_in.is_alias_of(bucket_view_out)) {
1467 bucket_view_in.copy_(bucket_view_out);
1468 }
1469 runGradCallbackForVariable(variable, [&](auto& grad) {
1470 if (set_grads_to_none_) {
1471 grad.reset();
1472 return true;
1473 }
1474 // If a parameter is globally unused, we keep its grad untouched.
1475 if (!global_unused) {
1476 // If grad is globally used but locally unused, let grad point to
1477 // bucket_view_in
1478 if (!grad.defined()) {
1479 grad = bucket_view_in;
1480 } else {
1481 if (!grad.is_alias_of(bucket_view_in)) {
1482 REDUCER_CHECK(
1483 false,
1484 logger_,
1485 "Detected at least one parameter gradient is not the "
1486 "expected DDP bucket view with gradient_as_bucket_view=True. "
1487 "This may happen (for example) if multiple allreduce hooks "
1488 "were registered onto the same parameter. If you hit this error, "
1489 "please file an issue with a minimal repro.");
1490 }
1491 }
1492 // The grad is modified and needs to be written back.
1493 return true;
1494 }
1495 // The grad is not modified.
1496 return false;
1497 });
1498 }
1499 }
1500}
1501
1502void Reducer::finalize_backward() {
1503 // No longer expect autograd hooks to fire after this function returns.
1504 TORCH_INTERNAL_ASSERT(expect_autograd_hooks_);
1505 expect_autograd_hooks_ = false;
1506
1507 // No longer require call to finalize after this function returns.
1508 TORCH_INTERNAL_ASSERT(require_finalize_);
1509 require_finalize_ = false;
1510
1511 // Wait for asynchronous reduction to complete, and unflatten the bucket's
1512 // flattened `gradients` tensor.
1513 for (auto& bucket : buckets_) {
1514 // See Note [DDP Communication Hook]
1515 TORCH_INTERNAL_ASSERT(
1516 bucket.future_work,
1517 "Expected bucket.future_work not to be null. "
1518 "This may indicate that communication hook was not properly installed.");
1519 bucket.future_work->wait();
1520 auto future_result = comm_hook_ == nullptr
1521 ? detail::parseCppCommHookResult(bucket.future_work->value())
1522 : comm_hook_->parseHookResult(bucket.future_work->value());
1523 if (bucket.expect_sparse_gradient) {
1524 bucket.gradients.copy_(future_result);
1525 } else {
1526 // Reinitialize only `bucket_views_out` with the future_result by
1527 // following the same logic in `initialize_buckets`.
1528 populate_bucket_views_out(bucket, future_result);
1529 }
1530
1531 // Unset allreduce division factor, as it may change in next backwards pass
1532 // when running with DDP join mode.
1533 div_factor_ = kUnsetDivFactor;
1534
1535 if (!bucket.expect_sparse_gradient) {
1536 // We don't need to finalize the sparse bucket since the sparse grad and
1537 // the bucket essentially point to the same storage. As a result, once
1538 // the allreduce is done, the sparse grads are automatically updated.
1539 finalize_bucket_dense(bucket);
1540 }
1541 }
1542
1543 if (installed_futures_ != c10::nullopt) {
1544 c10::collectAll(*installed_futures_)->wait();
1545 installed_futures_ = c10::nullopt;
1546 }
1547
1548 // See Note [Skip allreducing local_used_maps_dev]
1549 if (dynamic_graph_find_unused() || static_graph_first_iteration()) {
1550 // Due to the lazy wait, it is possible that reduction of the current
1551 // iteration is still going when the one for next iteration gets kicked off.
1552 // For such case, we want to wait explicitly to make sure the reduction does
1553 // complete before kicking off next one. Otherwise the previous one may
1554 // interfere, write to the device-side memory and clobber the content of
1555 // local_unused_maps_dev_.
1556 if (!local_used_map_reduced_) {
1557 local_used_work_->wait();
1558 }
1559 }
1560
1561 if (dynamic_graph_find_unused()) {
1562 // Reset unused parameter accounting.
1563 // See Note [local_used_map_ -> local_used_map_dev copying]
1564 local_used_map_.fill_(0);
1565 local_used_map_reduced_ = false;
1566 }
1567
1568 if (should_collect_runtime_stats()) {
1569 record_backward_comm_end_time();
1570 }
1571}
1572
1573void Reducer::runGradCallbackForVariable(
1574 at::Tensor& variable,
1575 GradCallback&& cb) {
1576#ifdef _WIN32
1577 cb(variable.mutable_grad());
1578#else
1579 auto context_ptr = rpc_context_.context_ptr.load();
1580 if (context_ptr == nullptr) {
1581 cb(variable.mutable_grad());
1582 } else {
1583 // Under distributed autograd
1584 context_ptr->runGradCallbackForVariable(variable, std::move(cb));
1585 }
1586#endif
1587}
1588
1589#ifndef _WIN32
1590void Reducer::RpcContext::set(ContextPtr&& new_context_ptr) {
1591 // We should set 'new_context_ptr' even if it's nullptr. That means the
1592 // reducer is under a local backward run.
1593 const auto new_context_raw_ptr = new_context_ptr.get();
1594 if (context_ptr.exchange(new_context_raw_ptr) != new_context_raw_ptr) {
1595 // Set the shared ptr to the context only if it's set first time.
1596 // All call sites should use the same context ptr.
1597 // Use an atomic to avoid data race from multiple threads.
1598 context_ptr_holder = std::move(new_context_ptr);
1599 }
1600}
1601#endif
1602
1603void Reducer::sync_bucket_indices(
1604 std::vector<std::vector<size_t>>& bucket_indices) {
1605 auto num_buckets = bucket_indices.size();
1606 std::vector<size_t> bucket_sizes;
1607 bucket_sizes.reserve(num_buckets);
1608 int64_t total_size = 0;
1609 for (const auto i : c10::irange(num_buckets)) {
1610 auto bucket_size = bucket_indices.at(i).size();
1611 bucket_sizes.push_back(bucket_size);
1612 total_size += bucket_size;
1613 }
1614
1615 at::TensorOptions options;
1616 options = options.dtype(at::kInt);
1617 options = options.device(params_[0].device());
1618
1619 // Group indices and num_bucket together into indices_tensor
1620 // Broadcast this tensor first, as its size is equal among all processes
1621 auto indices_tensor = at::empty({total_size + 1}, at::kInt);
1622 auto indices_accessor = indices_tensor.accessor<int, 1>();
1623 auto indices_accessor_Index = 0;
1624 for (const auto i : c10::irange(num_buckets)) {
1625 const auto& bucket_size = bucket_indices.at(i).size();
1626 for (const auto j : c10::irange(bucket_size)) {
1627 indices_accessor[indices_accessor_Index++] = bucket_indices[i][j];
1628 }
1629 }
1630 indices_accessor[indices_accessor_Index] = num_buckets;
1631
1632 // Copy CPU tensor to device tensor, as the process_group_ could be NCCL and
1633 // it can only broadcast device tensors.
1634 auto indices_tensor_device = at::empty({total_size + 1}, options);
1635 indices_tensor_device.copy_(indices_tensor, /*non_blocking=*/true);
1636 std::vector<at::Tensor> indices_tensor_list = {indices_tensor_device};
1637 process_group_->broadcast(indices_tensor_list)->wait();
1638 indices_tensor.copy_(indices_tensor_list.front(), /*non_blocking=*/false);
1639
1640 // Update num_buckets after receiving it from rank 0
1641 num_buckets = indices_accessor[indices_accessor_Index];
1642
1643 // Broadcast bucket_sizes
1644 auto bucket_sizes_tensor = at::empty({(int64_t)num_buckets}, at::kInt);
1645 auto bucket_sizes_accessor = bucket_sizes_tensor.accessor<int, 1>();
1646 for (const auto i : c10::irange(num_buckets)) {
1647 // For rank != 0, it is possible that local num buckets bucket_sizes.size()
1648 // is smaller than broadcasted num_buckets
1649 bucket_sizes_accessor[i] =
1650 bucket_sizes.at(std::min(i, (bucket_sizes.size() - 1)));
1651 }
1652 auto bucket_sizes_tensor_device = at::empty({(int64_t)num_buckets}, options);
1653 bucket_sizes_tensor_device.copy_(bucket_sizes_tensor, /*non_blocking=*/true);
1654 std::vector<at::Tensor> bucket_sizes_tensor_list = {
1655 bucket_sizes_tensor_device};
1656 process_group_->broadcast(bucket_sizes_tensor_list)->wait();
1657 bucket_sizes_tensor.copy_(
1658 bucket_sizes_tensor_list.front(), /*non_blocking=*/false);
1659
1660 // Clear bucket_indices first, and then update bucket_indices using received
1661 // num_buckets, bucket_sizes_tensor and indices_tensor from rank 0
1662 bucket_indices.clear();
1663 bucket_indices.reserve(num_buckets);
1664 indices_accessor_Index = 0;
1665 for (const auto i : c10::irange(num_buckets)) {
1666 const auto& bucket_size = bucket_sizes_accessor[i];
1667 std::vector<size_t> bucket;
1668 bucket.reserve(bucket_size);
1669 for (const auto j : c10::irange(bucket_size)) {
1670 (void)j;
1671 bucket.push_back(indices_accessor[indices_accessor_Index++]);
1672 }
1673 bucket_indices.emplace_back(std::move(bucket));
1674 }
1675}
1676
1677bool Reducer::rebuild_buckets() {
1678 // Ensure reduction for previous backwards pass is finished. If user's model
1679 // has unused parameters for example, this will raise an error recommending to
1680 // run with find_unused_parameters=True, instead of the size mismatch
1681 // exception below.
1682 std::lock_guard<std::mutex> lock(mutex_);
1683 ensure_prior_reduction_finished();
1684 if (!should_rebuild_buckets() || rebuilt_params_.empty()) {
1685 return false;
1686 }
1687
1688 TORCH_INTERNAL_ASSERT(
1689 rebuilt_params_.size() == rebuilt_param_indices_.size(),
1690 c10::str(
1691 "rebuilt parameter tensors size is not same as rebuilt parameter indices size: ",
1692 rebuilt_params_.size(),
1693 " versus ",
1694 rebuilt_param_indices_.size()));
1695 TORCH_INTERNAL_ASSERT(
1696 params_.size() == rebuilt_param_indices_.size(),
1697 c10::str(
1698 "rebuilt parameter indices size is not same as original model parameters size.",
1699 "Original model param size is: ",
1700 params_.size(),
1701 " versus rebuilt params size of: ",
1702 rebuilt_param_indices_.size()));
1703 std::vector<std::vector<size_t>> rebuilt_bucket_indices;
1704 std::vector<size_t> bucket_size_limits;
1705 bucket_size_limits.push_back(first_bucket_bytes_cap_);
1706 bucket_size_limits.push_back(bucket_bytes_cap_);
1707 std::vector<size_t> per_bucket_size_limits;
1708 auto ddp_set_last_bucket_as_small =
1709 (parse_env("DDP_SET_LAST_BUCKET_CAP") == "1");
1710
1711 if (ddp_set_last_bucket_as_small) {
1712 // Reverse so that first_bucket_bytes_cap_ (smaller bucket) becomes the last
1713 // bucket. We cannot simply pass in {bucket_bytes_cap_,
1714 // first_bucket_bytes_cap} as the bucket order as we would immediately
1715 // advance to the 2nd element after the first bucket, whereas we only want
1716 // the last bucket to have a smaller size.
1717 std::reverse(rebuilt_params_.begin(), rebuilt_params_.end());
1718 std::reverse(rebuilt_param_indices_.begin(), rebuilt_param_indices_.end());
1719 }
1720 std::tie(rebuilt_bucket_indices, per_bucket_size_limits) =
1721 compute_bucket_assignment_by_size(
1722 rebuilt_params_,
1723 bucket_size_limits,
1724 expect_sparse_gradients_,
1725 rebuilt_param_indices_,
1726 logger_);
1727
1728 if (ddp_set_last_bucket_as_small) {
1729 // Reverse again because buckets were rebuilt in the opposite of gradient
1730 // ready order.
1731 std::reverse(rebuilt_bucket_indices.begin(), rebuilt_bucket_indices.end());
1732 std::reverse(per_bucket_size_limits.begin(), per_bucket_size_limits.end());
1733 }
1734
1735 if (ddp_debug_level_ != c10d::DebugLevel::Off) {
1736 TORCH_INTERNAL_ASSERT(
1737 rebuilt_bucket_indices.size() == per_bucket_size_limits.size())
1738 LOG(INFO) << rebuilt_bucket_indices.size()
1739 << " buckets rebuilt with size limits: "
1740 << c10::Join(", ", per_bucket_size_limits) << " bytes.";
1741 }
1742
1743 // For rebuilt bucket indices, it needs to be synced across all ranks.
1744 // Broadcast the newly rebuilt bucket indices from rank 0 in default.
1745 // After syncing up rebuilt bucket indices, initialize buckets for reducer.
1746 sync_bucket_indices(rebuilt_bucket_indices);
1747
1748 has_rebuilt_bucket_ = true;
1749 rebuilt_params_.clear();
1750 rebuilt_param_indices_.clear();
1751
1752 initialize_buckets(std::move(rebuilt_bucket_indices));
1753
1754 return true;
1755}
1756
1757// See Note [DDP Communication Hook]
1758void Reducer::register_comm_hook(std::unique_ptr<CommHookInterface> iface) {
1759 REDUCER_CHECK(
1760 comm_hook_ == nullptr,
1761 logger_,
1762 "register_comm_hook or register_builtin_comm_hook can only be called once.");
1763
1764 comm_hook_ = std::move(iface);
1765}
1766
1767// See Note [DDP Communication Hook]
1768void Reducer::register_builtin_comm_hook(
1769 c10d::BuiltinCommHookType comm_hook_type) {
1770 REDUCER_CHECK(
1771 comm_hook_ == nullptr,
1772 logger_,
1773 "register_builtin_comm_hook or register_comm_hook can only be called once.");
1774
1775 switch (comm_hook_type) {
1776 case c10d::BuiltinCommHookType::ALLREDUCE:
1777 comm_hook_ = std::make_unique<c10d::AllReduceCommHook>(process_group_);
1778 LOG(INFO) << "Built-in communication hook ALLREDUCE is registered.";
1779 break;
1780 case c10d::BuiltinCommHookType::FP16_COMPRESS:
1781 comm_hook_ = std::make_unique<c10d::FP16CompressCommHook>(process_group_);
1782 LOG(INFO) << "Built-in communication hook FP16_COMPRESS is registered.";
1783 break;
1784 default:
1785 TORCH_WARN_ONCE(
1786 "Unknown built-in DDP comm hook type is provided. No comm hook will be used.");
1787 }
1788}
1789
1790void Reducer::set_grads_to_none(bool set_to_none) {
1791 set_grads_to_none_ = set_to_none;
1792}
1793
1794void Reducer::ensure_prior_reduction_finished() {
1795 // Check that any prior reduction has finished.
1796 // The variable `require_finalize_` is true until all gradients
1797 // have been computed and reduction of all buckets has been kicked off.
1798 if (require_finalize_) {
1799 // Collect unmarked parameter indices, additionally, in debug mode retrieve
1800 // parameter names.
1801 auto unmarked_param_indices = getUnmarkedParamIndicesForIteration();
1802 // We should have some unmarked parameter indices, otherwise we would not
1803 // have run into this error branch.
1804 TORCH_INTERNAL_ASSERT(!unmarked_param_indices.empty());
1805
1806 std::string kBaseErrorMsg =
1807 "Expected to have finished reduction in the prior iteration before "
1808 "starting a new one. "
1809 ""
1810 "This error indicates that your module has parameters that were "
1811 "not used in producing loss. ";
1812 std::string kOutputsNotUsedInLossErrorMsg =
1813 "making sure all "
1814 "`forward` function outputs participate in calculating loss. ";
1815 std::string kDDPBugErrorMsg =
1816 "\nIf you already have done the above, then the distributed "
1817 "data parallel module wasn't able to locate the output tensors in the "
1818 "return value of your module's `forward` function. "
1819 "Please include the loss function and the structure of the return "
1820 "value of `forward` of your module when reporting this issue (e.g. "
1821 "list, dict, iterable).";
1822
1823 if (static_graph_) {
1824 kBaseErrorMsg =
1825 "Expected to have finished reduction in the prior iteration before "
1826 "starting a new one. "
1827 "This error indicates that your training graph has changed "
1828 "in this iteration, e.g., one parameter is used in first "
1829 "iteration, but then got unused in the second iteration. "
1830 "this is not compatible with static_graph set to True.";
1831 } else if (!find_unused_parameters_) {
1832 // Parameters may have been unused in forward pass, or not all outputs
1833 // were used in producing loss.
1834 kBaseErrorMsg +=
1835 "You can enable unused parameter detection by passing the "
1836 "keyword argument `find_unused_parameters=True` to "
1837 "`torch.nn.parallel.DistributedDataParallel`, and by \n";
1838 kBaseErrorMsg += kOutputsNotUsedInLossErrorMsg;
1839 kBaseErrorMsg += kDDPBugErrorMsg;
1840 } else {
1841 // Note that it does not really matter whether unused_parameters_.empty(),
1842 // since user may have enabled detection but this particular iteration
1843 // could have used or not used all parameters.
1844 kBaseErrorMsg +=
1845 "Since `find_unused_parameters=True` is enabled, this likely "
1846 " means that not all `forward` outputs participate in computing loss. You can fix this by ";
1847 kBaseErrorMsg += kOutputsNotUsedInLossErrorMsg;
1848 kBaseErrorMsg += kDDPBugErrorMsg;
1849 }
1850
1851 const std::string unmarked_param_indices_info = c10::str(
1852 "\n",
1853 "Parameter indices which did not receive grad for rank ",
1854 process_group_->getRank(),
1855 ": ",
1856 unmarked_param_indices);
1857
1858 if (ddp_debug_level_ == DebugLevel::Off) {
1859 // Without debug mode, log unmarked_param_indices, as well as
1860 // recommendation to use debug mode to print parameter names.
1861 kBaseErrorMsg += unmarked_param_indices_info;
1862 kBaseErrorMsg +=
1863 "\n In addition, you can set the environment variable "
1864 "TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print out information "
1865 "about which particular parameters did not receive gradient on this rank "
1866 "as part of this error";
1867 } else {
1868 // Retrieve set of parameter names that did not receive gradient.
1869 auto unmarkedParams = getUnmarkedParamsForIteration();
1870 TORCH_INTERNAL_ASSERT(!unmarkedParams.empty());
1871 for (const auto& s : unmarkedParams) {
1872 LOG(INFO) << "[Rank " << process_group_->getRank() << "] "
1873 << "Parameter: " << s
1874 << " did not get gradient in backwards pass.";
1875 }
1876 const std::string unmarkedParamInfo = c10::Join(", ", unmarkedParams);
1877 // In debug mode, log param names and indices that went unused.
1878 kBaseErrorMsg += c10::str(
1879 "\n",
1880 "Parameters which did not receive grad for rank ",
1881 process_group_->getRank(),
1882 ": ",
1883 unmarkedParamInfo);
1884 kBaseErrorMsg += unmarked_param_indices_info;
1885 }
1886 REDUCER_CHECK(false, logger_, kBaseErrorMsg);
1887 }
1888}
1889
1890void Reducer::set_ddp_runtime_logging_sample_rate(int sample_rate) {
1891 ddp_runtime_logging_sample_rate_ = sample_rate;
1892}
1893
1894int Reducer::get_ddp_runtime_logging_sample_rate() {
1895 return ddp_runtime_logging_sample_rate_;
1896}
1897
1898bool Reducer::should_collect_runtime_stats() {
1899 if (num_iterations_ > 0 &&
1900 (num_iterations_ <= 10 ||
1901 num_iterations_ % get_ddp_runtime_logging_sample_rate() == 0)) {
1902 return true;
1903 }
1904 return false;
1905}
1906
1907void Reducer::record_forward_compute_start_time() {
1908 if (timer_) {
1909 timer_->record(Timer::Event::kForwardStart);
1910 }
1911}
1912
1913void Reducer::record_backward_compute_start_time() {
1914 if (timer_) {
1915 timer_->record(Timer::Event::kBackwardComputeStart);
1916 }
1917}
1918
1919void Reducer::record_backward_compute_end_time() {
1920 if (timer_) {
1921 timer_->record(Timer::Event::kBackwardComputeEnd);
1922 }
1923}
1924
1925void Reducer::record_backward_comm_start_time() {
1926 if (timer_) {
1927 timer_->record(Timer::Event::kBackwardCommStart);
1928 }
1929}
1930
1931void Reducer::record_backward_comm_end_time() {
1932 if (timer_) {
1933 timer_->record(Timer::Event::kBackwardCommEnd);
1934 }
1935}
1936
1937void Reducer::set_static_graph() {
1938 std::lock_guard<std::mutex> lock(mutex_);
1939 REDUCER_CHECK(
1940 num_iterations_ == 0,
1941 logger_,
1942 "set_static_graph() should be called before training loop starts "
1943 "and after DistributedDataParallel is constructed.");
1944 static_graph_ = true;
1945 // when static_graph_ is set as true, always initialize_local_used_map
1946 // and detect the global unused parameters in the first iteration.
1947 initialize_local_used_map();
1948}
1949
1950namespace {
1951
1952// Tensors may be coalesced into buckets. Buckets must contain tensors of
1953// the same type, on the same device, so a bucket can identified by a
1954// composite key of a tensor's type identifier and its device.
1955struct BucketKey {
1956 BucketKey(c10::ScalarType type, c10::Device device)
1957 : type(type), device(device) {}
1958
1959 const c10::ScalarType type;
1960 const c10::Device device;
1961
1962 // See torch/csrc/utils/hash.h for dispatch code.
1963 static size_t hash(const BucketKey& key) {
1964 return c10::get_hash(key.type, key.device);
1965 }
1966};
1967
1968inline bool operator==(const BucketKey& lhs, const BucketKey& rhs) {
1969 return lhs.type == rhs.type && lhs.device == rhs.device;
1970}
1971
1972} // namespace
1973
1974std::tuple<std::vector<std::vector<size_t>>, std::vector<size_t>>
1975compute_bucket_assignment_by_size(
1976 const std::vector<at::Tensor>& tensors,
1977 const std::vector<size_t>& bucket_size_limits,
1978 const std::vector<bool>& expect_sparse_gradient,
1979 const std::vector<int64_t>& tensor_indices,
1980 const c10::optional<std::weak_ptr<c10d::Logger>>& logger) {
1981 // Either expect_sparse_gradient is not specified or it has as many elements
1982 // as the vector with tensors.
1983 TORCH_INTERNAL_ASSERT(
1984 expect_sparse_gradient.empty() ||
1985 (tensors.size() == expect_sparse_gradient.size()));
1986 TORCH_INTERNAL_ASSERT(!tensors.empty());
1987 // Store bucket indices and their sizes together, because we later sort the
1988 // resulting indices by minimum tensor index and want to keep sizes
1989 // consistent.
1990 std::vector<std::tuple<std::vector<size_t>, size_t>> result;
1991 // Sparse tensors go in their own bucket, so they do not have an enforced size
1992 // limit.
1993 size_t kNoSizeLimit = 0;
1994 result.reserve(tensors.size());
1995
1996 // Keep iterator into the size_limit vector by tensor type and device.
1997 // This is done so that we can use the consecutive bucket limits per type.
1998 std::unordered_map<
1999 BucketKey,
2000 std::vector<size_t>::const_iterator,
2001 c10::hash<BucketKey>>
2002 bucket_size_limit_iterators;
2003
2004 // Keep vector of indices and size accumulator by tensor type and device.
2005 std::unordered_map<BucketKey, BucketAccumulator, c10::hash<BucketKey>>
2006 buckets;
2007
2008 for (const auto i : c10::irange(tensors.size())) {
2009 const auto& tensor = tensors[i];
2010 auto msg = std::string("No support for sparse tensors.");
2011 if (logger.has_value()) {
2012 REDUCER_CHECK(!tensor.is_sparse(), logger.value(), msg);
2013 } else {
2014 TORCH_CHECK(!tensor.is_sparse(), msg);
2015 }
2016
2017 // when tensor_indices is empty, the index of tensors[i] assigned to
2018 // bucket is i, otherwise the tensor index is tensor_indices[i].
2019 auto tensor_index = i;
2020 if (!tensor_indices.empty()) {
2021 tensor_index = tensor_indices[i];
2022 }
2023 // If we expect a sparse gradient to be produced for this tensor, it cannot
2024 // be grouped together with other gradients and gets its own bucket.
2025 if (!expect_sparse_gradient.empty() &&
2026 expect_sparse_gradient[tensor_index]) {
2027 result.emplace_back(std::vector<size_t>({tensor_index}), kNoSizeLimit);
2028 continue;
2029 }
2030
2031 auto key = BucketKey(tensor.scalar_type(), tensor.device());
2032 auto& bucket = buckets[key];
2033 bucket.indices.push_back(tensor_index);
2034 bucket.size += tensor.numel() * tensor.element_size();
2035
2036 // Initialize bucket size limit iterator if necessary.
2037 if (bucket_size_limit_iterators.count(key) == 0) {
2038 bucket_size_limit_iterators[key] = bucket_size_limits.begin();
2039 }
2040
2041 auto& bucket_size_limit_iterator = bucket_size_limit_iterators[key];
2042 const auto bucket_size_limit = *bucket_size_limit_iterator;
2043 bucket.size_limit = bucket_size_limit;
2044 if (bucket.size >= bucket_size_limit) {
2045 result.emplace_back(std::move(bucket.indices), bucket.size_limit);
2046 bucket = BucketAccumulator();
2047
2048 // Advance to the next bucket size limit for this type/device.
2049 auto next = bucket_size_limit_iterator + 1;
2050 if (next != bucket_size_limits.end()) {
2051 bucket_size_limit_iterator = next;
2052 }
2053 }
2054 }
2055
2056 // Add remaining buckets.
2057 for (auto& it : buckets) {
2058 auto& bucket = it.second;
2059 if (!bucket.indices.empty()) {
2060 result.emplace_back(std::move(bucket.indices), bucket.size_limit);
2061 }
2062 }
2063
2064 // If tensor_indices is not empty, the order of the tensors is in the gradient
2065 // ready order, so no need to sort.
2066 // If tensor_indices is empty, sort resulting buckets by the minimum tensor
2067 // index they include. We assume that the order of the tensors is the order in
2068 // which they are used (or the reverse order in which their gradients are
2069 // produced). This sorting step ensures that the buckets are ready in
2070 // consecutive order.
2071 if (tensor_indices.empty()) {
2072 std::sort(
2073 result.begin(),
2074 result.end(),
2075 [](const std::tuple<std::vector<size_t>, size_t>& a,
2076 const std::tuple<std::vector<size_t>, size_t>& b) {
2077 auto indices_a = std::get<0>(a);
2078 auto indices_b = std::get<0>(b);
2079 const auto amin =
2080 std::min_element(indices_a.begin(), indices_a.end());
2081 const auto bmin =
2082 std::min_element(indices_b.begin(), indices_b.end());
2083 return *amin < *bmin;
2084 });
2085 }
2086
2087 // Return bucket indices and size limits as separate entries in tuple, as some
2088 // APIs only need to consume bucket indices.
2089 std::vector<std::vector<size_t>> bucket_indices;
2090 bucket_indices.reserve(result.size());
2091 std::vector<size_t> per_bucket_size_limits;
2092 per_bucket_size_limits.reserve(result.size());
2093 for (const auto& bucket_indices_with_size : result) {
2094 bucket_indices.emplace_back(std::get<0>(bucket_indices_with_size));
2095 per_bucket_size_limits.emplace_back(std::get<1>(bucket_indices_with_size));
2096 }
2097 return std::make_tuple(bucket_indices, per_bucket_size_limits);
2098}
2099
2100// Verifies corresponding params in the model replica have the same
2101// sizes/strides across processes.
2102void verify_params_across_processes(
2103 const c10::intrusive_ptr<c10d::ProcessGroup>& process_group,
2104 const std::vector<at::Tensor>& params,
2105 const c10::optional<std::weak_ptr<c10d::Logger>>& logger) {
2106 // First verify number of parameters to avoid inconsistent inputs into
2107 // broadcast which can cause a crash.
2108 // See https://github.com/pytorch/pytorch/issues/73547
2109
2110 at::TensorOptions param_size_options;
2111 param_size_options = param_size_options.dtype(at::kLong);
2112 param_size_options = param_size_options.device(params[0].device());
2113 // Note: Not using tensor building API because of
2114 // https://github.com/pytorch/pytorch/issues/74114
2115 at::Tensor param_size_tensor =
2116 at::tensor({static_cast<int64_t>(params.size())}, param_size_options);
2117
2118 // Allgather and verify parameter size.
2119 std::vector<std::vector<at::Tensor>> param_size_output_tensors;
2120 param_size_output_tensors.emplace_back();
2121 auto world_size = process_group->getSize();
2122 for (size_t i = 0; i < world_size; ++i) {
2123 param_size_output_tensors.front().emplace_back(
2124 at::empty_like(param_size_tensor));
2125 }
2126
2127 std::vector<at::Tensor> param_size_vec{param_size_tensor};
2128 process_group->allgather(param_size_output_tensors, param_size_vec)->wait();
2129 auto result_size_tensors = param_size_output_tensors.front();
2130 for (size_t i = 0; i < world_size; ++i) {
2131 auto param_size_for_rank = result_size_tensors[i][0].item<int>();
2132 TORCH_CHECK(
2133 param_size_for_rank == params.size(),
2134 c10::str(
2135 "DDP expects same model across all ranks, but Rank ",
2136 process_group->getRank(),
2137 " has ",
2138 params.size(),
2139 " params, while rank ",
2140 i,
2141 " has inconsistent ",
2142 param_size_for_rank,
2143 " params."));
2144 }
2145
2146 // Continue with parameter shape verification.
2147 size_t i = 0;
2148 for (const auto& t : params) {
2149 i += 2 * t.dim();
2150 }
2151 at::TensorOptions options;
2152 options = options.dtype(at::kLong);
2153 auto metadata = at::empty({static_cast<long>(i)}, options);
2154
2155 // Technically, process 0 is the broadcast source, so only process 0 needs
2156 // to populate metadata. But no harm keeping work aligned across processes.
2157 auto metadata_accessor = metadata.accessor<int64_t, 1>();
2158 i = 0;
2159 for (const auto& t : params) {
2160 for (const auto& sz : t.sizes()) {
2161 metadata_accessor[i++] = sz;
2162 }
2163 for (const auto& str : t.strides()) {
2164 metadata_accessor[i++] = str;
2165 }
2166 }
2167
2168 auto metadata_dev = metadata.clone().to(params[0].device());
2169 std::vector<at::Tensor> vec{metadata_dev};
2170 process_group->broadcast(vec)->wait();
2171
2172 // Technically, process 0 doesn't need to double-check metadata, because it
2173 // was the source. But no harm keeping work aligned.
2174 auto control = at::empty({static_cast<long>(i)}, options);
2175 control.copy_(metadata_dev, /*non_blocking=*/false);
2176 auto control_accessor = control.accessor<int64_t, 1>();
2177 i = 0;
2178 for (const auto p : c10::irange(params.size())) {
2179 const auto& t = params[p];
2180 for (const auto& sz : t.sizes()) {
2181 auto msg = c10::str(
2182 "[",
2183 process_group->getRank(),
2184 "]: params[",
2185 p,
2186 "] in this process",
2187 " with sizes ",
2188 t.sizes(),
2189 " appears not to match sizes of the same param in process 0.");
2190 if (logger.has_value()) {
2191 REDUCER_CHECK(sz == control_accessor[i++], logger.value(), msg)
2192 } else {
2193 TORCH_CHECK(sz == control_accessor[i++], msg)
2194 }
2195 }
2196 for (const auto& str : t.strides()) {
2197 auto msg = c10::str(
2198 "params[",
2199 p,
2200 "] in this process",
2201 " with sizes ",
2202 t.sizes(),
2203 " appears not to match strides of the same param in process 0.");
2204 if (logger.has_value()) {
2205 REDUCER_CHECK(str == control_accessor[i++], logger.value(), msg)
2206 } else {
2207 TORCH_CHECK(str == control_accessor[i++], msg)
2208 }
2209 }
2210 }
2211}
2212
2213} // namespace c10d
2214