1#include <torch/csrc/distributed/c10d/comm.hpp>
2
3#include <deque>
4
5#include <ATen/core/functional.h>
6#include <c10/util/irange.h>
7#include <torch/csrc/distributed/c10d/reducer.hpp>
8#include <torch/csrc/utils/tensor_flatten.h>
9
10namespace c10d {
11namespace {
12
13class BroadcastWork {
14 public:
15 BroadcastWork(
16 const c10::intrusive_ptr<c10d::ProcessGroup>& process_group,
17 std::vector<at::Tensor> bucket_tensors,
18 int root_rank = 0)
19 : bucket_tensors_(std::move(bucket_tensors)),
20 flat_tensor_({torch::utils::flatten_dense_tensors(bucket_tensors_)}) {
21 BroadcastOptions broadcastOptions;
22 broadcastOptions.rootRank = root_rank;
23 work_ = process_group->broadcast(flat_tensor_, broadcastOptions);
24 }
25
26 void finish() {
27 work_->wait();
28
29 // Copy the output of the broadcast operation back.
30 auto output_tensors = torch::utils::unflatten_dense_tensors(
31 flat_tensor_.front(), bucket_tensors_);
32 TORCH_INTERNAL_ASSERT(output_tensors.size() == bucket_tensors_.size());
33 for (const auto i : c10::irange(output_tensors.size())) {
34 // if output_tensor is empty, no need to copy it back,
35 // this can avoid error when both bucket_tensor and output_tensor
36 // are empty, but they have different shapes, see
37 // https://github.com/pytorch/pytorch/issues/87280
38 if (output_tensors[i].numel() != 0) {
39 bucket_tensors_[i].copy_(output_tensors[i], /*non_blocking=*/true);
40 }
41 }
42 }
43
44 protected:
45 // The list of tensors to broadcast. They are guaranteed to be
46 // placed on the same device and have the same dtype.
47 std::vector<at::Tensor> bucket_tensors_;
48
49 // The vector with a single flattened tensor containing the contents
50 // of the tensors in bucket_tensors_. It must be stored in a vector
51 // because c10d::ProcessGroup::broadcast takes a vector argument.
52 std::vector<at::Tensor> flat_tensor_;
53
54 private:
55 // The broadcast work that is kicked off upon construction.
56 c10::intrusive_ptr<c10d::Work> work_;
57};
58
59} // namespace
60
61// Broadcast many tensors to all processes in the process group.
62void broadcast_coalesced(
63 c10::intrusive_ptr<c10d::ProcessGroup> process_group,
64 at::TensorList tensors,
65 size_t buffer_size,
66 int rank) {
67 // Coalesce tensors into buckets taking into account the maximum buffer size.
68 // This routine is multi-device aware, so the tensors can be split across
69 // multiple devices and can contain a mix of CPU and CUDA tensors.
70 std::vector<std::vector<size_t>> buckets;
71 std::tie(buckets, std::ignore) =
72 compute_bucket_assignment_by_size(tensors.vec(), {buffer_size});
73
74 // Returns tensor at specified index in input tensor list.
75 const auto lookup = [&tensors](size_t index) { return tensors[index]; };
76
77 // We maintain a maximum of 2 in flight broadcast operations to avoid
78 // allocating too much memory (in case the specified tensors are very large).
79 std::deque<BroadcastWork> in_flight;
80 constexpr auto max_in_flight = 2;
81 for (const auto& bucket : buckets) {
82 if (in_flight.size() >= max_in_flight) {
83 in_flight.front().finish();
84 in_flight.pop_front();
85 }
86
87 in_flight.emplace_back(process_group, c10::fmap(bucket, lookup), rank);
88 }
89
90 while (!in_flight.empty()) {
91 in_flight.front().finish();
92 in_flight.pop_front();
93 }
94}
95
96std::vector<at::Tensor> GradBucket::getGradients() const {
97 std::vector<at::Tensor> per_parameter_tensors;
98 size_t num_parameters = offsets_.size();
99 per_parameter_tensors.reserve(num_parameters);
100 for (const auto i : c10::irange(num_parameters)) {
101 per_parameter_tensors.push_back(
102 buffer_.slice(0, offsets_[i], offsets_[i] + lengths_[i])
103 .view(sizes_vec_[i]));
104 }
105 return per_parameter_tensors;
106}
107namespace detail {
108
109at::Tensor parseCppCommHookResult(const c10::IValue& result) {
110 if (result.isPyObject()) {
111 std::vector<at::Tensor> tensors =
112 result.toPyObjectHolder()->extractTensors();
113 return tensors[0];
114 }
115 TORCH_INTERNAL_ASSERT(
116 result.isTensor() || result.isTensorList(),
117 "expected the hook result is either a Tensor or a TensorList found ",
118 result.tagKind());
119
120 if (result.isTensor()) {
121 return result.toTensor();
122 }
123
124 return result.toTensorVector()[0];
125}
126
127} // namespace detail
128
129} // namespace c10d
130