1#pragma once
2
3#include <atomic>
4#include <memory>
5#include <mutex>
6#include <tuple>
7#include <unordered_map>
8#include <vector>
9
10#include <ATen/core/ivalue_inl.h>
11#include <c10/macros/Macros.h>
12#include <c10/util/intrusive_ptr.h>
13#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
14#include <torch/csrc/distributed/c10d/Utils.hpp>
15#include <torch/csrc/distributed/c10d/comm.hpp>
16#include <torch/csrc/distributed/c10d/debug.h>
17#include <torch/csrc/distributed/c10d/reducer_timer.hpp>
18#include <torch/csrc/distributed/c10d/default_comm_hooks.hpp>
19#include <torch/csrc/autograd/function.h>
20#include <torch/csrc/autograd/profiler.h>
21#include <torch/csrc/autograd/variable.h>
22#ifndef _WIN32
23#include <torch/csrc/distributed/autograd/context/context.h>
24#endif
25
26namespace c10d {
27
28constexpr int kDefaultFirstBucketBytes = int(1024 * 1024);
29constexpr int kDefaultBucketBytesCap = int(25 * 1024 * 1024);
30// Collect runtime stats once for every kDDPRuntimeLoggingSampleRate iterations.
31constexpr int kDDPRuntimeLoggingSampleRate = 100;
32
33// Forward declaration
34class Logger;
35
36// Local accumulator type for a single bucket.
37struct BucketAccumulator {
38 std::vector<size_t> indices;
39 size_t size = 0;
40 size_t size_limit = 0;
41};
42
43class TORCH_API Reducer {
44 public:
45 // The constructor takes a list of variables (i.e. parameters) for this
46 // process's single model replica (as DDP assumes single-process
47 // single-device). The bucket assignment for this reducer, `bucket_indices`,
48 // is specified as a list of buckets, each of which is specified as a list of
49 // indices into the bucket's `variables` list.
50 explicit Reducer(
51 std::vector<at::Tensor> params,
52 std::vector<std::vector<size_t>> bucket_indices,
53 std::vector<size_t> per_bucket_size_limits,
54 c10::intrusive_ptr<c10d::ProcessGroup> process_group,
55 std::vector<bool> expect_sparse_gradients,
56 int64_t bucket_bytes_cap,
57 bool find_unused_parameters,
58 bool gradient_as_bucket_view,
59 std::unordered_map<size_t, std::string> param_names,
60 int64_t first_bucket_bytes_cap);
61
62 ~Reducer() noexcept(false);
63
64 // To (re-)initialize bucket assignment, pass a list of buckets, each of
65 // which is specified by a list of indices in the bucket's `variables` list.
66 // This function performs validation that the variables within a bucket
67 // all live on the same device and have the same dimensionality.
68 void initialize_buckets(std::vector<std::vector<size_t>> bucket_indices);
69
70 // This function is called when the forward function has produced an output,
71 // and the user wishes to reduce gradients in the backwards pass.
72 // If they don't, and wish to accumulate gradients before reducing them,
73 // a call to this function can simply be omitted.
74 void prepare_for_backward(const std::vector<at::Tensor>& outputs);
75
76 // Called at the begginning of forward() inside DistributedDataParallel,
77 // right now it caputures the starting time of forward in each iteration.
78 void prepare_for_forward();
79
80 // Returns the relative time in nanoseconds when gradients were ready,
81 // with respect to the time `prepare_for_backward` was called. The
82 // vector is for parameters for a single model replica.
83 std::vector<int64_t> get_backward_stats() const {
84 return backward_stats_;
85 }
86
87 // Registers a hook to the reducer. The hook is `CommHookInterface`
88 // type to allow both Python and CPP hooks. This function can only
89 // be called once before calling backward.
90 // Cannot combine with the call of `register_builtin_comm_hook`.
91 void register_comm_hook(std::unique_ptr<CommHookInterface> iface);
92
93 // Registers a built-in C++ comm hook to the reducer. This function can only
94 // be called once before calling backward.
95 // Cannot combine with the call of `register_comm_hook`.
96 void register_builtin_comm_hook(c10d::BuiltinCommHookType comm_hook_type);
97
98 // If set_to_none=True, reducer will set gradients to None in
99 // finalize_backward callback.
100 void set_grads_to_none(bool set_to_none);
101
102 // Runs allreduce or installed communication hook given GradBucket instance.
103 c10::intrusive_ptr<c10::ivalue::Future> run_comm_hook(
104 GradBucket& grad_bucket);
105
106 // Runs default allreduce hook.
107 c10::intrusive_ptr<c10::ivalue::Future> run_allreduce_hook(
108 GradBucket& grad_bucket);
109
110 // Returns gradient buckets in sequential order of buckets_. This is the order
111 // in which buckets are reduced across processes. If return_zero_tensors=true,
112 // will return zero tensors of the same shape instead of the true tensors.
113 std::vector<c10d::GradBucket> get_grad_buckets(
114 bool return_zero_tensors = true) const;
115
116 // Rebuild buckets based on rebuilt_params_ and rebuilt_param_indices_
117 // according to when tensors received grads in the backward pass.
118 // TODO this function makes broadcast communication call and
119 // could be overlapped with next forward() call, thus
120 // it could be async. Will make it async when rebuilding buckets for
121 // find_unused_parameters = true case, as we could rebuild buckets more than
122 // once for find_unused_parameters = true case, where subgraphs are trained
123 // and parameter indices order may change more frequently.
124 // For find_unused_parameters = false case, buckets are only rebuilt once,
125 // the performance cost is negligible. Returns true if the buckets were
126 // rebuilt.
127 bool rebuild_buckets();
128
129 // Install futures that should be awaited at end of backwards. Currently these
130 // are only used by user-defined custom buffer reduction hooks, but can be generalized
131 // to any user-originating futures that need to be awaited.
132 void install_futures(c10::List<c10::intrusive_ptr<c10::ivalue::Future>> futs);
133
134 // Returns true if we should rebuild buckets, else false. We only rebuild
135 // buckets once after the first iteration and never rebuild them if
136 // find_unused_parameters_.
137 inline bool should_rebuild_buckets() const {
138 return (static_graph_ || !find_unused_parameters_) && !has_rebuilt_bucket_;
139 }
140
141 // Pushes all parameters to be rebuilt.
142 void push_rebuilt_params_for_all_indices();
143
144 // Creates and sets ForwardPassWorkHandle given a Work and the
145 // corresponding tensor being reduced.
146 void set_forward_pass_work_handle(
147 c10::intrusive_ptr<c10d::Work> forwardPassWorkHandle,
148 bool useStaticWorldSize);
149
150 // Retrieve on-device tensors used to track locally unused parameters. It is
151 // a tensor where index i = 1 if the Variable with that index has been used.
152 at::Tensor get_local_used_map_on_device() const;
153
154 // An function for users to set sample_rate of collecting
155 // runtime stats. The time stats will be recorded for the
156 // first 10 iterations, after 10 iteratons time stats will be
157 // recorded once every "sample_rate" training iterations.
158 void set_ddp_runtime_logging_sample_rate(int sample_rate);
159
160 // Specify the training graph is static.
161 void set_static_graph();
162
163 // Delay all reduce to be after all gradients' calculation is complete.
164 void delay_all_reduce();
165
166 // Weak reference to associated DDP logger. The reference is weak to avoid
167 // refcycle between reducer and logger.
168 void set_logger(std::weak_ptr<c10d::Logger> logger);
169
170 // When graph is not explicitly set by user as static and has unused
171 // parameters, this will return whether the graph has been static until the
172 // current iteration, which means unused params set has not changed.
173 bool ddp_graph_static();
174
175 protected:
176 // Forward declaration.
177 struct Bucket;
178
179 void push_rebuilt_params(const size_t& index);
180
181 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
182 mutable std::mutex mutex_;
183 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
184 const std::vector<at::Tensor> params_;
185 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
186 const c10::intrusive_ptr<::c10d::ProcessGroup> process_group_;
187 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
188 std::vector<bool> expect_sparse_gradients_;
189
190 std::vector<std::shared_ptr<torch::autograd::Node>>
191 grad_accumulators_; // NOLINT(cppcoreguidelines-non-private-member-variables-in-classes)
192 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
193 std::unordered_map<torch::autograd::Node*, size_t> gradAccToVariableMap_;
194 std::vector<std::pair<uintptr_t, std::shared_ptr<torch::autograd::Node>>>
195 hooks_; // NOLINT(cppcoreguidelines-non-private-member-variables-in-classes)
196
197 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
198 bool expect_autograd_hooks_;
199 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
200 bool require_finalize_;
201 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
202 size_t next_bucket_;
203
204 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
205 bool has_marked_unused_parameters_;
206 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
207 const bool find_unused_parameters_;
208 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
209 const bool gradient_as_bucket_view_;
210 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
211 std::vector<size_t> unused_parameters_;
212 // Previous iteration's unused params, used for checking if unused parameters
213 // change between iterations. Only filled during the first backwards call.
214 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
215 std::vector<size_t> prev_iteration_unused_parameters_;
216 // Whether graph is static or not. When user does not explicitly set static
217 // graph, the only possible dynamism is set of unused parameters changing
218 // between iterations which is tracked by this flag.
219 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
220 bool ddp_graph_static_{true};
221 // Locally used parameter maps indicating if parameters are used locally
222 // during the current iteration or no_sync session if no_sync is on.
223 // Each map is a one-dim int32 tensor of number of parameters. These tensors
224 // are marked in autograd_hook to indicate the corresponding param has been
225 // used, and get allreduced in the end of backward step of current iteration
226 // or no_sync session for figuring out the globally unused parameters.
227 //
228 // local_used_map_: CPU tensor for bookkeeping locally used params
229 // local_used_map_dev_: dev tensor for reducing globally unused params
230 at::Tensor local_used_map_;
231 at::Tensor local_used_map_dev_;
232 // Indicate that reduction is done and D2H copy is done as well.
233 bool local_used_map_reduced_;
234
235 // Weak pointer to associated DDP logger.
236 std::weak_ptr<c10d::Logger> logger_;
237 // List of futures installed by Reducer::install_futures that should be awaited
238 // at the end of backwards pass.
239 c10::optional<c10::List<c10::intrusive_ptr<c10::ivalue::Future>>> installed_futures_{c10::nullopt};
240
241 // Work handle for allreduce on local_used_map_
242 c10::intrusive_ptr<c10d::Work> local_used_work_;
243
244 void mark_variable_ready_dense(size_t variable_index);
245
246 void mark_variable_ready_sparse(size_t variable_index);
247
248 void mark_variable_ready(size_t variable_index);
249
250 void autograd_hook(size_t index);
251
252 void mark_bucket_ready(size_t bucket_index);
253
254 void finalize_bucket_dense(Bucket& bucket);
255
256 void finalize_backward();
257
258 // Returns list of model parameters corresponding to the given bucket.
259 // bucket_index is a key to cache after buckets are rebuilt, after which this
260 // mapping never changes.
261 std::vector<at::Tensor> get_variables_for_bucket(
262 size_t bucket_index, const Bucket& bucket) const;
263
264 // Asserts that the reduction for the previous iteration has finished before
265 // rebuilding buckets or kicking off the next one.
266 void ensure_prior_reduction_finished();
267
268 // Broadcast rebuilt buckets from rank 0 to other ranks before initializing
269 // the buckets
270 void sync_bucket_indices(std::vector<std::vector<size_t>>& bucket_indices);
271
272 // We'd like to use DistAutogradContext::GradCallback here but dist autograd
273 // doesn't exist under Windows. So we just directly use the concrete type but
274 // to preserve and enforce our original intent we do a static assert when dist
275 // autograd is available.
276 using GradCallback = std::function<bool(at::Tensor&)>;
277#ifndef _WIN32
278 static_assert(
279 std::is_same<
280 GradCallback,
281 torch::distributed::autograd::DistAutogradContext::GradCallback>::
282 value,
283 "");
284#endif
285 void runGradCallbackForVariable(at::Tensor& variable, GradCallback&& cb);
286
287 // This function is called inside `initialize_buckets()`. It initializes both
288 // `bucket_views_in` and `bucket_views_out` with views for each variable's
289 // gradient into the bucket's flattened `gradients` tensor. Views serve as
290 // entry points to `copy_()` each grad's data in/out of the flattened
291 // `gradients` tensor.
292 void initialize_bucket_views(Bucket& bucket);
293
294 // This function is called inside `finalize_backward`, it happens only if
295 // DDP communication hook was registered to recreate just bucket_views_out
296 // with the result of `future_work`.
297 void populate_bucket_views_out(Bucket& bucket, at::Tensor& tensor);
298
299 // If gradient_as_bucket_view_ is false, after allreduce buckets,
300 // copy bucket results back to grads.
301 void copy_bucket_to_grad(
302 at::Tensor& variable,
303 Reducer::Bucket& bucket,
304 size_t intra_bucket_index,
305 bool global_unused);
306 // Check layout of grad and bucket_view before copying the grad to bucket.
307 void check_grad_layout(const at::Tensor& grad, const at::Tensor& bucket_view);
308
309 // A bucket contains [1..N] gradients to be reduced, where the gradients
310 // have the same dtype and device.
311 // Coalescing gradients together before reducing can result in lower overhead
312 // and/or faster time to completion. Coalescing requires the constituent
313 // gradients to have the same dtype and device, and the resulting flattened
314 // tensor uses that common dtype and device. The flattened tensor is filled
315 // as the corresponding gradients are computed (triggered by autograd hooks),
316 // and the buckets are reduced in a predetermined order consistent across
317 // processes.
318 struct Bucket {
319 // Gradients of the bucket flattened into a 1-dimensional tensor
320 at::Tensor gradients;
321
322 // Views into the `gradients` tensor for each individual gradient
323 // Each view is created with layout (size and stride) matching the
324 // gradient's expected layout (see the "Gradient Layout Contract" in
325 // torch/csrc/autograd/functions/accumulate_grad.h).
326 // `bucket_views_in[i].copy_(grad)` and `grad.copy_(bucket_views_out[i])`
327 // provide convenient ways to copy gradient data in/out of `gradients`,
328 // respectively.
329 // We keep both `bucket_views_in` and `bucket_views_out` because
330 // registering a DDP communication hook may re-initialize
331 // `bucket_views_out` with the value of the hook's `future_work` but we
332 // still need separate views into the bucket's original flattened gradient
333 // to copy in gradient data.
334 std::vector<at::Tensor> bucket_views_in;
335 std::vector<at::Tensor> bucket_views_out;
336
337 // Variables whose gradients are held in this bucket
338 // We use refcounted tensors here so that we can easily unflatten the
339 // bucket's flattened `gradients` tensor into the participating variables
340 // after reduction has completed.
341 std::vector<at::Tensor> variables;
342
343 // Per-variable offset/length into the flattened `gradients` tensor and
344 // the corresponding `GradBucket` instance for communication hooks
345 std::vector<size_t> offsets;
346 std::vector<size_t> lengths;
347
348 // Per-variable sizes slicing into the bucket's `gradients` tensor
349 std::vector<c10::IntArrayRef> sizes_vec;
350
351 // Number of gradients left to be computed before the bucket is ready to
352 // be reduced
353 size_t pending;
354
355 // Global indices of participating variables in the bucket
356 std::vector<size_t> variable_indices;
357
358 // Future work handle for DDP communication hook
359 // If no hook is registered, a temporary vanilla allreduce hook is used.
360 c10::intrusive_ptr<at::ivalue::Future> future_work;
361
362 // If this bucket should expect a single sparse gradient
363 // If `true`, then this implies that `bucket.variables.size() == 1`.
364 bool expect_sparse_gradient = false;
365
366 // TODO(@pietern)
367 // Memory copies from gradient tensors into the bucket are potentially
368 // done on different CUDA streams. We record an event for every copy
369 // so that we can synchronize with them prior to kicking off the reduction.
370 // std::vector<at::cuda::CUDAEvent> events;
371
372 };
373
374 std::vector<Bucket> buckets_;
375
376 // A variable locator locates a particular variable in the reducer's buckets
377 struct VariableLocator {
378 // Index of the bucket containing the variable in the `buckets_` vector
379 size_t bucket_index;
380 // Index of the variable in the bucket, which may be used consistently
381 // across `bucket_views_in`, `bucket_views_out`, `variables`, `offsets`,
382 // `lengths`, `sizes_vec`, and `variable_indices` in `Bucket`
383 size_t intra_bucket_index;
384
385 VariableLocator() = default;
386
387 VariableLocator(size_t bucket_index_, size_t intra_bucket_index_) : bucket_index(bucket_index_), intra_bucket_index(intra_bucket_index_) {}
388 };
389
390 // Map the index of a variable to its location in the bucket structure.
391 std::vector<VariableLocator> variable_locators_;
392
393 // track the number of iterations to synchronize grads in training so far.
394 long num_iterations_;
395 // track the number of buckets that have been ready for
396 // communication calls like allReduce or communication hooks.
397 int num_buckets_ready_;
398
399 // Timing information.
400 int64_t backward_compute_start_time_ = -1;
401 std::unique_ptr<Timer> timer_;
402
403 // We collect the relative timestamp of every gradient being ready
404 // when executing autograd. This can be used to derive a timeline of
405 // the point in time buckets were ready, or ideal bucket assignment/ordering.
406 std::vector<int64_t> backward_stats_;
407
408 bool should_collect_runtime_stats();
409 void record_forward_compute_start_time();
410 void record_backward_compute_start_time();
411 void record_backward_compute_end_time();
412 void record_backward_comm_start_time();
413 void record_backward_comm_end_time();
414
415 int get_ddp_runtime_logging_sample_rate();
416 int ddp_runtime_logging_sample_rate_ = kDDPRuntimeLoggingSampleRate;
417
418 bool is_multi_device_module_ = false;
419
420 // Following variables are to help build dynamic bucket order
421 bool has_rebuilt_bucket_;
422 std::vector<at::Tensor> rebuilt_params_;
423 std::vector<int64_t> rebuilt_param_indices_;
424 const int64_t bucket_bytes_cap_;
425
426#ifndef _WIN32
427 struct RpcContext {
428 using ContextPtr = torch::distributed::autograd::ContextPtr;
429 // The shared_ptr is to hold the context instance.
430 ContextPtr context_ptr_holder;
431 std::atomic<ContextPtr::element_type*> context_ptr{nullptr};
432
433 void set(ContextPtr&& new_context_ptr);
434 };
435 RpcContext rpc_context_;
436#endif
437
438 // A struct containing work handle and tensor for allreduce scheduled in
439 // forward pass, if applicable.
440 struct ForwardPassAllreduceWork {
441 c10::intrusive_ptr<c10d::Work> workHandle;
442 at::Tensor resultTensor;
443 // whether we should divide by the initial world_size or the no. of
444 // remaining DDP ranks.
445 bool useStaticWorldSize;
446 };
447
448 // Handle for the currently scheduled allreduce in the forward pass, if
449 // applicable.
450 ForwardPassAllreduceWork forwardPassWorkHandle_;
451
452 // Division factor for reduction of gradients.
453 // Equal to the process group size, with an exception of handling uneven
454 // input.
455 int div_factor_;
456
457 bool static_graph_;
458
459 // Key: size_t (index), Value: the number of times that a variable's
460 // autograd_hook() should be triggered before marking this variable's grad as
461 // ready for communication. Map will not change after 1st iteration.
462 std::unordered_map<size_t, int> numGradHooksTriggeredMap_;
463 // Key: size_t (index), Value: the number of times that a variable's
464 // autograd_hook() are left to be triggered before marking this variable's
465 // grad as ready for communication. Map will change after 1st iteration to
466 // track a grad is ready for communication or not.
467 std::unordered_map<size_t, int> numGradHooksTriggeredMapPerIteration_;
468
469 private:
470 // reset counting for buckets before backward starts
471 void reset_bucket_counting();
472 // search unused parameters beore backward starts
473 void search_unused_parameters(
474 const std::vector<torch::autograd::Variable>& outputs);
475 void set_divide_factor();
476 // kick off all reduce for the ready bucket
477 void all_reduce_bucket(Bucket& bucket);
478 // kick off all reduce to local used map, it can help find global unused
479 // parameters
480 void all_reduce_local_used_map();
481 // initialize locally used parameter maps
482 void initialize_local_used_map();
483 // get current cuda stream
484 const c10::Stream get_current_stream();
485 bool dynamic_graph_find_unused();
486 bool static_graph_first_iteration();
487 bool static_graph_after_first_iteration();
488
489 // comm_hook_ is used to access the DDP communication hook if registered.
490 std::unique_ptr<CommHookInterface> comm_hook_;
491 // Debug level setting. It is parsed once when Reducer is constructed, and
492 // remains the same across a single invocation of DDP training.
493 DebugLevel ddp_debug_level_;
494 // Mapping of variable index to fully qualified name of model to notify users
495 // about errors when certain parameters do not get gradient.
496 std::unordered_map<size_t, std::string> param_names_;
497 // Variable indices stored sequentially in order of when the gradient is ready
498 // for the current backwards pass.
499 std::vector<int> grad_ready_order_indices_;
500 // Bytes capacity of first bucket, can be configured by user
501 int64_t first_bucket_bytes_cap_;
502 // Per iteration set of parameter indices that have been marked ready.
503 std::unordered_set<size_t> perIterationReadyParams_;
504 // Retrieves parameter names that have not been marked as ready as part of
505 // previous iteration.
506 std::vector<std::string> getUnmarkedParamsForIteration();
507 // Retrives parameter indices that have not been marked as ready as part of
508 // previous iteration.
509 std::vector<size_t> getUnmarkedParamIndicesForIteration();
510 // Raises appropriate error if mark_variable_ready is called on the same
511 // variable twice, which is unexpected.
512 void checkAndRaiseMarkedTwiceError(size_t curVariableIndex);
513 // Retrieves parameter corresponding to the given VariableIndex.
514 at::Tensor& get_param_from_index(size_t index);
515
516 // Cached bucket index to model parameter mapping. Populated after buckets
517 // are rebuilt after which this mapping is static.
518 mutable std::unordered_map<size_t, std::vector<at::Tensor>> cached_variables_for_bucket_;
519
520 bool set_grads_to_none_{false};
521 friend class Logger;
522};
523
524// This is equivalent to take_tensors but returns indices into the
525// tensor list argument for bucket assignment. Also, it is aware
526// of device placement and will not allow buckets to span devices.
527// The index of tensors[i] assigned to bucket is tensor_indices[i],
528// when tensor_indices is empty, the index of tensors[i] assigned to
529// bucket is i.
530TORCH_API std::tuple<std::vector<std::vector<size_t>>, std::vector<size_t>>
531compute_bucket_assignment_by_size(
532 const std::vector<at::Tensor>& tensors,
533 const std::vector<size_t>& bucket_size,
534 const std::vector<bool>& expect_sparse_gradient = {},
535 const std::vector<int64_t>& tensor_indices = {},
536 const c10::optional<std::weak_ptr<c10d::Logger>>& logger = {});
537
538// Verify models across all processes are the same as model on rank 0 with
539// respect to no. of params and matching dtype/size/layout.
540TORCH_API void verify_params_across_processes(
541 const c10::intrusive_ptr<c10d::ProcessGroup>& process_group,
542 const std::vector<at::Tensor>& params,
543 const c10::optional<std::weak_ptr<c10d::Logger>>& logger);
544} // namespace c10d
545