1#pragma once
2
3#include <ATen/ATen.h>
4#include <ATen/core/ivalue.h>
5#include <torch/csrc/Export.h>
6#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
7#include <utility>
8
9namespace c10d {
10
11// Broadcast many tensors to all processes in the process group.
12TORCH_API void broadcast_coalesced(
13 c10::intrusive_ptr<c10d::ProcessGroup> process_group,
14 at::TensorList tensors,
15 size_t buffer_size,
16 int rank = 0);
17
18// This class passes bucket contents tensor to DDP communication hook.
19class TORCH_API GradBucket {
20 public:
21 explicit GradBucket(
22 size_t index,
23 size_t bucket_count,
24 at::Tensor tensor,
25 std::vector<size_t> offsets,
26 std::vector<size_t> lengths,
27 std::vector<c10::IntArrayRef> sizes_vec,
28 std::vector<at::Tensor> parameters)
29 : index_(index),
30 bucket_count_(bucket_count),
31 buffer_(std::move(tensor)),
32 offsets_(std::move(offsets)),
33 lengths_(std::move(lengths)),
34 sizes_vec_(std::move(sizes_vec)),
35 parameters_(std::move(parameters)) {}
36
37 // Returns the index of the bucket, which is unique across all the buckets.
38 size_t getIndex() const {
39 return index_;
40 }
41
42 const at::Tensor& getBuffer() const {
43 return buffer_;
44 }
45
46 // Returns a mutable buffer compared with the above method.
47 at::Tensor& getBufferRef() {
48 return buffer_;
49 }
50
51 // Overwrites the buffer at a specific index.
52 void setBuffer(at::Tensor& buffer) {
53 buffer_ = buffer;
54 }
55
56 // Each tensor in the list that getGradients corresponds to a
57 // parameter.
58 std::vector<at::Tensor> getGradients() const;
59
60 // Returns model parameters belonging to this bucket. They are returned in the
61 // same order as gradient tensors via getGradients(). For example,
62 // getParameters[i] will have its gradient stored in
63 // getGradients[i]
64 const std::vector<at::Tensor> getParameters() const {
65 return parameters_;
66 }
67
68 // Returns whther this bucket is the last bucket to allreduce in an iteration.
69 bool isLast() const {
70 return index_ == bucket_count_ - 1;
71 }
72
73 private:
74 size_t index_;
75 size_t bucket_count_;
76 at::Tensor buffer_;
77
78 // Per-variable info in buffer_.
79 std::vector<size_t> offsets_;
80 std::vector<size_t> lengths_;
81 std::vector<c10::IntArrayRef> sizes_vec_;
82 // Model parameters for this bucket.
83 const std::vector<at::Tensor> parameters_;
84};
85
86// Base class of both `PythonCommHook` and `CppCommHook`.
87// Requires implementing 1) `runHook` method that communicates gradients
88// asynchronously, and 2) `parseHookResult` method that converts the hook
89// result into a tensor.
90class TORCH_API CommHookInterface {
91 public:
92 virtual ~CommHookInterface() = default;
93
94 // Passes the input grad bucket to the registered communication hook.
95 // Once the tensor in the bucket are ready, kicks off the hook asynchronously
96 // and returns a future that holds the communication results.
97 virtual c10::intrusive_ptr<c10::ivalue::Future> runHook(
98 GradBucket& bucket) = 0;
99
100 // Returns the resulting tensor once the communication hook result is
101 // ready. The resulting tensor will then be copied to the grads of
102 // individual parameters.
103 virtual at::Tensor parseHookResult(
104 const c10::IValue& result) = 0;
105};
106
107namespace detail {
108// This helper function is called both by CppCommHookInterface below and inside
109// reducer.
110 at::Tensor parseCppCommHookResult(const c10::IValue& result);
111} // namespace detail
112
113// This CppCommHook interface only requires implementing runHook method that
114// potentially uses a state.
115template <typename T>
116class CppCommHookInterface : public CommHookInterface {
117 public:
118 explicit CppCommHookInterface(T state) : state_(std::move(state)) {}
119
120 ~CppCommHookInterface() override = default;
121
122 at::Tensor parseHookResult(const c10::IValue& result) override {
123 return detail::parseCppCommHookResult(result);
124 }
125
126 protected:
127 T state_;
128};
129
130} // namespace c10d
131