1#include <torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp>
2
3#ifdef USE_C10D_GLOO
4
5#include <c10/core/Allocator.h>
6#include <c10/core/DeviceType.h>
7#include <c10/core/ScalarType.h>
8#include <c10/core/TensorOptions.h>
9#include <c10/util/Exception.h>
10#include <c10/util/Optional.h>
11#include <c10/util/intrusive_ptr.h>
12#include <c10/util/irange.h>
13#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
14#include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
15#include <stdexcept>
16#include <utility>
17
18namespace c10d {
19
20namespace {
21// A container for information about a particular collective, including optype
22// and input tensors (if applicable.)
23struct CollectiveFingerPrint {
24 // Current collective's operation type.
25 OpType op_type_;
26 // Number of input tensors
27 std::size_t num_tensors_{};
28 // input tensor data types
29 std::vector<int8_t> tensor_dtypes_;
30 // input tensor device types
31 std::vector<int8_t> tensor_device_types_;
32 // input tensor sizes
33 std::vector<std::vector<int64_t>> tensor_sizes_;
34
35 explicit CollectiveFingerPrint(
36 OpType op_type,
37 const std::vector<at::Tensor>& input_tensors)
38 : op_type_(op_type), num_tensors_(input_tensors.size()) {
39 tensor_dtypes_.reserve(num_tensors_);
40 tensor_device_types_.reserve(num_tensors_);
41 tensor_sizes_.reserve(num_tensors_);
42 for (const at::Tensor& t : input_tensors) {
43 tensor_dtypes_.push_back(static_cast<int8_t>(t.dtype().toScalarType()));
44 tensor_device_types_.push_back(static_cast<int8_t>(t.device().type()));
45 tensor_sizes_.push_back(t.sizes().vec());
46 }
47 }
48
49 // Constructor for the data received from deserialized fingerprint
50 CollectiveFingerPrint(
51 OpType op_type,
52 std::vector<int8_t> tensor_dtypes,
53 std::vector<int8_t> tensor_device_types,
54 std::vector<std::vector<int64_t>> tensor_sizes)
55 : op_type_(op_type),
56 tensor_dtypes_(std::move(tensor_dtypes)),
57 tensor_device_types_(std::move(tensor_device_types)),
58 tensor_sizes_(std::move(tensor_sizes)) {}
59
60 // Logs collective information in case of a failure.
61 friend std::ostream& operator<<(
62 std::ostream& output,
63 const CollectiveFingerPrint& collective_fingerprint);
64
65 // Executes and verifies the collective fingerprint.
66 void verify(c10::intrusive_ptr<Backend> backend) {
67 at::Tensor serialized_tensor = serialize_fingerprint();
68 std::vector<at::Tensor> inp{serialized_tensor};
69 // First verify tensor shapes. This is needed because if e.g. tensor dim
70 // does not match across processes, directly verifying tensors will result
71 // in a crash during allgather, but we'd actually like to report a
72 // description about the inconsistency. Since the input is just a 1D tensor
73 // the shape will be a single int k_i and we need to make sure k_i is
74 // consistent across the whole world.
75 std::vector<at::Tensor> sp = c10d::getTensorShapes(inp);
76 verify_tensors(sp, backend);
77 // Now verify consistency for the actual tensor.
78 verify_tensors(inp, backend);
79 }
80
81 // Takes a serialized fingerprint from
82 // CollectiveFingerPrint::serialize_fingerprint and deserializes it back to a
83 // CollectiveFingerPrint struct
84 CollectiveFingerPrint deserialize_fingerprint(at::Tensor serialized_tensor) {
85 OpType optype;
86 auto dtypes = std::vector<int8_t>();
87 auto device_types = std::vector<int8_t>();
88 auto sizes = std::vector<std::vector<int64_t>>();
89 int index = 0;
90 // 1. OpType
91 optype = OpType(serialized_tensor[index].item<int>());
92 index++;
93
94 if (index < serialized_tensor.size(0)) {
95 // 2. Num tensors
96 int num_tensors = serialized_tensor[index].item<int>();
97 index++;
98 dtypes.reserve(num_tensors);
99 device_types.reserve(num_tensors);
100 sizes.reserve(num_tensors);
101
102 // 3. Tensor dtypes
103 for (int i = 0; i < num_tensors; i++) {
104 dtypes.push_back(serialized_tensor[index].item<int8_t>());
105 index++;
106 }
107 // 4. Device types
108 for (int i = 0; i < num_tensors; i++) {
109 device_types.push_back(serialized_tensor[index].item<int8_t>());
110 index++;
111 }
112 // 5. Tensor shapes
113 for (int i = 0; i < num_tensors; i++) {
114 // 5a. Shape size
115 int size = serialized_tensor[index].item<int>();
116 index++;
117 // 5b. Shape
118 auto shapeVec = std::vector<int64_t>();
119 shapeVec.reserve(size);
120 for (int j = 0; j < size; j++) {
121 shapeVec.push_back(serialized_tensor[index].item<int64_t>());
122 index++;
123 }
124 sizes.push_back(shapeVec);
125 }
126 }
127 return CollectiveFingerPrint(optype, dtypes, device_types, sizes);
128 }
129
130 private:
131 void verify_tensors(
132 std::vector<at::Tensor>& tensors_to_verify,
133 c10::intrusive_ptr<Backend>& backend) {
134 // Create output tensor data structure to pass into allgather.
135 std::vector<std::vector<at::Tensor>> output_tensors;
136 // output tensors: [<tensor 0 outputs>, <tensor 1 outputs>, ..., <tensor n
137 // outputs>]
138 output_tensors.reserve(tensors_to_verify.size());
139 for (const auto& tensor_shape : tensors_to_verify) {
140 // Each rank has its own outputs shape, e.g.
141 // <tensor 0 outputs>: [<rank 0 tensor>, <rank 1 tensor>, ..., <rank n
142 // tensor>]
143 std::vector<at::Tensor> outputs;
144 outputs.reserve(backend->getSize());
145 for (const auto i : c10::irange(backend->getSize())) {
146 std::ignore = i; // Suppress unused variable warning
147 outputs.emplace_back(at::zeros_like(tensor_shape));
148 }
149 output_tensors.emplace_back(outputs);
150 }
151 // Allgather tensor shapes.
152 backend->allgather(output_tensors, tensors_to_verify)->wait();
153 // Verify equivalence
154 for (const auto i : c10::irange(output_tensors.size())) {
155 const std::vector<at::Tensor> gathered_tensors = output_tensors[i];
156 const at::Tensor reference_tensor = tensors_to_verify[i];
157 for (int rank = 0; rank < gathered_tensors.size(); rank++) {
158 const auto& rank_tensor = gathered_tensors[rank];
159 if (!rank_tensor.equal(reference_tensor)) {
160 CollectiveFingerPrint rank_fingerprint =
161 deserialize_fingerprint(rank_tensor);
162 std::stringstream ss;
163 ss << "Detected mismatch between collectives on ranks. Rank "
164 << backend->getRank() << " is running collective: " << *this
165 << ", but Rank " << rank
166 << " is running collective: " << rank_fingerprint << ".";
167 TORCH_CHECK(false, ss.str());
168 }
169 }
170 }
171 }
172
173 // Serializes the information (op type, input shapes, data types, device
174 // types) about the collective fingerprint into a tensor
175 at::Tensor serialize_fingerprint() {
176 auto data = std::make_unique<std::vector<int64_t>>();
177 // std::vector<int64_t> data;
178 // 1. OpType
179 data->push_back(static_cast<int64_t>(op_type_));
180 // 2. Num tensors
181 data->push_back(static_cast<int64_t>(num_tensors_));
182 // 3. Tensor dtypes
183 for (const auto& type : tensor_dtypes_) {
184 data->push_back(type);
185 }
186 // 4. Device types
187 for (const auto& d : tensor_device_types_) {
188 data->push_back(d);
189 }
190 // 5. Shapes
191 for (const auto& sizes : tensor_sizes_) {
192 data->push_back(sizes.size());
193 for (const auto& s : sizes) {
194 data->push_back(s);
195 }
196 }
197 // Serialize data into tensor
198 int64_t data_size = data->size();
199 // Need to release here and get the ptr due to C++ parameter evaluation
200 // order.
201 auto d = data.release();
202 at::Tensor serialized_tensor =
203 at::for_blob(d->data(), {data_size})
204 .context(
205 d,
206 [](void* ctx) {
207 delete static_cast<std::vector<int64_t>*>(ctx);
208 })
209 .options(at::TensorOptions().dtype(at::kLong))
210 .make_tensor();
211 return serialized_tensor;
212 }
213};
214
215std::ostream& operator<<(
216 std::ostream& output,
217 const CollectiveFingerPrint& collective_fingerprint) {
218 std::string collectiveInfo;
219 if (collective_fingerprint.num_tensors_ != 0) {
220 // Convert dtype and device type info to string.
221 std::vector<std::string> dtype_strs;
222 std::vector<std::string> device_type_strs;
223 std::vector<std::string> size_strs;
224 dtype_strs.reserve(collective_fingerprint.tensor_dtypes_.size());
225 for (const auto& tensor_dtype : collective_fingerprint.tensor_dtypes_) {
226 dtype_strs.emplace_back(
227 c10::toString(static_cast<at::ScalarType>(tensor_dtype)));
228 }
229 device_type_strs.reserve(
230 collective_fingerprint.tensor_device_types_.size());
231 for (const auto& tensor_device_type :
232 collective_fingerprint.tensor_device_types_) {
233 device_type_strs.emplace_back(
234 c10::toString(static_cast<at::DeviceType>(tensor_device_type)));
235 }
236 if (!collective_fingerprint.tensor_sizes_.empty()) {
237 for (const auto& single_tensor_shape_num :
238 collective_fingerprint.tensor_sizes_[0]) {
239 size_strs.emplace_back(std::to_string(single_tensor_shape_num));
240 }
241 }
242
243 collectiveInfo = c10::str(
244 "CollectiveFingerPrint(",
245 "OpType=",
246 opTypeToString(collective_fingerprint.op_type_),
247 ", TensorShape=[",
248 c10::Join(", ", size_strs),
249 "], TensorDtypes=",
250 (dtype_strs),
251 ", TensorDeviceTypes=",
252 (device_type_strs),
253 ")");
254 } else {
255 collectiveInfo = c10::str(
256 "CollectiveFingerPrint(",
257 "OpType=",
258 opTypeToString(collective_fingerprint.op_type_),
259 ")");
260 }
261 return output << collectiveInfo;
262}
263
264} // namespace
265
266ProcessGroupWrapper::ProcessGroupWrapper(
267 c10::intrusive_ptr<Backend> backend,
268 c10::intrusive_ptr<Backend> glooBackend)
269 : Backend(backend->getRank(), backend->getSize()),
270 backend_(backend),
271 glooBackend_(std::move(glooBackend)) {
272 // Set the sequence number for the underlying process group.
273 backend_->setSequenceNumberForGroup();
274}
275
276const std::string ProcessGroupWrapper::getBackendName() const {
277 return backend_->getBackendName();
278}
279
280c10::intrusive_ptr<Work> ProcessGroupWrapper::broadcast(
281 std::vector<at::Tensor>& data,
282 const BroadcastOptions& opts) {
283 runCollectiveChecks(OpType::BROADCAST, data);
284 return backend_->broadcast(data, opts);
285}
286
287c10::intrusive_ptr<Work> ProcessGroupWrapper::allreduce(
288 std::vector<at::Tensor>& data,
289 const AllreduceOptions& opts) {
290 runCollectiveChecks(OpType::ALLREDUCE, data);
291 return backend_->allreduce(data, opts);
292}
293
294c10::intrusive_ptr<Work> ProcessGroupWrapper::allreduce_coalesced(
295 std::vector<at::Tensor>& tensors,
296 const AllreduceCoalescedOptions& opts) {
297 // NOTE: We don't enforce shape checking for allreduce_coalesced because
298 // the implementation itself does not enforce it we have tests that use
299 // inconsistent shapes, see python implementation in distributed_c10d for
300 // details.
301 runCollectiveChecks(OpType::ALLREDUCE_COALESCED, {});
302 return backend_->allreduce_coalesced(tensors, opts);
303}
304
305c10::intrusive_ptr<Work> ProcessGroupWrapper::reduce(
306 std::vector<at::Tensor>& tensors,
307 const ReduceOptions& opts) {
308 runCollectiveChecks(OpType::REDUCE, tensors);
309 return backend_->reduce(tensors, opts);
310}
311
312c10::intrusive_ptr<Work> ProcessGroupWrapper::allgather(
313 std::vector<std::vector<at::Tensor>>& outputTensors,
314 std::vector<at::Tensor>& inputTensors,
315 const AllgatherOptions& opts) {
316 runCollectiveChecks(OpType::ALLGATHER, inputTensors);
317 return backend_->allgather(outputTensors, inputTensors, opts);
318}
319
320c10::intrusive_ptr<Work> ProcessGroupWrapper::_allgather_base(
321 at::Tensor& outputBuffer,
322 at::Tensor& inputBuffer,
323 const AllgatherOptions& opts) {
324 std::vector<at::Tensor> inputTensors({inputBuffer});
325 runCollectiveChecks(OpType::_ALLGATHER_BASE, inputTensors);
326 return backend_->_allgather_base(outputBuffer, inputBuffer, opts);
327}
328
329c10::intrusive_ptr<Work> ProcessGroupWrapper::allgather_coalesced(
330 std::vector<std::vector<at::Tensor>>& outputTensorLists,
331 std::vector<at::Tensor>& inputTensors,
332 const AllgatherOptions& opts) {
333 // NOTE: We don't enforce shape checking for allgather_coalesced because
334 // the implementation itself does not enforce it we have tests that use
335 // inconsistent shapes, see python implementation in distributed_c10d for
336 // details.
337 runCollectiveChecks(OpType::ALLGATHER_COALESCED, {});
338 return backend_->allgather_coalesced(outputTensorLists, inputTensors, opts);
339}
340
341c10::intrusive_ptr<Work> ProcessGroupWrapper::gather(
342 std::vector<std::vector<at::Tensor>>& outputTensors,
343 std::vector<at::Tensor>& inputTensors,
344 const GatherOptions& opts) {
345 runCollectiveChecks(OpType::GATHER, inputTensors);
346 return backend_->gather(outputTensors, inputTensors, opts);
347}
348
349c10::intrusive_ptr<Work> ProcessGroupWrapper::scatter(
350 std::vector<at::Tensor>& outputTensors,
351 std::vector<std::vector<at::Tensor>>& inputTensors,
352 const ScatterOptions& opts) {
353 runCollectiveChecks(OpType::SCATTER, outputTensors);
354 return backend_->scatter(outputTensors, inputTensors, opts);
355}
356
357c10::intrusive_ptr<Work> ProcessGroupWrapper::reduce_scatter(
358 std::vector<at::Tensor>& outputTensors,
359 std::vector<std::vector<at::Tensor>>& inputTensors,
360 const ReduceScatterOptions& opts) {
361 runCollectiveChecks(OpType::REDUCE_SCATTER, outputTensors);
362 return backend_->reduce_scatter(outputTensors, inputTensors, opts);
363}
364
365c10::intrusive_ptr<Work> ProcessGroupWrapper::alltoall_base(
366 at::Tensor& outputTensor,
367 at::Tensor& inputTensor,
368 std::vector<int64_t>& outputSplitSizes,
369 std::vector<int64_t>& inputSplitSizes,
370 const AllToAllOptions& opts) {
371 // alltoall supports uneven split, so don't enforce shape checking.
372 runCollectiveChecks(OpType::ALLTOALL_BASE, {});
373 return backend_->alltoall_base(
374 outputTensor, inputTensor, outputSplitSizes, inputSplitSizes, opts);
375}
376
377c10::intrusive_ptr<Work> ProcessGroupWrapper::alltoall(
378 std::vector<at::Tensor>& outputTensors,
379 std::vector<at::Tensor>& inputTensors,
380 const AllToAllOptions& opts) {
381 // alltoall supports uneven split, so don't enforce shape checking.
382 runCollectiveChecks(OpType::ALLTOALL, {});
383 return backend_->alltoall(outputTensors, inputTensors, opts);
384}
385
386void ProcessGroupWrapper::monitoredBarrier(
387 const BarrierOptions& opts,
388 bool waitAllRanks) {
389 return backend_->monitoredBarrier(opts, waitAllRanks);
390}
391
392void ProcessGroupWrapper::setSequenceNumberForGroup() {
393 // Set underlying pg's sequence number if it is not set.
394 if (backend_->getSequenceNumberForGroup() == 0) {
395 // Set the sequence number for the underlying process group.
396 backend_->setSequenceNumberForGroup();
397 }
398}
399
400uint64_t ProcessGroupWrapper::getSequenceNumberForGroup() {
401 return backend_->getSequenceNumberForGroup();
402}
403
404c10::intrusive_ptr<Work> ProcessGroupWrapper::send(
405 std::vector<at::Tensor>& tensors,
406 int dstRank,
407 int tag) {
408 return backend_->send(tensors, dstRank, tag);
409}
410
411c10::intrusive_ptr<Work> ProcessGroupWrapper::recv(
412 std::vector<at::Tensor>& tensors,
413 int srcRank,
414 int tag) {
415 return backend_->recv(tensors, srcRank, tag);
416}
417
418c10::intrusive_ptr<Work> ProcessGroupWrapper::recvAnysource(
419 std::vector<at::Tensor>& tensors,
420 int tag) {
421 return backend_->recvAnysource(tensors, tag);
422}
423
424c10::intrusive_ptr<Work> ProcessGroupWrapper::barrier(
425 const BarrierOptions& opts) {
426 runCollectiveChecks(OpType::BARRIER, {});
427 return backend_->barrier(opts);
428}
429
430c10::intrusive_ptr<Work> ProcessGroupWrapper::_reduce_scatter_base(
431 at::Tensor& outputBuffer,
432 at::Tensor& inputBuffer,
433 const ReduceScatterOptions& opts) {
434 runCollectiveChecks(
435 OpType::_REDUCE_SCATTER_BASE, {inputBuffer, outputBuffer});
436 return backend_->_reduce_scatter_base(outputBuffer, inputBuffer, opts);
437}
438
439c10::intrusive_ptr<Backend> ProcessGroupWrapper::getWrappedPg() const {
440 return backend_;
441}
442
443void ProcessGroupWrapper::runCollectiveChecks(
444 OpType op_type,
445 const std::vector<at::Tensor>& tensors) const {
446 // first perform a monitored barrier to ensure all ranks can synchronize.
447 c10d::BarrierOptions options;
448 // TODO: we should use wrapped backend_'s timeout here, but C++ ProcessGroup
449 // API does not expose timeout.
450 auto finger_print = CollectiveFingerPrint(op_type, tensors);
451 try {
452 glooBackend_->monitoredBarrier(options, /* waitAllRanks */ true);
453 } catch (const std::runtime_error& e) {
454 // Attach collective info to the exception and re-raise.
455 std::stringstream ss;
456 ss << finger_print;
457 auto collective_info = ss.str();
458 auto err_msg = c10::str(
459 "ProcessGroupWrapper: Monitored Barrier encountered error running collective: ",
460 collective_info,
461 ". Error: \n",
462 e.what());
463 TORCH_CHECK(false, err_msg);
464 }
465 // Will throw if an ill-formed collective is detected.
466 finger_print.verify(glooBackend_);
467}
468
469} // namespace c10d
470
471#endif // USE_C10D_GLOO
472