1 | #include <torch/csrc/distributed/c10d/reducer.hpp> |
2 | |
3 | #include <torch/csrc/distributed/c10d/Utils.hpp> |
4 | #include <torch/csrc/distributed/c10d/default_comm_hooks.hpp> |
5 | |
6 | #include <functional> |
7 | |
8 | #include <c10/core/DeviceGuard.h> |
9 | #include <c10/core/StreamGuard.h> |
10 | #include <c10/util/Exception.h> |
11 | #include <c10/util/Logging.h> |
12 | #include <c10/util/hash.h> |
13 | #include <c10/util/irange.h> |
14 | #include <torch/csrc/autograd/engine.h> |
15 | #include <torch/csrc/autograd/function_hook.h> |
16 | #include <torch/csrc/autograd/functions/accumulate_grad.h> |
17 | #include <torch/csrc/autograd/profiler.h> |
18 | #include <torch/csrc/autograd/utils/grad_layout_contract.h> |
19 | #include <torch/csrc/autograd/utils/lambda_post_hook.h> |
20 | #include <torch/csrc/distributed/c10d/comm.hpp> |
21 | #include <torch/csrc/distributed/c10d/logger.hpp> |
22 | #include <torch/csrc/utils/memory.h> |
23 | |
24 | namespace c10d { |
25 | namespace { |
26 | |
27 | constexpr int kUnsetDivFactor = -1; |
28 | |
29 | // Macro that wraps TORCH_CHECK with DDP logging. |
30 | #define REDUCER_CHECK(cond, logger_, ...) \ |
31 | if (C10_UNLIKELY_OR_CONST(!(cond))) { \ |
32 | if (!logger_.expired()) { \ |
33 | logger_.lock()->set_error_and_log(__VA_ARGS__); \ |
34 | } \ |
35 | TORCH_CHECK(false, ##__VA_ARGS__); \ |
36 | } |
37 | |
38 | } // namespace |
39 | |
40 | C10_DEFINE_TYPED_REGISTRY( // NOLINT |
41 | TimerRegistry, |
42 | c10::DeviceType, |
43 | Timer, |
44 | std::unique_ptr, |
45 | c10::Device); |
46 | |
47 | namespace { |
48 | |
49 | class CpuTimer : public Timer { |
50 | public: |
51 | explicit CpuTimer(c10::Device /* unused */) {} |
52 | |
53 | c10::optional<int64_t> measureDifference(Event start, Event end) override { |
54 | int64_t start_time = getTimeRef(start); |
55 | int64_t end_time = getTimeRef(end); |
56 | // If cpu_end_time is not recorded in this iteration, |
57 | // avg_time will return invalid value. |
58 | // For some cases like DDP runs on non-sync mode, backward compute |
59 | // end time can not be recorded in this iteration and thus can not |
60 | // calculate the valid avg_time. |
61 | // In this case, skip calculating the avg_time and return. |
62 | if (end_time < start_time) { |
63 | return c10::nullopt; |
64 | } |
65 | return end_time - start_time; |
66 | } |
67 | }; |
68 | |
69 | C10_REGISTER_TYPED_CLASS(TimerRegistry, c10::kCPU, CpuTimer); |
70 | |
71 | std::vector<at::Tensor> (const c10::IValue& result) { |
72 | if (result.isPyObject()) { |
73 | return result.toPyObjectHolder()->extractTensors(); |
74 | } |
75 | TORCH_INTERNAL_ASSERT( |
76 | result.isTensor() || result.isTensorList(), |
77 | "expected the hook result is either a Tensor or a TensorList found " , |
78 | result.tagKind()); |
79 | |
80 | if (result.isTensor()) { |
81 | return {result.toTensor()}; |
82 | } |
83 | |
84 | return result.toTensorVector(); |
85 | } |
86 | |
87 | } // namespace |
88 | |
89 | Reducer::Reducer( |
90 | std::vector<at::Tensor> params, |
91 | std::vector<std::vector<size_t>> bucket_indices, |
92 | std::vector<size_t> per_bucket_size_limits, |
93 | c10::intrusive_ptr<c10d::ProcessGroup> process_group, |
94 | std::vector<bool> expect_sparse_gradients, |
95 | int64_t bucket_bytes_cap, |
96 | bool find_unused_parameters, |
97 | bool gradient_as_bucket_view, |
98 | std::unordered_map<size_t, std::string> param_names, |
99 | int64_t first_bucket_bytes_cap) |
100 | : params_(std::move(params)), |
101 | process_group_(std::move(process_group)), |
102 | expect_sparse_gradients_(std::move(expect_sparse_gradients)), |
103 | expect_autograd_hooks_(false), |
104 | require_finalize_(false), |
105 | next_bucket_(0), |
106 | has_marked_unused_parameters_(false), |
107 | find_unused_parameters_(find_unused_parameters), |
108 | gradient_as_bucket_view_(gradient_as_bucket_view), |
109 | local_used_map_reduced_(false), |
110 | num_iterations_(0), |
111 | num_buckets_ready_(0), |
112 | has_rebuilt_bucket_(false), |
113 | bucket_bytes_cap_(bucket_bytes_cap), |
114 | div_factor_(kUnsetDivFactor), |
115 | static_graph_(false), |
116 | comm_hook_(nullptr), |
117 | ddp_debug_level_(debug_level()), |
118 | param_names_(std::move(param_names)), |
119 | first_bucket_bytes_cap_(first_bucket_bytes_cap) { |
120 | C10_LOG_API_USAGE_ONCE("torch.distributed.ddp.reducer" ); |
121 | TORCH_INTERNAL_ASSERT(!params_.empty(), "Expected at least one parameter." ); |
122 | |
123 | if (ddp_debug_level_ != c10d::DebugLevel::Off) { |
124 | LOG(INFO) << "Reducer initialized with bucket_bytes_cap: " |
125 | << bucket_bytes_cap_ |
126 | << " first_bucket_bytes_cap: " << first_bucket_bytes_cap; |
127 | } |
128 | // Check whether the module is multi_device_module |
129 | { |
130 | std::set<int> unique_devices; |
131 | for (const auto& v : params_) { |
132 | auto device_idx = int(v.device().index()); |
133 | if (unique_devices.find(device_idx) == unique_devices.end()) { |
134 | unique_devices.insert(device_idx); |
135 | if (unique_devices.size() > 1) { |
136 | is_multi_device_module_ = true; |
137 | break; |
138 | } |
139 | } |
140 | } |
141 | } |
142 | |
143 | // For CUDA, record events only for single device module. |
144 | c10::Device device = params_[0].device(); |
145 | if (!(device.is_cuda() && is_multi_device_module_)) { |
146 | timer_ = TimerRegistry()->Create(device.type(), device); |
147 | } |
148 | |
149 | // If `expect_sparse_gradients` is not specified, initialize it such that |
150 | // we do not expect sparse gradients for any parameter. |
151 | if (expect_sparse_gradients_.empty()) { |
152 | expect_sparse_gradients_ = std::vector<bool>(params_.size(), false); |
153 | } |
154 | TORCH_INTERNAL_ASSERT(expect_sparse_gradients_.size() == params_.size()); |
155 | |
156 | // Initialize variable bucketing. |
157 | // This can be reinitialized later after capturing runtime information. |
158 | { |
159 | std::lock_guard<std::mutex> lock(mutex_); |
160 | initialize_buckets(std::move(bucket_indices)); |
161 | } |
162 | |
163 | // All variables are expected to have their `grad_fn` set to the gradient |
164 | // accumulation function (since they are leafs in the autograd graph). |
165 | // We store pointers to these functions such that we can check if they are |
166 | // used in an autograd pass. If they are not, we know their grad tensors |
167 | // can be marked as ready for reduction. |
168 | { |
169 | const auto variable_count = params_.size(); |
170 | grad_accumulators_.resize(variable_count); |
171 | for (const auto variable_index : c10::irange(variable_count)) { |
172 | auto& variable = params_[variable_index]; |
173 | |
174 | // The gradient accumulator function is lazily initialized once. |
175 | // Therefore we can use its presence in the autograd graph as |
176 | // evidence that the parameter has participated in an iteration. |
177 | auto grad_accumulator = torch::autograd::impl::grad_accumulator(variable); |
178 | |
179 | #ifndef _WIN32 |
180 | using torch::distributed::autograd::ThreadLocalDistAutogradContext; |
181 | #endif |
182 | // Hook to execute after the gradient accumulator has executed. |
183 | hooks_.emplace_back( |
184 | grad_accumulator->add_post_hook( |
185 | torch::make_unique<torch::autograd::utils::LambdaPostHook>( |
186 | [=](const torch::autograd::variable_list& outputs, |
187 | const torch::autograd::variable_list& /* unused */) { |
188 | #ifndef _WIN32 |
189 | this->rpc_context_.set( |
190 | ThreadLocalDistAutogradContext::getContextPtr()); |
191 | #endif |
192 | this->autograd_hook(variable_index); |
193 | return outputs; |
194 | })), |
195 | grad_accumulator); |
196 | |
197 | // Map raw function pointer to parameter index. |
198 | // This is used later on when the autograd graph is traversed |
199 | // to check for parameters for which no gradient is computed, if |
200 | // find_unused_parameters=True. |
201 | // Note that the mapping of gradient accumulator to variable should be |
202 | // one to one as we deduplicate shared parameters before constructing |
203 | // Reducer. |
204 | if (find_unused_parameters_) { |
205 | gradAccToVariableMap_[grad_accumulator.get()] = variable_index; |
206 | } |
207 | |
208 | numGradHooksTriggeredMap_[variable_index] = 0; |
209 | |
210 | // The gradient accumulator is stored as weak_ptr in the autograd |
211 | // metadata of the variable, so we have to keep it alive here for |
212 | // the raw pointer to be valid. |
213 | REDUCER_CHECK( |
214 | grad_accumulators_[variable_index] == nullptr, |
215 | logger_, |
216 | c10::str( |
217 | "Reducer tried to register duplicate grad accumulator for variable " , |
218 | variable_index)); |
219 | |
220 | grad_accumulators_[variable_index] = std::move(grad_accumulator); |
221 | } |
222 | } |
223 | |
224 | // Initialize backward stats vector. |
225 | { |
226 | const auto variable_count = params_.size(); |
227 | backward_stats_.resize(variable_count); |
228 | } |
229 | |
230 | // See Note [Skip allreducing local_used_map_dev] |
231 | if (find_unused_parameters_) { |
232 | initialize_local_used_map(); |
233 | } |
234 | } |
235 | |
236 | // Note [Skip allreducing local_used_map_dev] |
237 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~ |
238 | // If find_unused_parameters_ is set to false, there is no need to allreduce |
239 | // local_used_map_dev_, because all parameters will be reduced anyway. |
240 | // Therefore, we can avoid allocating memory for local_used_map and |
241 | // local_used_map_dev_ if find_unused_parameters_ is false. |
242 | |
243 | // Note [DDP Communication Hook] |
244 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~ |
245 | // If DDP communication hook is not registered, the reducer reduces the buckets |
246 | // by just calling allreduce. If registered, it calls the hook and uses future |
247 | // work handle. If registered, reducer also skips dividing grads by world size. |
248 | // The reason for this is that the communication hook is expected to completely |
249 | // override how we perform communication and the user should have complete |
250 | // control over how the grads are handled. |
251 | // |
252 | // DDP communication hook is an enhancement that provides a hook which can be |
253 | // used to override how DDP communicates gradients across ranks, this can be |
254 | // used for algorithms like Gradient Compression/GossipGrad. This hook can be |
255 | // registered from Python API using `register_comm_hook`. `PythonCommHook` |
256 | // enables registering a Python hook and is a subclass of `CommHookInterface`. |
257 | // Additionally, there are also some built-in C++ hook implementations that can |
258 | // be specified by calling `register_builtin_comm_hook` from Python API. |
259 | |
260 | Reducer::~Reducer() noexcept(false) { |
261 | // Remove all hooks on variables registered by this Reducer. This is necessary |
262 | // to make DDP failure recoverable. Otherwise, multiple Reducer instances |
263 | // (from recoveries) will add their hooks to the original model, and those |
264 | // hooks will try to invoke methods on a deleted Reducer objects. |
265 | for (auto& hook : hooks_) { |
266 | auto& key = hook.first; |
267 | auto& grad_accumulator = hook.second; |
268 | |
269 | TORCH_INTERNAL_ASSERT( |
270 | grad_accumulator->del_post_hook(key), |
271 | "Reducer attempts to delete a non-existing hook." ); |
272 | } |
273 | } |
274 | |
275 | bool Reducer::dynamic_graph_find_unused() { |
276 | return !static_graph_ && find_unused_parameters_; |
277 | } |
278 | |
279 | bool Reducer::static_graph_first_iteration() { |
280 | return static_graph_ && num_iterations_ == 1; |
281 | } |
282 | |
283 | bool Reducer::static_graph_after_first_iteration() { |
284 | return static_graph_ && num_iterations_ > 1; |
285 | } |
286 | |
287 | bool Reducer::ddp_graph_static() { |
288 | std::lock_guard<std::mutex> lock(mutex_); |
289 | return ddp_graph_static_; |
290 | } |
291 | |
292 | void Reducer::initialize_local_used_map() { |
293 | const auto variable_count = params_.size(); |
294 | at::TensorOptions options; |
295 | options = options.dtype(at::kInt); |
296 | |
297 | // Deliberately don't pin the memory even if local_used_map_dev_ will |
298 | // be cuda. See Note [local_used_map_ -> local_used_map_dev copying] |
299 | local_used_map_ = at::zeros({static_cast<long>(variable_count)}, options); |
300 | |
301 | // This tensor needs to be on the same device as the replica params because |
302 | // backend such as NCCL may not support CPU tensors, and hence it might not |
303 | // work if we always put it on CPU. |
304 | options = options.device(params_[0].device()); |
305 | local_used_map_dev_ = at::empty({static_cast<long>(variable_count)}, options); |
306 | } |
307 | |
308 | void Reducer::check_grad_layout( |
309 | const at::Tensor& grad, |
310 | const at::Tensor& bucket_view) { |
311 | // Ensure that the gradient type matches the bucket type. |
312 | REDUCER_CHECK( |
313 | grad.options().type_equal(bucket_view.options()), |
314 | logger_, |
315 | c10::str("Expected " , bucket_view.toString(), ", got " , grad.toString())); |
316 | |
317 | TORCH_INTERNAL_ASSERT(grad.device() == bucket_view.device()); |
318 | TORCH_INTERNAL_ASSERT(grad.numel() == bucket_view.numel()); |
319 | // AccumulateGrad doesn't HAVE to obey the grad layout contract. |
320 | // The penalty for disobedience is reduced performance, not numerical |
321 | // death. Warnings here help diagnose poor DDP performance. |
322 | if (grad.strides() != bucket_view.strides()) { |
323 | TORCH_WARN_ONCE( |
324 | "Grad strides do not match bucket view strides. " |
325 | "This may indicate grad was not created according to the " |
326 | "gradient layout contract, or that the param's strides " |
327 | "changed since DDP was constructed. This is not an error, " |
328 | "but may impair performance.\n" |
329 | "grad.sizes() = " , |
330 | grad.sizes(), |
331 | ", strides() = " , |
332 | grad.strides(), |
333 | "\n" , |
334 | "bucket_view.sizes() = " , |
335 | bucket_view.sizes(), |
336 | ", strides() = " , |
337 | bucket_view.strides()); |
338 | } |
339 | if (!gradient_as_bucket_view_) { |
340 | TORCH_INTERNAL_ASSERT(!grad.is_alias_of(bucket_view)); |
341 | } |
342 | } |
343 | |
344 | void Reducer::mark_variable_ready_dense(size_t variable_index) { |
345 | const auto& bucket_index = variable_locators_[variable_index]; |
346 | auto& bucket = buckets_[bucket_index.bucket_index]; |
347 | auto& variable = bucket.variables[bucket_index.intra_bucket_index]; |
348 | auto& bucket_view = bucket.bucket_views_in[bucket_index.intra_bucket_index]; |
349 | |
350 | // Copy the contents of the gradient tensor to the corresponding part of the |
351 | // bucket's flattened gradient tensor. |
352 | // If the gradient is not set, we assume it wasn't computed as part of the |
353 | // current backwards pass, and we zero the part of the bucket it would |
354 | // otherwise hold. |
355 | runGradCallbackForVariable(variable, [&](auto& grad) { |
356 | if (grad.defined()) { |
357 | this->check_grad_layout(grad, bucket_view); |
358 | // When gradient_as_bucket_view_ is false, or even when |
359 | // gradient_as_bucket_view_ is true, in rare cases users may set grad to |
360 | // be None after every iteration. In these cases, grad and bucket_view are |
361 | // pointing to different storages and thus need to copy grads to |
362 | // bucket_view. If gradient_as_bucket_view_ is set as true, let grad point |
363 | // to bucket_view. If grad has already been set as views of buckets in |
364 | // previous iterations, no copy is needed. |
365 | if (!grad.is_alias_of(bucket_view)) { |
366 | if (comm_hook_ == nullptr) { |
367 | auto wrapped = |
368 | at::native::wrapped_scalar_tensor(double(1.) / div_factor_); |
369 | if (!grad.requires_grad()) { |
370 | // Divides while copying into the bucket view to save one scan over |
371 | // all the input parameters. |
372 | at::mul_out(bucket_view, grad, wrapped); |
373 | } else { |
374 | // If DDP is running with create_graph=True, gradients require_grad |
375 | // themselves in order to compute higher order derivatives. However, |
376 | // DDP will not sync up these gradients currently (see |
377 | // https://github.com/pytorch/pytorch/issues/63812). |
378 | C10_LOG_EVERY_N(WARNING, 1000) |
379 | << "Using DistributedDataParallel with create_graph=True " |
380 | << " is not well-supported. The higher-order gradient will " |
381 | << " not be synchronized across ranks, and backpropagation " |
382 | << " through all_reduce operations will not occur. If you require " |
383 | << " DDP to work with higher-order gradients for your use case, " |
384 | << " please ping https://github.com/pytorch/pytorch/issues/63929" ; |
385 | auto div_result = at::mul(grad, wrapped); |
386 | bucket_view.copy_(div_result); |
387 | } |
388 | } else { |
389 | bucket_view.copy_(grad); |
390 | } |
391 | |
392 | if (gradient_as_bucket_view_) { |
393 | // Let grad point to bucket_view buffer. |
394 | grad = bucket_view; |
395 | // The grad is modified and need to be written back. |
396 | return true; |
397 | } |
398 | } else { |
399 | // If grad and bucket view point to the same storage, no need to copy. |
400 | if (comm_hook_ == nullptr) { |
401 | bucket_view.div_(div_factor_); |
402 | } |
403 | } |
404 | } else { |
405 | // Gradient is undefined. When find_unused_parameters=True, ensure it is |
406 | // not marked as locally used, otherwise we will be allreducing zero's |
407 | // instead of not touching .grad field of parameter. |
408 | if (this->dynamic_graph_find_unused() || |
409 | this->static_graph_first_iteration()) { |
410 | REDUCER_CHECK( |
411 | local_used_map_[variable_index].item<int>() == 0, |
412 | logger_, |
413 | "Encountered gradient which is undefined, but still allreduced by " |
414 | "DDP reducer. This indicates a bug in DDP implementation, please " |
415 | "report a bug with a repro to PyTorch." ); |
416 | } |
417 | bucket_view.zero_(); |
418 | } |
419 | // The grad is not modified and doesn't need to be written back. |
420 | return false; |
421 | }); |
422 | } |
423 | |
424 | void Reducer::mark_variable_ready_sparse(size_t variable_index) { |
425 | const auto& bucket_index = variable_locators_[variable_index]; |
426 | auto& bucket = buckets_[bucket_index.bucket_index]; |
427 | auto& variable = bucket.variables[bucket_index.intra_bucket_index]; |
428 | |
429 | runGradCallbackForVariable(variable, [&](auto& grad) { |
430 | REDUCER_CHECK( |
431 | grad.defined(), logger_, "Expected sparse gradient to be defined." ); |
432 | REDUCER_CHECK( |
433 | grad.options().layout() == c10::kSparse, |
434 | logger_, |
435 | "Expected variable to have sparse gradient." ); |
436 | |
437 | // Sparse tensors cannot be grouped together with other sparse tensors in a |
438 | // single reduction operation like we can for dense tensors. Therefore, the |
439 | // `offsets` and `lengths` vectors in the bucket struct are empty, and |
440 | // there is no pre-existing accumulation tensor. |
441 | // Directly assign the sparse tensor to the `gradients` field. |
442 | bucket.gradients = grad; |
443 | // If no DDP comm hook is registered, the allreduce only sums up the |
444 | // value, and a separate division is required. |
445 | if (comm_hook_ == nullptr) { |
446 | bucket.gradients.div_(div_factor_); |
447 | } |
448 | // The grad is modified in place and needs to be written back. |
449 | return true; |
450 | }); |
451 | } |
452 | |
453 | std::vector<c10d::GradBucket> Reducer::get_grad_buckets( |
454 | bool return_zero_tensors) const { |
455 | std::lock_guard<std::mutex> lock(mutex_); |
456 | std::vector<c10d::GradBucket> gradBuckets; |
457 | gradBuckets.reserve(buckets_.size()); |
458 | for (const auto i : c10::irange(buckets_.size())) { |
459 | auto& bucket = buckets_[i]; |
460 | auto variables_for_bucket = get_variables_for_bucket(i, bucket); |
461 | gradBuckets.emplace_back( |
462 | i, |
463 | buckets_.size(), |
464 | return_zero_tensors ? at::zeros_like(bucket.gradients) |
465 | : bucket.gradients, |
466 | bucket.offsets, |
467 | bucket.lengths, |
468 | bucket.sizes_vec, |
469 | variables_for_bucket); |
470 | } |
471 | return gradBuckets; |
472 | } |
473 | |
474 | void Reducer::set_forward_pass_work_handle( |
475 | c10::intrusive_ptr<c10d::Work> forwardPassWorkHandle, |
476 | bool useStaticWorldSize) { |
477 | std::lock_guard<std::mutex> lock(mutex_); |
478 | forwardPassWorkHandle_.workHandle = std::move(forwardPassWorkHandle); |
479 | forwardPassWorkHandle_.useStaticWorldSize = useStaticWorldSize; |
480 | } |
481 | |
482 | at::Tensor Reducer::get_local_used_map_on_device() const { |
483 | std::lock_guard<std::mutex> lock(mutex_); |
484 | return local_used_map_dev_; |
485 | } |
486 | |
487 | void Reducer::push_rebuilt_params_for_all_indices() { |
488 | std::lock_guard<std::mutex> lock(mutex_); |
489 | if (!should_rebuild_buckets() || !rebuilt_param_indices_.empty()) { |
490 | return; |
491 | } |
492 | const auto variable_count = params_.size(); |
493 | for (const auto variable_index : c10::irange(variable_count)) { |
494 | push_rebuilt_params(variable_index); |
495 | } |
496 | } |
497 | |
498 | void Reducer::push_rebuilt_params(const size_t& index) { |
499 | rebuilt_params_.push_back(params_[index]); |
500 | rebuilt_param_indices_.push_back(index); |
501 | } |
502 | |
503 | void Reducer::set_divide_factor() { |
504 | // If it was scheduled, wait on allreduce in forward pass that tells us |
505 | // division factor based on no. of currently participating processes. |
506 | if (div_factor_ == kUnsetDivFactor) { |
507 | div_factor_ = process_group_->getSize(); |
508 | auto& workHandle = forwardPassWorkHandle_.workHandle; |
509 | if (workHandle && !forwardPassWorkHandle_.useStaticWorldSize) { |
510 | workHandle->wait(); |
511 | // PyProcessGroup::PyWork doesn't expose value, so fetch it from the |
512 | // future |
513 | auto results = extractTensors(workHandle->getFuture()->value()); |
514 | |
515 | // Guard against the results being empty |
516 | TORCH_INTERNAL_ASSERT(!results.empty()); |
517 | at::Tensor& res = results.front(); |
518 | div_factor_ = res.item().to<int>(); |
519 | } |
520 | } |
521 | } |
522 | |
523 | // Right now delay_all_reduce is only called when static_graph_=true and |
524 | // num_iterations_==1. |
525 | void Reducer::delay_all_reduce() { |
526 | std::lock_guard<std::mutex> lock(this->mutex_); |
527 | |
528 | if (should_collect_runtime_stats()) { |
529 | record_backward_compute_end_time(); |
530 | record_backward_comm_start_time(); |
531 | } |
532 | |
533 | // launch all reduce local used map |
534 | all_reduce_local_used_map(); |
535 | |
536 | // prepare to set unused_parameters_, if it is static graph, |
537 | // unused_parameters_ will not change after 1st iteration. |
538 | unused_parameters_.clear(); |
539 | |
540 | // copy all gradients to buckets |
541 | for (const auto variable_index : c10::irange(params_.size())) { |
542 | // set unused_parameters_ |
543 | if (numGradHooksTriggeredMap_[variable_index] == 0) { |
544 | unused_parameters_.push_back(variable_index); |
545 | } |
546 | require_finalize_ = true; |
547 | set_divide_factor(); |
548 | if (expect_sparse_gradients_[variable_index]) { |
549 | mark_variable_ready_sparse(variable_index); |
550 | } else { |
551 | mark_variable_ready_dense(variable_index); |
552 | } |
553 | } |
554 | |
555 | // To avoid confusion around why static graph is picking up |
556 | // some parameters as unused on a rank vs not, we log |
557 | // unused parameter names for each rank for better |
558 | // debugability when TORCH_DISTRIBUTED_DEBUG is set to |
559 | // INFO or DETAIL |
560 | if (ddp_debug_level_ != c10d::DebugLevel::Off) { |
561 | // construct one string to output |
562 | std::ostringstream unused_params_stream; |
563 | |
564 | for (const auto& unused_index : unused_parameters_) { |
565 | auto param_name = param_names_.find(unused_index); |
566 | TORCH_INTERNAL_ASSERT( |
567 | param_name != param_names_.end(), |
568 | "Expected to find parameter name from unused parameters map in debug mode." ); |
569 | // Add the param_name |
570 | unused_params_stream << "{" << param_name->second << "," << unused_index |
571 | << "}" ; |
572 | } |
573 | |
574 | // Each rank prints out all the unused parameters detected |
575 | if (!unused_parameters_.empty()) { |
576 | LOG(INFO) << "[Rank " << process_group_->getRank() << "]: " |
577 | << "Parameter(s) (in the format of {param_name, index}): " |
578 | << unused_params_stream.str() |
579 | << " is(are) unused during first iteration. Since" |
580 | << " static_graph=True is enabled for DDP, we expect" |
581 | << " this set of unused parameters to remain consistent" |
582 | << " on this rank throughout the training." ; |
583 | } |
584 | } |
585 | |
586 | // launch all reduces for all buckets |
587 | for (auto& bucket : buckets_) { |
588 | all_reduce_bucket(bucket); |
589 | } |
590 | |
591 | finalize_backward(); |
592 | } |
593 | |
594 | void Reducer::set_logger(std::weak_ptr<c10d::Logger> logger) { |
595 | logger_ = logger; |
596 | } |
597 | |
598 | // The function `autograd_hook` is called after the gradient for a |
599 | // model parameter has been accumulated into its gradient tensor. |
600 | // This function is only to be called from the autograd thread. |
601 | void Reducer::autograd_hook(size_t index) { |
602 | std::lock_guard<std::mutex> lock(this->mutex_); |
603 | // Ignore if we don't expect to be called. |
604 | // This may be the case if the user wants to accumulate gradients |
605 | // for number of iterations before reducing them. |
606 | if (!expect_autograd_hooks_) { |
607 | return; |
608 | } |
609 | |
610 | grad_ready_order_indices_.push_back(index); |
611 | |
612 | // See Note [Skip allreducing local_used_map_dev] |
613 | if (dynamic_graph_find_unused() || static_graph_first_iteration()) { |
614 | // Since it gets here, this param has been used for this iteration. We want |
615 | // to mark it in local_used_map_. During no_sync session, the same var can |
616 | // be set multiple times, which is OK as does not affect correctness. As |
617 | // long as it is used once during no_sync session, it is marked as used. |
618 | // Only set it as locally used if the grad is defined. Otherwise, hooks can |
619 | // be fired with undefined grads, such as when not all outputs are used in |
620 | // DDP when computing loss. In this case, we don't want to mark it as |
621 | // locally used to ensure we don't touch the parameter's .grad field. |
622 | auto& variable = get_param_from_index(index); |
623 | runGradCallbackForVariable(variable, [&](auto& grad) { |
624 | if (grad.defined()) { |
625 | local_used_map_[index] = 1; |
626 | } |
627 | // The gradient is never modified. |
628 | return false; |
629 | }); |
630 | } |
631 | |
632 | if (static_graph_first_iteration()) { |
633 | numGradHooksTriggeredMap_[index] += 1; |
634 | return; |
635 | } |
636 | |
637 | // If `find_unused_parameters_` is true there may be model parameters that |
638 | // went unused when computing the model output, they won't be part of the |
639 | // autograd graph, and won't receive gradients. These parameters are |
640 | // discovered in the `prepare_for_backward` function and their indexes stored |
641 | // in the `unused_parameters_` vector. |
642 | if (!has_marked_unused_parameters_) { |
643 | has_marked_unused_parameters_ = true; |
644 | for (const auto& unused_index : unused_parameters_) { |
645 | mark_variable_ready(unused_index); |
646 | } |
647 | } |
648 | |
649 | // Rebuild bucket only if 1) it is the first time to rebuild bucket 2) |
650 | // static_graph_ is true or find_unused_parameters_ is false, |
651 | // 3) this backward pass needs to run allreduce. |
652 | // Here, we just dump tensors and their parameter indices into |
653 | // rebuilt_params_ and rebuilt_param_indices_ based on gradient arriving |
654 | // order, and then at the end of finalize_backward(), buckets will be |
655 | // rebuilt based on rebuilt_params_ and rebuilt_param_indices_, and then |
656 | // will be broadcasted and initialized. |
657 | // If it is static graph, after 1st iteration, check if a variable |
658 | // is ready for communication based on numGradHooksTriggeredMap_. |
659 | if (static_graph_after_first_iteration()) { |
660 | REDUCER_CHECK( |
661 | numGradHooksTriggeredMapPerIteration_[index] > 0, |
662 | logger_, |
663 | "Your training graph has changed in this iteration, " , |
664 | "e.g., one parameter is unused in first iteration, but " , |
665 | "then got used in the second iteration. this is not " , |
666 | "compatible with static_graph set to True." ); |
667 | if (--numGradHooksTriggeredMapPerIteration_[index] == 0) { |
668 | if (should_rebuild_buckets()) { |
669 | push_rebuilt_params(index); |
670 | } |
671 | // Finally mark variable for which this function was originally called. |
672 | mark_variable_ready(index); |
673 | } |
674 | } else { |
675 | if (should_rebuild_buckets()) { |
676 | push_rebuilt_params(index); |
677 | } |
678 | // Finally mark variable for which this function was originally called. |
679 | mark_variable_ready(index); |
680 | } |
681 | } |
682 | |
683 | void Reducer::all_reduce_local_used_map() { |
684 | // See Note [Skip allreducing local_used_map_dev] |
685 | // H2D from local_used_map_ to local_used_map_dev_ |
686 | if (local_used_map_dev_.is_cuda()) { |
687 | // Note [local_used_map_ -> local_used_map_dev copying] |
688 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
689 | // We do async H2D to avoid the blocking overhead. The async copy and |
690 | // allreduce respect the current stream, so will be sequenced |
691 | // correctly. |
692 | // |
693 | // Correct sequencing with respect to host operations is also |
694 | // essential. The H2D copy_ is stream ordered, while the host's |
695 | // changes to local_used_map_ are host ordered. If a large backlog of |
696 | // cuda-stream work pushes the copy_ far into the future, and if no |
697 | // blocking calls occur between now and finalize_backward()** such |
698 | // that finalize_backward() re-zeroes local_used_map_ on the host |
699 | // before the stream executes the copy_, copy_ will read those zeros |
700 | // instead of the values we thought we told it to read here. Copying |
701 | // local_used_map_ to a pinned temporary (which the pinned caching |
702 | // allocator should supply asynchronously) avoids this nasty, rare |
703 | // race condition. |
704 | // |
705 | // ** In the hoped-for case where all params are used, DDP itself |
706 | // won't do any blocking work between now and the re-zeroing, so the |
707 | // danger is real. |
708 | // |
709 | // Defensively ensures local_used_map_tmp is distinct from |
710 | // local_used_map_ |
711 | auto local_used_map_tmp = at::native::empty_like( |
712 | local_used_map_, |
713 | optTypeMetaToScalarType(local_used_map_.options().dtype_opt()), |
714 | local_used_map_.options().layout_opt(), |
715 | local_used_map_.options().device_opt(), |
716 | true /* pinned_memory */); |
717 | // Paranoid asserts here because in some workloads, the pinned |
718 | // allocator behaves in a way we don't understand, and may be bugged. |
719 | // See https://github.com/pytorch/pytorch/pull/54474 |
720 | TORCH_INTERNAL_ASSERT(local_used_map_tmp.is_pinned()); |
721 | TORCH_INTERNAL_ASSERT( |
722 | local_used_map_tmp.data_ptr() != local_used_map_.data_ptr()); |
723 | local_used_map_tmp.copy_(local_used_map_); |
724 | local_used_map_dev_.copy_(local_used_map_tmp, true); |
725 | } else { |
726 | local_used_map_dev_.copy_(local_used_map_, true); |
727 | } |
728 | std::vector<at::Tensor> temp_local_used_map_dev_vec_ = {local_used_map_dev_}; |
729 | local_used_work_ = process_group_->allreduce(temp_local_used_map_dev_vec_); |
730 | } |
731 | |
732 | at::Tensor& Reducer::get_param_from_index(size_t index) { |
733 | const auto& bucket_index = variable_locators_[index]; |
734 | auto& bucket = buckets_[bucket_index.bucket_index]; |
735 | // Cannot simply access variable via `bucket.variables[variable_index]` since |
736 | // return value is used in `runGradCallbackForVariable()` which does not |
737 | // accept const tensors. |
738 | auto& variable = bucket.variables[bucket_index.intra_bucket_index]; |
739 | return variable; |
740 | } |
741 | |
742 | void Reducer::checkAndRaiseMarkedTwiceError(size_t index) { |
743 | // Something is wrong if all variables contained in this bucket have |
744 | // already been marked as ready. |
745 | // We don't expect the same variable to be marked ready twice. |
746 | bool marked_twice = |
747 | perIterationReadyParams_.find(index) != perIterationReadyParams_.end(); |
748 | |
749 | if (marked_twice) { |
750 | // Report index of param that has been marked twice. In debug mode, also |
751 | // report fully qualified parameter name. |
752 | auto param_name = param_names_.find(index); |
753 | const bool found_param_name = param_name != param_names_.end(); |
754 | TORCH_INTERNAL_ASSERT( |
755 | ddp_debug_level_ == c10d::DebugLevel::Off || found_param_name, |
756 | "Expected to find parameter name in debug mode." ); |
757 | std::string paramInfo = c10::str( |
758 | "Parameter at index " , |
759 | index, |
760 | found_param_name ? c10::str(" with name " , param_name->second) : "" , |
761 | " has been marked as ready twice. This means that multiple autograd engine " , |
762 | " hooks have fired for this particular parameter during this iteration." ); |
763 | // param_names_ is empty in debug mode. |
764 | if (!found_param_name) { |
765 | paramInfo += c10::str( |
766 | " You can set the environment variable TORCH_DISTRIBUTED_DEBUG to either" , |
767 | " INFO or DETAIL to print parameter names for further debugging." ); |
768 | } |
769 | std::string common_error = c10::str( |
770 | "Expected to mark a variable ready only once. " , |
771 | "" , |
772 | "This error is caused by one of the following reasons: " , |
773 | "1) Use of a module parameter outside the `forward` function. " , |
774 | "Please make sure model parameters are not shared across multiple " , |
775 | "concurrent forward-backward passes. or try to use _set_static_graph() " , |
776 | "as a workaround if this module graph does not change " , |
777 | "during training loop." , |
778 | "2) Reused parameters in multiple reentrant backward passes. For " , |
779 | "example, if you use multiple `checkpoint` functions to wrap the " , |
780 | "same part of your model, it would result in the same set of " , |
781 | "parameters been used by different reentrant backward passes " , |
782 | "multiple times, and hence marking a variable ready multiple times. " , |
783 | "DDP does not support such use cases in default. You can try to " , |
784 | "use _set_static_graph() as a workaround if your module graph " , |
785 | "does not change over iterations." ); |
786 | |
787 | common_error += c10::str("\n" , paramInfo); |
788 | |
789 | REDUCER_CHECK( |
790 | has_marked_unused_parameters_, |
791 | logger_, |
792 | common_error, |
793 | "3) Incorrect unused parameter detection. The return value of the " , |
794 | "`forward` function is inspected by the distributed data parallel " , |
795 | "wrapper to figure out if any of the module's parameters went " , |
796 | "unused. For unused parameters, DDP would not expect gradients from " , |
797 | "then. However, if an unused parameter becomes part of the autograd " , |
798 | "graph at a later point in time (e.g., in a reentrant backward when " , |
799 | "using `checkpoint`), the gradient will show up unexpectedly. If all " , |
800 | "parameters in the model participate in the backward pass, you can " , |
801 | "disable unused parameter detection by passing the keyword argument " , |
802 | "`find_unused_parameters=False` to " , |
803 | "`torch.nn.parallel.DistributedDataParallel`. If unused parameters " , |
804 | "in the model do not change over iterations, You can try to use " , |
805 | "_set_static_graph() as a workaround if this module graph does not " , |
806 | "change during training loop." ); |
807 | REDUCER_CHECK(!has_marked_unused_parameters_, logger_, common_error); |
808 | } |
809 | } |
810 | |
811 | void Reducer::mark_variable_ready(size_t variable_index) { |
812 | REDUCER_CHECK( |
813 | variable_index < variable_locators_.size(), |
814 | logger_, |
815 | "Out of range variable index." ); |
816 | |
817 | checkAndRaiseMarkedTwiceError(variable_index); |
818 | perIterationReadyParams_.insert(variable_index); |
819 | backward_stats_[variable_index] = |
820 | current_time_in_nanos() - backward_compute_start_time_; |
821 | |
822 | // Any time we mark a variable ready (be it in line due to unused parameters, |
823 | // or via an autograd hook), we require a call to the finalize function. If |
824 | // this doesn't happen before the next iteration (or call to |
825 | // `prepare_for_backwards`), we know something is wrong. |
826 | require_finalize_ = true; |
827 | |
828 | const auto& bucket_index = variable_locators_[variable_index]; |
829 | auto& bucket = buckets_[bucket_index.bucket_index]; |
830 | |
831 | set_divide_factor(); |
832 | |
833 | if (bucket.expect_sparse_gradient) { |
834 | mark_variable_ready_sparse(variable_index); |
835 | } else { |
836 | mark_variable_ready_dense(variable_index); |
837 | } |
838 | |
839 | // TODO(@pietern): Make this work for both CPU/CUDA tensors. |
840 | // When using CPU tensors we don't need to do this. |
841 | // Record event so that we can wait for all of them. |
842 | // auto& event = bucket.events[bucket_index.intra_bucket_index]; |
843 | // event.record(); |
844 | |
845 | // Check if this was the final gradient for this bucket. |
846 | if (--bucket.pending == 0) { |
847 | mark_bucket_ready(bucket_index.bucket_index); |
848 | } |
849 | |
850 | // Run finalizer function and kick off reduction for local_used_map once the |
851 | // final bucket was marked ready. |
852 | if (next_bucket_ == buckets_.size()) { |
853 | if (dynamic_graph_find_unused()) { |
854 | all_reduce_local_used_map(); |
855 | } |
856 | |
857 | torch::autograd::Engine::get_default_engine().queue_callback([=] { |
858 | std::lock_guard<std::mutex> lock(this->mutex_); |
859 | if (should_collect_runtime_stats()) { |
860 | record_backward_compute_end_time(); |
861 | } |
862 | // Check that all buckets were completed and had their work kicked off. |
863 | TORCH_INTERNAL_ASSERT(next_bucket_ == buckets_.size()); |
864 | if (static_graph_after_first_iteration() && should_rebuild_buckets()) { |
865 | for (const auto& unused_index : unused_parameters_) { |
866 | push_rebuilt_params(unused_index); |
867 | } |
868 | } |
869 | this->finalize_backward(); |
870 | }); |
871 | } |
872 | } |
873 | |
874 | c10::intrusive_ptr<c10::ivalue::Future> Reducer::run_comm_hook( |
875 | GradBucket& grad_bucket) { |
876 | if (comm_hook_ == nullptr) { |
877 | return run_allreduce_hook(grad_bucket); |
878 | } else { |
879 | return comm_hook_->runHook(grad_bucket); |
880 | } |
881 | } |
882 | |
883 | c10::intrusive_ptr<c10::ivalue::Future> Reducer::run_allreduce_hook( |
884 | GradBucket& grad_bucket) { |
885 | _AllReduceBySumCommHook allreduce_hook(process_group_); |
886 | return allreduce_hook.runHook(grad_bucket); |
887 | } |
888 | |
889 | void Reducer::all_reduce_bucket(Bucket& bucket) { |
890 | auto variables_for_bucket = get_variables_for_bucket(next_bucket_, bucket); |
891 | // TODO(@pietern): Ensure proper synchronization with the CUDA events |
892 | // that recorded copies into this `gradients` tensor. If these copies are |
893 | // executed on non-default streams, the current stream for the device |
894 | // that holds the `gradients` tensor must wait on these events. |
895 | // |
896 | // As long as autograd uses the default stream for every device, |
897 | // these operations are implicitly sequenced, and we don't need to |
898 | // do any extra synchronization here. |
899 | const auto& tensor = bucket.gradients; |
900 | |
901 | GradBucket grad_bucket( |
902 | next_bucket_, |
903 | buckets_.size(), |
904 | tensor, |
905 | bucket.offsets, |
906 | bucket.lengths, |
907 | bucket.sizes_vec, |
908 | variables_for_bucket); |
909 | bucket.future_work = run_comm_hook(grad_bucket); |
910 | } |
911 | |
912 | std::vector<at::Tensor> Reducer::get_variables_for_bucket( |
913 | size_t bucket_index, |
914 | const Bucket& bucket) const { |
915 | // Check if we have cached mapping previously. |
916 | if (has_rebuilt_bucket_ && |
917 | cached_variables_for_bucket_.find(bucket_index) != |
918 | cached_variables_for_bucket_.end()) { |
919 | return cached_variables_for_bucket_[bucket_index]; |
920 | } |
921 | std::vector<at::Tensor> variables_for_bucket; |
922 | variables_for_bucket.reserve(bucket.variable_indices.size()); |
923 | for (const auto& variable_index : bucket.variable_indices) { |
924 | // Grab bucket index where gradient is located using variable_locators_. |
925 | auto& bucket_index_for_variable = variable_locators_[variable_index]; |
926 | // Grab the actual model parameter. |
927 | auto& variable = |
928 | bucket.variables[bucket_index_for_variable.intra_bucket_index]; |
929 | variables_for_bucket.emplace_back(variable); |
930 | } |
931 | |
932 | if (has_rebuilt_bucket_) { |
933 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
934 | cached_variables_for_bucket_.find(bucket_index) == |
935 | cached_variables_for_bucket_.end()); |
936 | cached_variables_for_bucket_.insert( |
937 | {bucket_index, std::move(variables_for_bucket)}); |
938 | return cached_variables_for_bucket_[bucket_index]; |
939 | } else { |
940 | return variables_for_bucket; |
941 | } |
942 | } |
943 | |
944 | // Called when the bucket at the specified index is ready to be reduced. |
945 | void Reducer::mark_bucket_ready(size_t bucket_index) { |
946 | TORCH_INTERNAL_ASSERT(bucket_index >= next_bucket_); |
947 | |
948 | // Buckets are reduced in sequence. Ignore this bucket if |
949 | // it's not its turn to be reduced. |
950 | if (bucket_index > next_bucket_) { |
951 | return; |
952 | } |
953 | |
954 | // Keep going, until we either: |
955 | // - have kicked off reduction for all buckets, or |
956 | // - found a bucket that's not yet ready for reduction. |
957 | for (; next_bucket_ < buckets_.size() && buckets_[next_bucket_].pending == 0; |
958 | next_bucket_++) { |
959 | num_buckets_ready_++; |
960 | if (num_buckets_ready_ == 1 && should_collect_runtime_stats()) { |
961 | record_backward_comm_start_time(); |
962 | } |
963 | auto& bucket = buckets_[next_bucket_]; |
964 | all_reduce_bucket(bucket); |
965 | } |
966 | } |
967 | |
968 | void Reducer::install_futures( |
969 | c10::List<c10::intrusive_ptr<c10::ivalue::Future>> futs) { |
970 | // Append instead of overwrite so that this method can be called multiple |
971 | // times in one iteration. |
972 | if (!installed_futures_) { |
973 | installed_futures_ = std::move(futs); |
974 | } else { |
975 | installed_futures_->append(futs); |
976 | } |
977 | } |
978 | |
979 | void Reducer::initialize_buckets( |
980 | std::vector<std::vector<size_t>> bucket_indices) { |
981 | // If initialize_buckets is called inside DDP constructor, then |
982 | // it does not matter rpc context ptr is nullptr or not, as grad |
983 | // will not be mutated. |
984 | // If initialize_buckets is called during training loop, e.g, inside |
985 | // rebuild_buckets(), since grad could be mutated and be pointed to |
986 | // bucket_view, then it needs to check rpc context ptr is nullptr or not, |
987 | // If rpc context ptr is nullptr, mutate variable.grad(); otherwise, |
988 | // mutate grad in rpc context. |
989 | #ifndef _WIN32 |
990 | using torch::distributed::autograd::ThreadLocalDistAutogradContext; |
991 | this->rpc_context_.set(ThreadLocalDistAutogradContext::getContextPtr()); |
992 | #endif |
993 | |
994 | // This shouldn't be called if we're expecting autograd hooks to fire. |
995 | REDUCER_CHECK( |
996 | !expect_autograd_hooks_, |
997 | logger_, |
998 | "`initialize_buckets` must NOT be called during autograd execution." ); |
999 | |
1000 | // Clear current bucket assignment. |
1001 | buckets_.clear(); |
1002 | variable_locators_.clear(); |
1003 | |
1004 | // Ensure we have a bucket index for every variable. |
1005 | variable_locators_.resize(params_.size()); |
1006 | |
1007 | // Iterate over buckets. |
1008 | const auto bucket_count = bucket_indices.size(); |
1009 | buckets_.reserve(bucket_count); |
1010 | for (const auto bucket_index : c10::irange(bucket_count)) { |
1011 | Bucket bucket; |
1012 | |
1013 | // TODO(@pietern): Validate indices. |
1014 | // Must be non-empty, unique, and unique across buckets. |
1015 | REDUCER_CHECK( |
1016 | !bucket_indices[bucket_index].empty(), |
1017 | logger_, |
1018 | "Empty bucket specified." ); |
1019 | |
1020 | // Variables that expect sparse gradients must have their own bucket. |
1021 | if (bucket_indices[bucket_index].size() == 1) { |
1022 | const auto variable_index = bucket_indices[bucket_index].front(); |
1023 | bucket.expect_sparse_gradient = expect_sparse_gradients_[variable_index]; |
1024 | } else { |
1025 | for (const auto variable_index : bucket_indices[bucket_index]) { |
1026 | REDUCER_CHECK( |
1027 | !expect_sparse_gradients_[variable_index], |
1028 | logger_, |
1029 | "Buckets with more than one variable cannot include variables " , |
1030 | "that expect a sparse gradient." ); |
1031 | } |
1032 | } |
1033 | |
1034 | if (bucket.expect_sparse_gradient) { |
1035 | const auto variable_index = bucket_indices[bucket_index].front(); |
1036 | const auto& variable = params_[variable_index]; |
1037 | TORCH_INTERNAL_ASSERT(bucket_indices[bucket_index].size() == 1); |
1038 | bucket.variables = {variable}; |
1039 | } else { |
1040 | at::TensorOptions options; |
1041 | // The start index of the variable in the flattened tensor. |
1042 | size_t offset = 0; |
1043 | |
1044 | // Reserve enough space for the per-variable fields stored in the bucket |
1045 | // for efficiency. |
1046 | const size_t num_variables = bucket_indices[bucket_index].size(); |
1047 | bucket.variables.reserve(num_variables); |
1048 | bucket.offsets.reserve(num_variables); |
1049 | bucket.lengths.reserve(num_variables); |
1050 | bucket.sizes_vec.reserve(num_variables); |
1051 | |
1052 | // Iterate over bucket variables. |
1053 | for (const auto variable_index : bucket_indices[bucket_index]) { |
1054 | TORCH_INTERNAL_ASSERT( |
1055 | variable_index < params_.size(), |
1056 | "Out of range variable index specified." ); |
1057 | const auto& variable = params_[variable_index]; |
1058 | if (!options.has_device()) { |
1059 | options = options.device(variable.device()); |
1060 | } else { |
1061 | REDUCER_CHECK( |
1062 | variable.device() == options.device(), |
1063 | logger_, |
1064 | "All parameters in a bucket must be " , |
1065 | "placed on the same device." ); |
1066 | } |
1067 | if (!options.has_dtype()) { |
1068 | options = options.dtype(variable.dtype()); |
1069 | } else { |
1070 | REDUCER_CHECK( |
1071 | variable.dtype() == options.dtype(), |
1072 | logger_, |
1073 | "All parameters in a bucket must have the same dtype." ); |
1074 | } |
1075 | const auto length = variable.numel(); |
1076 | bucket.variables.push_back(variable); |
1077 | bucket.offsets.push_back(offset); |
1078 | bucket.lengths.push_back(length); |
1079 | bucket.sizes_vec.push_back(variable.sizes()); |
1080 | offset += length; |
1081 | } |
1082 | |
1083 | // Allocate the bucket's flattened `gradients` tensor. |
1084 | bucket.gradients = at::empty({static_cast<long>(offset)}, options); |
1085 | |
1086 | // Note: "Gradient Layout Contract" |
1087 | // |
1088 | // Here, create views into the `gradients` tensor for each variable's |
1089 | // grad. Views serve as entry points to `copy_()` each grad's data in/out |
1090 | // of the flattened `gradients` tensor. |
1091 | // |
1092 | // Gradients may have dense memory but non-row-major-contiguous strides |
1093 | // (e.g. channels_last or channels_last_3d). For coalesced accesses |
1094 | // during copy_s, it's beneficial for each view's layout to match its |
1095 | // grad's layout. |
1096 | // |
1097 | // Specifically, we expect torch/csrc/autograd/functions/accumulate_grad.h |
1098 | // produces grads that obey the "Gradient Layout Contract": |
1099 | // (1) if variable.is_non_overlapping_and_dense(), the stashed grad's |
1100 | // strides match variable. |
1101 | // (2) else, stashed grad is rowmajor contiguous. |
1102 | // and create views to match. |
1103 | // |
1104 | // If AccumulateGrad breaks the contract, and produces a grad with an |
1105 | // unexpected layout, performance will degrade due to poor memory access |
1106 | // patterns when copy_ing grad data in and out of its bucket view. |
1107 | // However, numerics remain correct, because the bucket view is the same |
1108 | // on either end of the raw allreduce. bucket_view_in.copy(grad) |
1109 | // tranposes |
1110 | // (+ densifies) to the bucket view's layout, the data is allreduced, |
1111 | // then grad.copy_(bucket_view_out) transposes it back to grad's layout. |
1112 | // |
1113 | // The only way the numerics can go haywire is if the bucket views |
1114 | // themselves have different layouts across processes. |
1115 | // Bucket views' sizes and strides are set based on param layouts, using |
1116 | // the same logic that (we expect) AccumulateGrad uses for their grads. |
1117 | // Therefore, the only way a bucket view could have different layouts in |
1118 | // different processes is if its param has a different layout in |
1119 | // different processes. We can check that param layouts match across |
1120 | // processes in Reducer's constructor by allreducing some metadata. |
1121 | // Checking just once won't catch if someone messes with |
1122 | // param layouts over time, but not messing with params after DDP |
1123 | // construction is already a documented constraint. |
1124 | initialize_bucket_views(bucket); |
1125 | } |
1126 | |
1127 | // Map participating variables to this bucket. |
1128 | size_t intra_bucket_index = 0; |
1129 | for (const auto variable_index : bucket_indices[bucket_index]) { |
1130 | TORCH_INTERNAL_ASSERT( |
1131 | variable_index < variable_locators_.size(), |
1132 | "Out of range variable index specified." ); |
1133 | variable_locators_[variable_index] = |
1134 | VariableLocator(bucket_index, intra_bucket_index++); |
1135 | } |
1136 | bucket.variable_indices = std::move(bucket_indices[bucket_index]); |
1137 | |
1138 | buckets_.push_back(std::move(bucket)); |
1139 | } |
1140 | } |
1141 | |
1142 | // (see Note: "Gradient Layout Contract" in initialize_buckets). |
1143 | void Reducer::initialize_bucket_views(Reducer::Bucket& bucket) { |
1144 | const auto& gradients = bucket.gradients; |
1145 | for (const auto i : c10::irange(bucket.variables.size())) { |
1146 | auto& v = bucket.variables[i]; |
1147 | const auto offset = bucket.offsets[i]; |
1148 | const auto length = bucket.lengths[i]; |
1149 | if (v.is_non_overlapping_and_dense()) { |
1150 | // If the param's memory is dense, match its layout, anticipating |
1151 | // the autograd engine (AccumulateGrad) will also create gradients |
1152 | // matching its layout. |
1153 | bucket.bucket_views_in.push_back( |
1154 | gradients.as_strided(v.sizes(), v.strides(), offset)); |
1155 | } else { |
1156 | // Fall back to a C-style contiguous view, again anticipating |
1157 | // AccumulateGrad will do the same when stashing grads for non-dense |
1158 | // params. |
1159 | bucket.bucket_views_in.push_back( |
1160 | gradients.narrow(0, offset, length).view(v.sizes())); |
1161 | } |
1162 | // By default `bucket_views_out` and `bucket_views_in` are |
1163 | // essentially the same thing. |
1164 | bucket.bucket_views_out = bucket.bucket_views_in; |
1165 | |
1166 | // If gradient_as_bucket_view_ is set as true, then there are two cases to |
1167 | // handle: initialize_bucket_views could be called inside initialize_buckets |
1168 | // when rebuild_buckets, if grad has already been defined/calculated in |
1169 | // previous iteration, old grad needs to be copied into new bucket_view and |
1170 | // let grad point to the new bucket_view, initialize_bucket_views could also |
1171 | // be called inside initialize_buckets during construction. Grads are not |
1172 | // defined during construction time, in this case, do not let grad point to |
1173 | // bucket_view, because grads should be kept as being undefined for globally |
1174 | // unused parameters. |
1175 | if (gradient_as_bucket_view_) { |
1176 | auto& bucket_view = bucket.bucket_views_in.back(); |
1177 | runGradCallbackForVariable(v, [&](auto& grad) { |
1178 | if (grad.defined() && !grad.is_alias_of(bucket_view)) { |
1179 | bucket_view.copy_(grad); |
1180 | grad = bucket_view; |
1181 | // The grad is modefied and needs to be written back. |
1182 | return true; |
1183 | } |
1184 | // The grad is not modified and does not need to be written back. |
1185 | return false; |
1186 | }); |
1187 | } |
1188 | } |
1189 | } |
1190 | |
1191 | // (see Note: "Gradient Layout Contract" in initialize_buckets). |
1192 | void Reducer::populate_bucket_views_out( |
1193 | Reducer::Bucket& bucket, |
1194 | at::Tensor& tensor) { |
1195 | bucket.bucket_views_out.clear(); |
1196 | for (const auto i : c10::irange(bucket.variables.size())) { |
1197 | const auto& v = bucket.variables[i]; |
1198 | const auto offset = bucket.offsets[i]; |
1199 | const auto length = bucket.lengths[i]; |
1200 | if (v.is_non_overlapping_and_dense()) { |
1201 | // If the param's memory is dense, match its layout, anticipating |
1202 | // the autograd engine (AccumulateGrad) will also create gradients |
1203 | // matching its layout. |
1204 | bucket.bucket_views_out.push_back( |
1205 | tensor.as_strided(v.sizes(), v.strides(), offset)); |
1206 | } else { |
1207 | // Fall back to a C-style contiguous view, again anticipating |
1208 | // AccumulateGrad will do the same when stashing grads for non-dense |
1209 | // params. |
1210 | bucket.bucket_views_out.push_back( |
1211 | tensor.narrow(0, offset, length).view(v.sizes())); |
1212 | } |
1213 | } |
1214 | } |
1215 | |
1216 | void Reducer::prepare_for_forward() { |
1217 | std::lock_guard<std::mutex> lock(mutex_); |
1218 | num_iterations_++; |
1219 | if (should_collect_runtime_stats()) { |
1220 | record_forward_compute_start_time(); |
1221 | } |
1222 | } |
1223 | |
1224 | void Reducer::reset_bucket_counting() { |
1225 | next_bucket_ = 0; |
1226 | // Reset num_buckets_ready_ at the beginning of backward computation |
1227 | // in each iteration. |
1228 | num_buckets_ready_ = 0; |
1229 | |
1230 | for (auto& bucket : buckets_) { |
1231 | bucket.pending = bucket.variables.size(); |
1232 | } |
1233 | |
1234 | if (static_graph_) { |
1235 | numGradHooksTriggeredMapPerIteration_ = numGradHooksTriggeredMap_; |
1236 | } |
1237 | } |
1238 | |
1239 | // Traverse the autograd graph starting at the specified output. |
1240 | // All parameters for which we have a pointer to their gradient accumulation |
1241 | // functions, but don't show up in the autograd graph will be marked ready for |
1242 | // for reduction as soon as the first autograd hook is called. This is not |
1243 | // done immediately because the model output may be ignored, and we only |
1244 | // want to start performing reductions on `torch.autograd.backward()`. |
1245 | void Reducer::search_unused_parameters( |
1246 | const std::vector<torch::autograd::Variable>& outputs) { |
1247 | std::unordered_set<torch::autograd::Node*> seen; |
1248 | std::vector<torch::autograd::Node*> queue; |
1249 | |
1250 | RECORD_FUNCTION( |
1251 | "torch.distributed.ddp.reducer::search_unused_parameters" , |
1252 | std::vector<c10::IValue>()); |
1253 | |
1254 | // Seed queue with the grad functions of all outputs. |
1255 | for (const auto& output : outputs) { |
1256 | const auto& grad_fn = output.grad_fn(); |
1257 | if (grad_fn) { |
1258 | queue.push_back(grad_fn.get()); |
1259 | } |
1260 | } |
1261 | |
1262 | // Traverse the autograd graph starting at the specified output. |
1263 | while (!queue.empty()) { |
1264 | auto fn = queue.back(); |
1265 | queue.pop_back(); |
1266 | for (const auto& edge : fn->next_edges()) { |
1267 | if (auto next_ptr = edge.function.get()) { |
1268 | const bool was_inserted = seen.insert(next_ptr).second; |
1269 | if (was_inserted) { |
1270 | queue.push_back(next_ptr); |
1271 | } |
1272 | } |
1273 | } |
1274 | } |
1275 | |
1276 | // Find accumulator functions that don't show up in this graph. |
1277 | for (const auto& it : gradAccToVariableMap_) { |
1278 | // If the accumulator function is present in the graph, we know |
1279 | // a gradient will be computed for the corresponding parameter. |
1280 | if (seen.count(it.first) == 0) { |
1281 | if (ddp_debug_level_ == c10d::DebugLevel::Detail) { |
1282 | const auto param_info = param_names_.find(it.second); |
1283 | TORCH_INTERNAL_ASSERT( |
1284 | param_info != param_names_.end(), |
1285 | "Did not find variable index " , |
1286 | it.second, |
1287 | " in DDP parameter name mapping!" ); |
1288 | const auto param_name = param_info->second; |
1289 | LOG(INFO) << "[Rank " << process_group_->getRank() << "]: " |
1290 | << "Parameter " << param_name << " at index " << it.second |
1291 | << " is marked as unused." ; |
1292 | } |
1293 | unused_parameters_.push_back(it.second); |
1294 | } |
1295 | } |
1296 | |
1297 | // Warn user about unnecessary perf hit if all parameters were used in |
1298 | // forward. |
1299 | if (unused_parameters_.empty()) { |
1300 | TORCH_WARN_ONCE( |
1301 | "find_unused_parameters=True was specified in DDP constructor, " |
1302 | "but did not find any unused parameters in the forward pass. This flag " |
1303 | "results in an extra traversal of the autograd graph every iteration, " |
1304 | " which can adversely affect performance. If your model indeed never " |
1305 | "has any unused parameters in the forward pass, consider turning this " |
1306 | "flag off. Note that this warning may be a false positive if your model " |
1307 | "has flow control causing later iterations to have unused parameters." ); |
1308 | } |
1309 | if (!static_graph_ && ddp_graph_static_) { |
1310 | if (num_iterations_ > 1) { |
1311 | // Graph is still static if the set of unused parameters did not change. |
1312 | ddp_graph_static_ = |
1313 | prev_iteration_unused_parameters_ == unused_parameters_; |
1314 | |
1315 | if (!ddp_graph_static_) { |
1316 | // Log graph is not static. Logger takes care of ensuring this is done |
1317 | // only once to avoid overhead. |
1318 | logger_.lock()->log_if_graph_static(false); |
1319 | } |
1320 | } |
1321 | prev_iteration_unused_parameters_ = unused_parameters_; |
1322 | } |
1323 | } |
1324 | |
1325 | void Reducer::prepare_for_backward( |
1326 | const std::vector<torch::autograd::Variable>& outputs) { |
1327 | std::lock_guard<std::mutex> lock(mutex_); |
1328 | |
1329 | backward_compute_start_time_ = current_time_in_nanos(); |
1330 | if (should_collect_runtime_stats()) { |
1331 | record_backward_compute_start_time(); |
1332 | } |
1333 | |
1334 | // Reset accounting. |
1335 | expect_autograd_hooks_ = true; |
1336 | // Clear gradient ready order as it can be different in the next iteration. |
1337 | grad_ready_order_indices_.clear(); |
1338 | |
1339 | reset_bucket_counting(); |
1340 | |
1341 | // Reset unused parameter accounting. |
1342 | has_marked_unused_parameters_ = false; |
1343 | // Reset per iteration marked ready parameters. |
1344 | perIterationReadyParams_.clear(); |
1345 | |
1346 | // If static graph is not set, search graph to detect unused parameters. |
1347 | // When static graph is set, unused_parameters_ will be detected and will |
1348 | // not change after 1st iteration. |
1349 | // If static_graph_ = false and find_unused_parameters_ is false, |
1350 | // we assume that autograd hooks for ALL variables will be called, |
1351 | // and we don't have to search the autograd graph for presence of these hooks. |
1352 | if (dynamic_graph_find_unused()) { |
1353 | unused_parameters_.clear(); |
1354 | search_unused_parameters(outputs); |
1355 | } |
1356 | } |
1357 | |
1358 | void Reducer::copy_bucket_to_grad( |
1359 | at::Tensor& variable, |
1360 | Reducer::Bucket& bucket, |
1361 | size_t intra_bucket_index, |
1362 | bool global_unused) { |
1363 | const auto& bucket_view = bucket.bucket_views_out[intra_bucket_index]; |
1364 | runGradCallbackForVariable(variable, [&](auto& grad) { |
1365 | // If a parameter is globally unused, we keep its grad untouched. |
1366 | if (!global_unused) { |
1367 | if (!grad.defined()) { |
1368 | // Creates grad according to the "Gradient Layout Contract" |
1369 | // (see torch/csrc/autograd/functions/accumulate_grad.h) |
1370 | grad = |
1371 | torch::autograd::utils::clone_obey_contract(bucket_view, variable); |
1372 | } else { |
1373 | grad.copy_(bucket_view); |
1374 | } |
1375 | // The grad is modified and needs to be written back. |
1376 | return true; |
1377 | } |
1378 | // The grad is not modified. |
1379 | return false; |
1380 | }); |
1381 | } |
1382 | |
1383 | std::vector<std::string> Reducer::getUnmarkedParamsForIteration() { |
1384 | std::vector<std::string> unMarkedParamNames; |
1385 | for (const auto& it : param_names_) { |
1386 | if (perIterationReadyParams_.find(it.first) == |
1387 | perIterationReadyParams_.end()) { |
1388 | unMarkedParamNames.push_back(it.second); |
1389 | } |
1390 | } |
1391 | return unMarkedParamNames; |
1392 | } |
1393 | |
1394 | std::vector<size_t> Reducer::getUnmarkedParamIndicesForIteration() { |
1395 | std::vector<size_t> unmarked_param_indices; |
1396 | const auto variable_count = params_.size(); |
1397 | for (const auto variable_index : c10::irange(variable_count)) { |
1398 | if (perIterationReadyParams_.find(variable_index) == |
1399 | perIterationReadyParams_.end()) { |
1400 | unmarked_param_indices.push_back(variable_index); |
1401 | } |
1402 | } |
1403 | return unmarked_param_indices; |
1404 | } |
1405 | |
1406 | // A bucket with one or more dense tensors needs to be unflattened. |
1407 | void Reducer::finalize_bucket_dense(Bucket& bucket) { |
1408 | for (const auto intra_bucket_index : c10::irange(bucket.variables.size())) { |
1409 | auto& variable = bucket.variables[intra_bucket_index]; |
1410 | |
1411 | bool global_unused = false; |
1412 | // See Note [Skip allreducing local_used_map_dev] |
1413 | if (static_graph_ || find_unused_parameters_) { |
1414 | // Determine if this param has been used globally or not. |
1415 | // |
1416 | // If the variable was used locally, it is also used globally and then |
1417 | // we don't need to wait for the reduction. Otherwise we lazily wait for |
1418 | // the reduction to complete, only when we see a variable that was |
1419 | // unused locally. Then we end up delaying the synchronization point |
1420 | // that local_used_work_->wait() implies. If we don't have any unused |
1421 | // parameters at all, we can skip waiting for the work to complete |
1422 | // altogether, and cause negligible performance overhead for models |
1423 | // where all parameters are used. Such lazily waiting means minimizing |
1424 | // performance impact for the big majority of models where all |
1425 | // parameters are always used. Then we only pay the overhead cost if |
1426 | // there is indeed a parameter that is locally unused, because we need |
1427 | // to check if it's also globally unused. |
1428 | size_t variable_index = bucket.variable_indices[intra_bucket_index]; |
1429 | // Note: global_unused might not be global yet. As we lazily wait for |
1430 | // the reduction to complete, it becomes really global only if we get to |
1431 | // the point as below where we wait for the reduction work, make D2H |
1432 | // copy, and update global_unused with the real global consensus, i.e. |
1433 | // local_used_map_reduced_ is true. |
1434 | global_unused = local_used_map_[variable_index].item<int>() == 0; |
1435 | if (global_unused && !local_used_map_reduced_) { |
1436 | // Wait for local_used_map reduction to complete. |
1437 | local_used_work_->wait(); |
1438 | // D2H from local_used_map_dev_ to local_used_map_ |
1439 | // Blocking copy, if local_used_map_dev_ is cuda |
1440 | local_used_map_.copy_(local_used_map_dev_); |
1441 | |
1442 | global_unused = local_used_map_[variable_index].item<int>() == 0; |
1443 | local_used_map_reduced_ = true; |
1444 | } |
1445 | } |
1446 | |
1447 | if (!gradient_as_bucket_view_) { |
1448 | if (set_grads_to_none_) { |
1449 | runGradCallbackForVariable(variable, [&](auto& grad) { |
1450 | grad.reset(); |
1451 | return true; |
1452 | }); |
1453 | } else { |
1454 | RECORD_FUNCTION( |
1455 | "torch.distributed.ddp.reducer::copy_bucket_to_grad" , |
1456 | std::vector<c10::IValue>({variable})); |
1457 | copy_bucket_to_grad( |
1458 | variable, bucket, intra_bucket_index, global_unused); |
1459 | } |
1460 | } else { |
1461 | const auto& bucket_view_out = bucket.bucket_views_out[intra_bucket_index]; |
1462 | auto& bucket_view_in = bucket.bucket_views_in[intra_bucket_index]; |
1463 | // If a communication hook is registered, then `bucket_view_out` stores |
1464 | // the allreduced results in a newly allocated tensor, so we copy |
1465 | // `bucket_view_out` back to `bucket_view_in` for this gradient. |
1466 | if (!bucket_view_in.is_alias_of(bucket_view_out)) { |
1467 | bucket_view_in.copy_(bucket_view_out); |
1468 | } |
1469 | runGradCallbackForVariable(variable, [&](auto& grad) { |
1470 | if (set_grads_to_none_) { |
1471 | grad.reset(); |
1472 | return true; |
1473 | } |
1474 | // If a parameter is globally unused, we keep its grad untouched. |
1475 | if (!global_unused) { |
1476 | // If grad is globally used but locally unused, let grad point to |
1477 | // bucket_view_in |
1478 | if (!grad.defined()) { |
1479 | grad = bucket_view_in; |
1480 | } else { |
1481 | if (!grad.is_alias_of(bucket_view_in)) { |
1482 | REDUCER_CHECK( |
1483 | false, |
1484 | logger_, |
1485 | "Detected at least one parameter gradient is not the " |
1486 | "expected DDP bucket view with gradient_as_bucket_view=True. " |
1487 | "This may happen (for example) if multiple allreduce hooks " |
1488 | "were registered onto the same parameter. If you hit this error, " |
1489 | "please file an issue with a minimal repro." ); |
1490 | } |
1491 | } |
1492 | // The grad is modified and needs to be written back. |
1493 | return true; |
1494 | } |
1495 | // The grad is not modified. |
1496 | return false; |
1497 | }); |
1498 | } |
1499 | } |
1500 | } |
1501 | |
1502 | void Reducer::finalize_backward() { |
1503 | // No longer expect autograd hooks to fire after this function returns. |
1504 | TORCH_INTERNAL_ASSERT(expect_autograd_hooks_); |
1505 | expect_autograd_hooks_ = false; |
1506 | |
1507 | // No longer require call to finalize after this function returns. |
1508 | TORCH_INTERNAL_ASSERT(require_finalize_); |
1509 | require_finalize_ = false; |
1510 | |
1511 | // Wait for asynchronous reduction to complete, and unflatten the bucket's |
1512 | // flattened `gradients` tensor. |
1513 | for (auto& bucket : buckets_) { |
1514 | // See Note [DDP Communication Hook] |
1515 | TORCH_INTERNAL_ASSERT( |
1516 | bucket.future_work, |
1517 | "Expected bucket.future_work not to be null. " |
1518 | "This may indicate that communication hook was not properly installed." ); |
1519 | bucket.future_work->wait(); |
1520 | auto future_result = comm_hook_ == nullptr |
1521 | ? detail::parseCppCommHookResult(bucket.future_work->value()) |
1522 | : comm_hook_->parseHookResult(bucket.future_work->value()); |
1523 | if (bucket.expect_sparse_gradient) { |
1524 | bucket.gradients.copy_(future_result); |
1525 | } else { |
1526 | // Reinitialize only `bucket_views_out` with the future_result by |
1527 | // following the same logic in `initialize_buckets`. |
1528 | populate_bucket_views_out(bucket, future_result); |
1529 | } |
1530 | |
1531 | // Unset allreduce division factor, as it may change in next backwards pass |
1532 | // when running with DDP join mode. |
1533 | div_factor_ = kUnsetDivFactor; |
1534 | |
1535 | if (!bucket.expect_sparse_gradient) { |
1536 | // We don't need to finalize the sparse bucket since the sparse grad and |
1537 | // the bucket essentially point to the same storage. As a result, once |
1538 | // the allreduce is done, the sparse grads are automatically updated. |
1539 | finalize_bucket_dense(bucket); |
1540 | } |
1541 | } |
1542 | |
1543 | if (installed_futures_ != c10::nullopt) { |
1544 | c10::collectAll(*installed_futures_)->wait(); |
1545 | installed_futures_ = c10::nullopt; |
1546 | } |
1547 | |
1548 | // See Note [Skip allreducing local_used_maps_dev] |
1549 | if (dynamic_graph_find_unused() || static_graph_first_iteration()) { |
1550 | // Due to the lazy wait, it is possible that reduction of the current |
1551 | // iteration is still going when the one for next iteration gets kicked off. |
1552 | // For such case, we want to wait explicitly to make sure the reduction does |
1553 | // complete before kicking off next one. Otherwise the previous one may |
1554 | // interfere, write to the device-side memory and clobber the content of |
1555 | // local_unused_maps_dev_. |
1556 | if (!local_used_map_reduced_) { |
1557 | local_used_work_->wait(); |
1558 | } |
1559 | } |
1560 | |
1561 | if (dynamic_graph_find_unused()) { |
1562 | // Reset unused parameter accounting. |
1563 | // See Note [local_used_map_ -> local_used_map_dev copying] |
1564 | local_used_map_.fill_(0); |
1565 | local_used_map_reduced_ = false; |
1566 | } |
1567 | |
1568 | if (should_collect_runtime_stats()) { |
1569 | record_backward_comm_end_time(); |
1570 | } |
1571 | } |
1572 | |
1573 | void Reducer::runGradCallbackForVariable( |
1574 | at::Tensor& variable, |
1575 | GradCallback&& cb) { |
1576 | #ifdef _WIN32 |
1577 | cb(variable.mutable_grad()); |
1578 | #else |
1579 | auto context_ptr = rpc_context_.context_ptr.load(); |
1580 | if (context_ptr == nullptr) { |
1581 | cb(variable.mutable_grad()); |
1582 | } else { |
1583 | // Under distributed autograd |
1584 | context_ptr->runGradCallbackForVariable(variable, std::move(cb)); |
1585 | } |
1586 | #endif |
1587 | } |
1588 | |
1589 | #ifndef _WIN32 |
1590 | void Reducer::RpcContext::set(ContextPtr&& new_context_ptr) { |
1591 | // We should set 'new_context_ptr' even if it's nullptr. That means the |
1592 | // reducer is under a local backward run. |
1593 | const auto new_context_raw_ptr = new_context_ptr.get(); |
1594 | if (context_ptr.exchange(new_context_raw_ptr) != new_context_raw_ptr) { |
1595 | // Set the shared ptr to the context only if it's set first time. |
1596 | // All call sites should use the same context ptr. |
1597 | // Use an atomic to avoid data race from multiple threads. |
1598 | context_ptr_holder = std::move(new_context_ptr); |
1599 | } |
1600 | } |
1601 | #endif |
1602 | |
1603 | void Reducer::sync_bucket_indices( |
1604 | std::vector<std::vector<size_t>>& bucket_indices) { |
1605 | auto num_buckets = bucket_indices.size(); |
1606 | std::vector<size_t> bucket_sizes; |
1607 | bucket_sizes.reserve(num_buckets); |
1608 | int64_t total_size = 0; |
1609 | for (const auto i : c10::irange(num_buckets)) { |
1610 | auto bucket_size = bucket_indices.at(i).size(); |
1611 | bucket_sizes.push_back(bucket_size); |
1612 | total_size += bucket_size; |
1613 | } |
1614 | |
1615 | at::TensorOptions options; |
1616 | options = options.dtype(at::kInt); |
1617 | options = options.device(params_[0].device()); |
1618 | |
1619 | // Group indices and num_bucket together into indices_tensor |
1620 | // Broadcast this tensor first, as its size is equal among all processes |
1621 | auto indices_tensor = at::empty({total_size + 1}, at::kInt); |
1622 | auto indices_accessor = indices_tensor.accessor<int, 1>(); |
1623 | auto indices_accessor_Index = 0; |
1624 | for (const auto i : c10::irange(num_buckets)) { |
1625 | const auto& bucket_size = bucket_indices.at(i).size(); |
1626 | for (const auto j : c10::irange(bucket_size)) { |
1627 | indices_accessor[indices_accessor_Index++] = bucket_indices[i][j]; |
1628 | } |
1629 | } |
1630 | indices_accessor[indices_accessor_Index] = num_buckets; |
1631 | |
1632 | // Copy CPU tensor to device tensor, as the process_group_ could be NCCL and |
1633 | // it can only broadcast device tensors. |
1634 | auto indices_tensor_device = at::empty({total_size + 1}, options); |
1635 | indices_tensor_device.copy_(indices_tensor, /*non_blocking=*/true); |
1636 | std::vector<at::Tensor> indices_tensor_list = {indices_tensor_device}; |
1637 | process_group_->broadcast(indices_tensor_list)->wait(); |
1638 | indices_tensor.copy_(indices_tensor_list.front(), /*non_blocking=*/false); |
1639 | |
1640 | // Update num_buckets after receiving it from rank 0 |
1641 | num_buckets = indices_accessor[indices_accessor_Index]; |
1642 | |
1643 | // Broadcast bucket_sizes |
1644 | auto bucket_sizes_tensor = at::empty({(int64_t)num_buckets}, at::kInt); |
1645 | auto bucket_sizes_accessor = bucket_sizes_tensor.accessor<int, 1>(); |
1646 | for (const auto i : c10::irange(num_buckets)) { |
1647 | // For rank != 0, it is possible that local num buckets bucket_sizes.size() |
1648 | // is smaller than broadcasted num_buckets |
1649 | bucket_sizes_accessor[i] = |
1650 | bucket_sizes.at(std::min(i, (bucket_sizes.size() - 1))); |
1651 | } |
1652 | auto bucket_sizes_tensor_device = at::empty({(int64_t)num_buckets}, options); |
1653 | bucket_sizes_tensor_device.copy_(bucket_sizes_tensor, /*non_blocking=*/true); |
1654 | std::vector<at::Tensor> bucket_sizes_tensor_list = { |
1655 | bucket_sizes_tensor_device}; |
1656 | process_group_->broadcast(bucket_sizes_tensor_list)->wait(); |
1657 | bucket_sizes_tensor.copy_( |
1658 | bucket_sizes_tensor_list.front(), /*non_blocking=*/false); |
1659 | |
1660 | // Clear bucket_indices first, and then update bucket_indices using received |
1661 | // num_buckets, bucket_sizes_tensor and indices_tensor from rank 0 |
1662 | bucket_indices.clear(); |
1663 | bucket_indices.reserve(num_buckets); |
1664 | indices_accessor_Index = 0; |
1665 | for (const auto i : c10::irange(num_buckets)) { |
1666 | const auto& bucket_size = bucket_sizes_accessor[i]; |
1667 | std::vector<size_t> bucket; |
1668 | bucket.reserve(bucket_size); |
1669 | for (const auto j : c10::irange(bucket_size)) { |
1670 | (void)j; |
1671 | bucket.push_back(indices_accessor[indices_accessor_Index++]); |
1672 | } |
1673 | bucket_indices.emplace_back(std::move(bucket)); |
1674 | } |
1675 | } |
1676 | |
1677 | bool Reducer::rebuild_buckets() { |
1678 | // Ensure reduction for previous backwards pass is finished. If user's model |
1679 | // has unused parameters for example, this will raise an error recommending to |
1680 | // run with find_unused_parameters=True, instead of the size mismatch |
1681 | // exception below. |
1682 | std::lock_guard<std::mutex> lock(mutex_); |
1683 | ensure_prior_reduction_finished(); |
1684 | if (!should_rebuild_buckets() || rebuilt_params_.empty()) { |
1685 | return false; |
1686 | } |
1687 | |
1688 | TORCH_INTERNAL_ASSERT( |
1689 | rebuilt_params_.size() == rebuilt_param_indices_.size(), |
1690 | c10::str( |
1691 | "rebuilt parameter tensors size is not same as rebuilt parameter indices size: " , |
1692 | rebuilt_params_.size(), |
1693 | " versus " , |
1694 | rebuilt_param_indices_.size())); |
1695 | TORCH_INTERNAL_ASSERT( |
1696 | params_.size() == rebuilt_param_indices_.size(), |
1697 | c10::str( |
1698 | "rebuilt parameter indices size is not same as original model parameters size." , |
1699 | "Original model param size is: " , |
1700 | params_.size(), |
1701 | " versus rebuilt params size of: " , |
1702 | rebuilt_param_indices_.size())); |
1703 | std::vector<std::vector<size_t>> rebuilt_bucket_indices; |
1704 | std::vector<size_t> bucket_size_limits; |
1705 | bucket_size_limits.push_back(first_bucket_bytes_cap_); |
1706 | bucket_size_limits.push_back(bucket_bytes_cap_); |
1707 | std::vector<size_t> per_bucket_size_limits; |
1708 | auto ddp_set_last_bucket_as_small = |
1709 | (parse_env("DDP_SET_LAST_BUCKET_CAP" ) == "1" ); |
1710 | |
1711 | if (ddp_set_last_bucket_as_small) { |
1712 | // Reverse so that first_bucket_bytes_cap_ (smaller bucket) becomes the last |
1713 | // bucket. We cannot simply pass in {bucket_bytes_cap_, |
1714 | // first_bucket_bytes_cap} as the bucket order as we would immediately |
1715 | // advance to the 2nd element after the first bucket, whereas we only want |
1716 | // the last bucket to have a smaller size. |
1717 | std::reverse(rebuilt_params_.begin(), rebuilt_params_.end()); |
1718 | std::reverse(rebuilt_param_indices_.begin(), rebuilt_param_indices_.end()); |
1719 | } |
1720 | std::tie(rebuilt_bucket_indices, per_bucket_size_limits) = |
1721 | compute_bucket_assignment_by_size( |
1722 | rebuilt_params_, |
1723 | bucket_size_limits, |
1724 | expect_sparse_gradients_, |
1725 | rebuilt_param_indices_, |
1726 | logger_); |
1727 | |
1728 | if (ddp_set_last_bucket_as_small) { |
1729 | // Reverse again because buckets were rebuilt in the opposite of gradient |
1730 | // ready order. |
1731 | std::reverse(rebuilt_bucket_indices.begin(), rebuilt_bucket_indices.end()); |
1732 | std::reverse(per_bucket_size_limits.begin(), per_bucket_size_limits.end()); |
1733 | } |
1734 | |
1735 | if (ddp_debug_level_ != c10d::DebugLevel::Off) { |
1736 | TORCH_INTERNAL_ASSERT( |
1737 | rebuilt_bucket_indices.size() == per_bucket_size_limits.size()) |
1738 | LOG(INFO) << rebuilt_bucket_indices.size() |
1739 | << " buckets rebuilt with size limits: " |
1740 | << c10::Join(", " , per_bucket_size_limits) << " bytes." ; |
1741 | } |
1742 | |
1743 | // For rebuilt bucket indices, it needs to be synced across all ranks. |
1744 | // Broadcast the newly rebuilt bucket indices from rank 0 in default. |
1745 | // After syncing up rebuilt bucket indices, initialize buckets for reducer. |
1746 | sync_bucket_indices(rebuilt_bucket_indices); |
1747 | |
1748 | has_rebuilt_bucket_ = true; |
1749 | rebuilt_params_.clear(); |
1750 | rebuilt_param_indices_.clear(); |
1751 | |
1752 | initialize_buckets(std::move(rebuilt_bucket_indices)); |
1753 | |
1754 | return true; |
1755 | } |
1756 | |
1757 | // See Note [DDP Communication Hook] |
1758 | void Reducer::register_comm_hook(std::unique_ptr<CommHookInterface> iface) { |
1759 | REDUCER_CHECK( |
1760 | comm_hook_ == nullptr, |
1761 | logger_, |
1762 | "register_comm_hook or register_builtin_comm_hook can only be called once." ); |
1763 | |
1764 | comm_hook_ = std::move(iface); |
1765 | } |
1766 | |
1767 | // See Note [DDP Communication Hook] |
1768 | void Reducer::register_builtin_comm_hook( |
1769 | c10d::BuiltinCommHookType comm_hook_type) { |
1770 | REDUCER_CHECK( |
1771 | comm_hook_ == nullptr, |
1772 | logger_, |
1773 | "register_builtin_comm_hook or register_comm_hook can only be called once." ); |
1774 | |
1775 | switch (comm_hook_type) { |
1776 | case c10d::BuiltinCommHookType::ALLREDUCE: |
1777 | comm_hook_ = std::make_unique<c10d::AllReduceCommHook>(process_group_); |
1778 | LOG(INFO) << "Built-in communication hook ALLREDUCE is registered." ; |
1779 | break; |
1780 | case c10d::BuiltinCommHookType::FP16_COMPRESS: |
1781 | comm_hook_ = std::make_unique<c10d::FP16CompressCommHook>(process_group_); |
1782 | LOG(INFO) << "Built-in communication hook FP16_COMPRESS is registered." ; |
1783 | break; |
1784 | default: |
1785 | TORCH_WARN_ONCE( |
1786 | "Unknown built-in DDP comm hook type is provided. No comm hook will be used." ); |
1787 | } |
1788 | } |
1789 | |
1790 | void Reducer::set_grads_to_none(bool set_to_none) { |
1791 | set_grads_to_none_ = set_to_none; |
1792 | } |
1793 | |
1794 | void Reducer::ensure_prior_reduction_finished() { |
1795 | // Check that any prior reduction has finished. |
1796 | // The variable `require_finalize_` is true until all gradients |
1797 | // have been computed and reduction of all buckets has been kicked off. |
1798 | if (require_finalize_) { |
1799 | // Collect unmarked parameter indices, additionally, in debug mode retrieve |
1800 | // parameter names. |
1801 | auto unmarked_param_indices = getUnmarkedParamIndicesForIteration(); |
1802 | // We should have some unmarked parameter indices, otherwise we would not |
1803 | // have run into this error branch. |
1804 | TORCH_INTERNAL_ASSERT(!unmarked_param_indices.empty()); |
1805 | |
1806 | std::string kBaseErrorMsg = |
1807 | "Expected to have finished reduction in the prior iteration before " |
1808 | "starting a new one. " |
1809 | "" |
1810 | "This error indicates that your module has parameters that were " |
1811 | "not used in producing loss. " ; |
1812 | std::string kOutputsNotUsedInLossErrorMsg = |
1813 | "making sure all " |
1814 | "`forward` function outputs participate in calculating loss. " ; |
1815 | std::string kDDPBugErrorMsg = |
1816 | "\nIf you already have done the above, then the distributed " |
1817 | "data parallel module wasn't able to locate the output tensors in the " |
1818 | "return value of your module's `forward` function. " |
1819 | "Please include the loss function and the structure of the return " |
1820 | "value of `forward` of your module when reporting this issue (e.g. " |
1821 | "list, dict, iterable)." ; |
1822 | |
1823 | if (static_graph_) { |
1824 | kBaseErrorMsg = |
1825 | "Expected to have finished reduction in the prior iteration before " |
1826 | "starting a new one. " |
1827 | "This error indicates that your training graph has changed " |
1828 | "in this iteration, e.g., one parameter is used in first " |
1829 | "iteration, but then got unused in the second iteration. " |
1830 | "this is not compatible with static_graph set to True." ; |
1831 | } else if (!find_unused_parameters_) { |
1832 | // Parameters may have been unused in forward pass, or not all outputs |
1833 | // were used in producing loss. |
1834 | kBaseErrorMsg += |
1835 | "You can enable unused parameter detection by passing the " |
1836 | "keyword argument `find_unused_parameters=True` to " |
1837 | "`torch.nn.parallel.DistributedDataParallel`, and by \n" ; |
1838 | kBaseErrorMsg += kOutputsNotUsedInLossErrorMsg; |
1839 | kBaseErrorMsg += kDDPBugErrorMsg; |
1840 | } else { |
1841 | // Note that it does not really matter whether unused_parameters_.empty(), |
1842 | // since user may have enabled detection but this particular iteration |
1843 | // could have used or not used all parameters. |
1844 | kBaseErrorMsg += |
1845 | "Since `find_unused_parameters=True` is enabled, this likely " |
1846 | " means that not all `forward` outputs participate in computing loss. You can fix this by " ; |
1847 | kBaseErrorMsg += kOutputsNotUsedInLossErrorMsg; |
1848 | kBaseErrorMsg += kDDPBugErrorMsg; |
1849 | } |
1850 | |
1851 | const std::string unmarked_param_indices_info = c10::str( |
1852 | "\n" , |
1853 | "Parameter indices which did not receive grad for rank " , |
1854 | process_group_->getRank(), |
1855 | ": " , |
1856 | unmarked_param_indices); |
1857 | |
1858 | if (ddp_debug_level_ == DebugLevel::Off) { |
1859 | // Without debug mode, log unmarked_param_indices, as well as |
1860 | // recommendation to use debug mode to print parameter names. |
1861 | kBaseErrorMsg += unmarked_param_indices_info; |
1862 | kBaseErrorMsg += |
1863 | "\n In addition, you can set the environment variable " |
1864 | "TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print out information " |
1865 | "about which particular parameters did not receive gradient on this rank " |
1866 | "as part of this error" ; |
1867 | } else { |
1868 | // Retrieve set of parameter names that did not receive gradient. |
1869 | auto unmarkedParams = getUnmarkedParamsForIteration(); |
1870 | TORCH_INTERNAL_ASSERT(!unmarkedParams.empty()); |
1871 | for (const auto& s : unmarkedParams) { |
1872 | LOG(INFO) << "[Rank " << process_group_->getRank() << "] " |
1873 | << "Parameter: " << s |
1874 | << " did not get gradient in backwards pass." ; |
1875 | } |
1876 | const std::string unmarkedParamInfo = c10::Join(", " , unmarkedParams); |
1877 | // In debug mode, log param names and indices that went unused. |
1878 | kBaseErrorMsg += c10::str( |
1879 | "\n" , |
1880 | "Parameters which did not receive grad for rank " , |
1881 | process_group_->getRank(), |
1882 | ": " , |
1883 | unmarkedParamInfo); |
1884 | kBaseErrorMsg += unmarked_param_indices_info; |
1885 | } |
1886 | REDUCER_CHECK(false, logger_, kBaseErrorMsg); |
1887 | } |
1888 | } |
1889 | |
1890 | void Reducer::set_ddp_runtime_logging_sample_rate(int sample_rate) { |
1891 | ddp_runtime_logging_sample_rate_ = sample_rate; |
1892 | } |
1893 | |
1894 | int Reducer::get_ddp_runtime_logging_sample_rate() { |
1895 | return ddp_runtime_logging_sample_rate_; |
1896 | } |
1897 | |
1898 | bool Reducer::should_collect_runtime_stats() { |
1899 | if (num_iterations_ > 0 && |
1900 | (num_iterations_ <= 10 || |
1901 | num_iterations_ % get_ddp_runtime_logging_sample_rate() == 0)) { |
1902 | return true; |
1903 | } |
1904 | return false; |
1905 | } |
1906 | |
1907 | void Reducer::record_forward_compute_start_time() { |
1908 | if (timer_) { |
1909 | timer_->record(Timer::Event::kForwardStart); |
1910 | } |
1911 | } |
1912 | |
1913 | void Reducer::record_backward_compute_start_time() { |
1914 | if (timer_) { |
1915 | timer_->record(Timer::Event::kBackwardComputeStart); |
1916 | } |
1917 | } |
1918 | |
1919 | void Reducer::record_backward_compute_end_time() { |
1920 | if (timer_) { |
1921 | timer_->record(Timer::Event::kBackwardComputeEnd); |
1922 | } |
1923 | } |
1924 | |
1925 | void Reducer::record_backward_comm_start_time() { |
1926 | if (timer_) { |
1927 | timer_->record(Timer::Event::kBackwardCommStart); |
1928 | } |
1929 | } |
1930 | |
1931 | void Reducer::record_backward_comm_end_time() { |
1932 | if (timer_) { |
1933 | timer_->record(Timer::Event::kBackwardCommEnd); |
1934 | } |
1935 | } |
1936 | |
1937 | void Reducer::set_static_graph() { |
1938 | std::lock_guard<std::mutex> lock(mutex_); |
1939 | REDUCER_CHECK( |
1940 | num_iterations_ == 0, |
1941 | logger_, |
1942 | "set_static_graph() should be called before training loop starts " |
1943 | "and after DistributedDataParallel is constructed." ); |
1944 | static_graph_ = true; |
1945 | // when static_graph_ is set as true, always initialize_local_used_map |
1946 | // and detect the global unused parameters in the first iteration. |
1947 | initialize_local_used_map(); |
1948 | } |
1949 | |
1950 | namespace { |
1951 | |
1952 | // Tensors may be coalesced into buckets. Buckets must contain tensors of |
1953 | // the same type, on the same device, so a bucket can identified by a |
1954 | // composite key of a tensor's type identifier and its device. |
1955 | struct BucketKey { |
1956 | BucketKey(c10::ScalarType type, c10::Device device) |
1957 | : type(type), device(device) {} |
1958 | |
1959 | const c10::ScalarType type; |
1960 | const c10::Device device; |
1961 | |
1962 | // See torch/csrc/utils/hash.h for dispatch code. |
1963 | static size_t hash(const BucketKey& key) { |
1964 | return c10::get_hash(key.type, key.device); |
1965 | } |
1966 | }; |
1967 | |
1968 | inline bool operator==(const BucketKey& lhs, const BucketKey& rhs) { |
1969 | return lhs.type == rhs.type && lhs.device == rhs.device; |
1970 | } |
1971 | |
1972 | } // namespace |
1973 | |
1974 | std::tuple<std::vector<std::vector<size_t>>, std::vector<size_t>> |
1975 | compute_bucket_assignment_by_size( |
1976 | const std::vector<at::Tensor>& tensors, |
1977 | const std::vector<size_t>& bucket_size_limits, |
1978 | const std::vector<bool>& expect_sparse_gradient, |
1979 | const std::vector<int64_t>& tensor_indices, |
1980 | const c10::optional<std::weak_ptr<c10d::Logger>>& logger) { |
1981 | // Either expect_sparse_gradient is not specified or it has as many elements |
1982 | // as the vector with tensors. |
1983 | TORCH_INTERNAL_ASSERT( |
1984 | expect_sparse_gradient.empty() || |
1985 | (tensors.size() == expect_sparse_gradient.size())); |
1986 | TORCH_INTERNAL_ASSERT(!tensors.empty()); |
1987 | // Store bucket indices and their sizes together, because we later sort the |
1988 | // resulting indices by minimum tensor index and want to keep sizes |
1989 | // consistent. |
1990 | std::vector<std::tuple<std::vector<size_t>, size_t>> result; |
1991 | // Sparse tensors go in their own bucket, so they do not have an enforced size |
1992 | // limit. |
1993 | size_t kNoSizeLimit = 0; |
1994 | result.reserve(tensors.size()); |
1995 | |
1996 | // Keep iterator into the size_limit vector by tensor type and device. |
1997 | // This is done so that we can use the consecutive bucket limits per type. |
1998 | std::unordered_map< |
1999 | BucketKey, |
2000 | std::vector<size_t>::const_iterator, |
2001 | c10::hash<BucketKey>> |
2002 | bucket_size_limit_iterators; |
2003 | |
2004 | // Keep vector of indices and size accumulator by tensor type and device. |
2005 | std::unordered_map<BucketKey, BucketAccumulator, c10::hash<BucketKey>> |
2006 | buckets; |
2007 | |
2008 | for (const auto i : c10::irange(tensors.size())) { |
2009 | const auto& tensor = tensors[i]; |
2010 | auto msg = std::string("No support for sparse tensors." ); |
2011 | if (logger.has_value()) { |
2012 | REDUCER_CHECK(!tensor.is_sparse(), logger.value(), msg); |
2013 | } else { |
2014 | TORCH_CHECK(!tensor.is_sparse(), msg); |
2015 | } |
2016 | |
2017 | // when tensor_indices is empty, the index of tensors[i] assigned to |
2018 | // bucket is i, otherwise the tensor index is tensor_indices[i]. |
2019 | auto tensor_index = i; |
2020 | if (!tensor_indices.empty()) { |
2021 | tensor_index = tensor_indices[i]; |
2022 | } |
2023 | // If we expect a sparse gradient to be produced for this tensor, it cannot |
2024 | // be grouped together with other gradients and gets its own bucket. |
2025 | if (!expect_sparse_gradient.empty() && |
2026 | expect_sparse_gradient[tensor_index]) { |
2027 | result.emplace_back(std::vector<size_t>({tensor_index}), kNoSizeLimit); |
2028 | continue; |
2029 | } |
2030 | |
2031 | auto key = BucketKey(tensor.scalar_type(), tensor.device()); |
2032 | auto& bucket = buckets[key]; |
2033 | bucket.indices.push_back(tensor_index); |
2034 | bucket.size += tensor.numel() * tensor.element_size(); |
2035 | |
2036 | // Initialize bucket size limit iterator if necessary. |
2037 | if (bucket_size_limit_iterators.count(key) == 0) { |
2038 | bucket_size_limit_iterators[key] = bucket_size_limits.begin(); |
2039 | } |
2040 | |
2041 | auto& bucket_size_limit_iterator = bucket_size_limit_iterators[key]; |
2042 | const auto bucket_size_limit = *bucket_size_limit_iterator; |
2043 | bucket.size_limit = bucket_size_limit; |
2044 | if (bucket.size >= bucket_size_limit) { |
2045 | result.emplace_back(std::move(bucket.indices), bucket.size_limit); |
2046 | bucket = BucketAccumulator(); |
2047 | |
2048 | // Advance to the next bucket size limit for this type/device. |
2049 | auto next = bucket_size_limit_iterator + 1; |
2050 | if (next != bucket_size_limits.end()) { |
2051 | bucket_size_limit_iterator = next; |
2052 | } |
2053 | } |
2054 | } |
2055 | |
2056 | // Add remaining buckets. |
2057 | for (auto& it : buckets) { |
2058 | auto& bucket = it.second; |
2059 | if (!bucket.indices.empty()) { |
2060 | result.emplace_back(std::move(bucket.indices), bucket.size_limit); |
2061 | } |
2062 | } |
2063 | |
2064 | // If tensor_indices is not empty, the order of the tensors is in the gradient |
2065 | // ready order, so no need to sort. |
2066 | // If tensor_indices is empty, sort resulting buckets by the minimum tensor |
2067 | // index they include. We assume that the order of the tensors is the order in |
2068 | // which they are used (or the reverse order in which their gradients are |
2069 | // produced). This sorting step ensures that the buckets are ready in |
2070 | // consecutive order. |
2071 | if (tensor_indices.empty()) { |
2072 | std::sort( |
2073 | result.begin(), |
2074 | result.end(), |
2075 | [](const std::tuple<std::vector<size_t>, size_t>& a, |
2076 | const std::tuple<std::vector<size_t>, size_t>& b) { |
2077 | auto indices_a = std::get<0>(a); |
2078 | auto indices_b = std::get<0>(b); |
2079 | const auto amin = |
2080 | std::min_element(indices_a.begin(), indices_a.end()); |
2081 | const auto bmin = |
2082 | std::min_element(indices_b.begin(), indices_b.end()); |
2083 | return *amin < *bmin; |
2084 | }); |
2085 | } |
2086 | |
2087 | // Return bucket indices and size limits as separate entries in tuple, as some |
2088 | // APIs only need to consume bucket indices. |
2089 | std::vector<std::vector<size_t>> bucket_indices; |
2090 | bucket_indices.reserve(result.size()); |
2091 | std::vector<size_t> per_bucket_size_limits; |
2092 | per_bucket_size_limits.reserve(result.size()); |
2093 | for (const auto& bucket_indices_with_size : result) { |
2094 | bucket_indices.emplace_back(std::get<0>(bucket_indices_with_size)); |
2095 | per_bucket_size_limits.emplace_back(std::get<1>(bucket_indices_with_size)); |
2096 | } |
2097 | return std::make_tuple(bucket_indices, per_bucket_size_limits); |
2098 | } |
2099 | |
2100 | // Verifies corresponding params in the model replica have the same |
2101 | // sizes/strides across processes. |
2102 | void verify_params_across_processes( |
2103 | const c10::intrusive_ptr<c10d::ProcessGroup>& process_group, |
2104 | const std::vector<at::Tensor>& params, |
2105 | const c10::optional<std::weak_ptr<c10d::Logger>>& logger) { |
2106 | // First verify number of parameters to avoid inconsistent inputs into |
2107 | // broadcast which can cause a crash. |
2108 | // See https://github.com/pytorch/pytorch/issues/73547 |
2109 | |
2110 | at::TensorOptions param_size_options; |
2111 | param_size_options = param_size_options.dtype(at::kLong); |
2112 | param_size_options = param_size_options.device(params[0].device()); |
2113 | // Note: Not using tensor building API because of |
2114 | // https://github.com/pytorch/pytorch/issues/74114 |
2115 | at::Tensor param_size_tensor = |
2116 | at::tensor({static_cast<int64_t>(params.size())}, param_size_options); |
2117 | |
2118 | // Allgather and verify parameter size. |
2119 | std::vector<std::vector<at::Tensor>> param_size_output_tensors; |
2120 | param_size_output_tensors.emplace_back(); |
2121 | auto world_size = process_group->getSize(); |
2122 | for (size_t i = 0; i < world_size; ++i) { |
2123 | param_size_output_tensors.front().emplace_back( |
2124 | at::empty_like(param_size_tensor)); |
2125 | } |
2126 | |
2127 | std::vector<at::Tensor> param_size_vec{param_size_tensor}; |
2128 | process_group->allgather(param_size_output_tensors, param_size_vec)->wait(); |
2129 | auto result_size_tensors = param_size_output_tensors.front(); |
2130 | for (size_t i = 0; i < world_size; ++i) { |
2131 | auto param_size_for_rank = result_size_tensors[i][0].item<int>(); |
2132 | TORCH_CHECK( |
2133 | param_size_for_rank == params.size(), |
2134 | c10::str( |
2135 | "DDP expects same model across all ranks, but Rank " , |
2136 | process_group->getRank(), |
2137 | " has " , |
2138 | params.size(), |
2139 | " params, while rank " , |
2140 | i, |
2141 | " has inconsistent " , |
2142 | param_size_for_rank, |
2143 | " params." )); |
2144 | } |
2145 | |
2146 | // Continue with parameter shape verification. |
2147 | size_t i = 0; |
2148 | for (const auto& t : params) { |
2149 | i += 2 * t.dim(); |
2150 | } |
2151 | at::TensorOptions options; |
2152 | options = options.dtype(at::kLong); |
2153 | auto metadata = at::empty({static_cast<long>(i)}, options); |
2154 | |
2155 | // Technically, process 0 is the broadcast source, so only process 0 needs |
2156 | // to populate metadata. But no harm keeping work aligned across processes. |
2157 | auto metadata_accessor = metadata.accessor<int64_t, 1>(); |
2158 | i = 0; |
2159 | for (const auto& t : params) { |
2160 | for (const auto& sz : t.sizes()) { |
2161 | metadata_accessor[i++] = sz; |
2162 | } |
2163 | for (const auto& str : t.strides()) { |
2164 | metadata_accessor[i++] = str; |
2165 | } |
2166 | } |
2167 | |
2168 | auto metadata_dev = metadata.clone().to(params[0].device()); |
2169 | std::vector<at::Tensor> vec{metadata_dev}; |
2170 | process_group->broadcast(vec)->wait(); |
2171 | |
2172 | // Technically, process 0 doesn't need to double-check metadata, because it |
2173 | // was the source. But no harm keeping work aligned. |
2174 | auto control = at::empty({static_cast<long>(i)}, options); |
2175 | control.copy_(metadata_dev, /*non_blocking=*/false); |
2176 | auto control_accessor = control.accessor<int64_t, 1>(); |
2177 | i = 0; |
2178 | for (const auto p : c10::irange(params.size())) { |
2179 | const auto& t = params[p]; |
2180 | for (const auto& sz : t.sizes()) { |
2181 | auto msg = c10::str( |
2182 | "[" , |
2183 | process_group->getRank(), |
2184 | "]: params[" , |
2185 | p, |
2186 | "] in this process" , |
2187 | " with sizes " , |
2188 | t.sizes(), |
2189 | " appears not to match sizes of the same param in process 0." ); |
2190 | if (logger.has_value()) { |
2191 | REDUCER_CHECK(sz == control_accessor[i++], logger.value(), msg) |
2192 | } else { |
2193 | TORCH_CHECK(sz == control_accessor[i++], msg) |
2194 | } |
2195 | } |
2196 | for (const auto& str : t.strides()) { |
2197 | auto msg = c10::str( |
2198 | "params[" , |
2199 | p, |
2200 | "] in this process" , |
2201 | " with sizes " , |
2202 | t.sizes(), |
2203 | " appears not to match strides of the same param in process 0." ); |
2204 | if (logger.has_value()) { |
2205 | REDUCER_CHECK(str == control_accessor[i++], logger.value(), msg) |
2206 | } else { |
2207 | TORCH_CHECK(str == control_accessor[i++], msg) |
2208 | } |
2209 | } |
2210 | } |
2211 | } |
2212 | |
2213 | } // namespace c10d |
2214 | |