1 | #pragma once |
2 | |
3 | #include <ATen/ATen.h> |
4 | #include <ATen/core/ivalue.h> |
5 | #include <torch/csrc/Export.h> |
6 | #include <torch/csrc/distributed/c10d/ProcessGroup.hpp> |
7 | #include <utility> |
8 | |
9 | namespace c10d { |
10 | |
11 | // Broadcast many tensors to all processes in the process group. |
12 | TORCH_API void broadcast_coalesced( |
13 | c10::intrusive_ptr<c10d::ProcessGroup> process_group, |
14 | at::TensorList tensors, |
15 | size_t buffer_size, |
16 | int rank = 0); |
17 | |
18 | // This class passes bucket contents tensor to DDP communication hook. |
19 | class TORCH_API GradBucket { |
20 | public: |
21 | explicit GradBucket( |
22 | size_t index, |
23 | size_t bucket_count, |
24 | at::Tensor tensor, |
25 | std::vector<size_t> offsets, |
26 | std::vector<size_t> lengths, |
27 | std::vector<c10::IntArrayRef> sizes_vec, |
28 | std::vector<at::Tensor> parameters) |
29 | : index_(index), |
30 | bucket_count_(bucket_count), |
31 | buffer_(std::move(tensor)), |
32 | offsets_(std::move(offsets)), |
33 | lengths_(std::move(lengths)), |
34 | sizes_vec_(std::move(sizes_vec)), |
35 | parameters_(std::move(parameters)) {} |
36 | |
37 | // Returns the index of the bucket, which is unique across all the buckets. |
38 | size_t getIndex() const { |
39 | return index_; |
40 | } |
41 | |
42 | const at::Tensor& getBuffer() const { |
43 | return buffer_; |
44 | } |
45 | |
46 | // Returns a mutable buffer compared with the above method. |
47 | at::Tensor& getBufferRef() { |
48 | return buffer_; |
49 | } |
50 | |
51 | // Overwrites the buffer at a specific index. |
52 | void setBuffer(at::Tensor& buffer) { |
53 | buffer_ = buffer; |
54 | } |
55 | |
56 | // Each tensor in the list that getGradients corresponds to a |
57 | // parameter. |
58 | std::vector<at::Tensor> getGradients() const; |
59 | |
60 | // Returns model parameters belonging to this bucket. They are returned in the |
61 | // same order as gradient tensors via getGradients(). For example, |
62 | // getParameters[i] will have its gradient stored in |
63 | // getGradients[i] |
64 | const std::vector<at::Tensor> getParameters() const { |
65 | return parameters_; |
66 | } |
67 | |
68 | // Returns whther this bucket is the last bucket to allreduce in an iteration. |
69 | bool isLast() const { |
70 | return index_ == bucket_count_ - 1; |
71 | } |
72 | |
73 | private: |
74 | size_t index_; |
75 | size_t bucket_count_; |
76 | at::Tensor buffer_; |
77 | |
78 | // Per-variable info in buffer_. |
79 | std::vector<size_t> offsets_; |
80 | std::vector<size_t> lengths_; |
81 | std::vector<c10::IntArrayRef> sizes_vec_; |
82 | // Model parameters for this bucket. |
83 | const std::vector<at::Tensor> parameters_; |
84 | }; |
85 | |
86 | // Base class of both `PythonCommHook` and `CppCommHook`. |
87 | // Requires implementing 1) `runHook` method that communicates gradients |
88 | // asynchronously, and 2) `parseHookResult` method that converts the hook |
89 | // result into a tensor. |
90 | class TORCH_API CommHookInterface { |
91 | public: |
92 | virtual ~CommHookInterface() = default; |
93 | |
94 | // Passes the input grad bucket to the registered communication hook. |
95 | // Once the tensor in the bucket are ready, kicks off the hook asynchronously |
96 | // and returns a future that holds the communication results. |
97 | virtual c10::intrusive_ptr<c10::ivalue::Future> runHook( |
98 | GradBucket& bucket) = 0; |
99 | |
100 | // Returns the resulting tensor once the communication hook result is |
101 | // ready. The resulting tensor will then be copied to the grads of |
102 | // individual parameters. |
103 | virtual at::Tensor parseHookResult( |
104 | const c10::IValue& result) = 0; |
105 | }; |
106 | |
107 | namespace detail { |
108 | // This helper function is called both by CppCommHookInterface below and inside |
109 | // reducer. |
110 | at::Tensor parseCppCommHookResult(const c10::IValue& result); |
111 | } // namespace detail |
112 | |
113 | // This CppCommHook interface only requires implementing runHook method that |
114 | // potentially uses a state. |
115 | template <typename T> |
116 | class CppCommHookInterface : public CommHookInterface { |
117 | public: |
118 | explicit CppCommHookInterface(T state) : state_(std::move(state)) {} |
119 | |
120 | ~CppCommHookInterface() override = default; |
121 | |
122 | at::Tensor parseHookResult(const c10::IValue& result) override { |
123 | return detail::parseCppCommHookResult(result); |
124 | } |
125 | |
126 | protected: |
127 | T state_; |
128 | }; |
129 | |
130 | } // namespace c10d |
131 | |