1 | #include <torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp> |
2 | |
3 | #ifdef USE_C10D_GLOO |
4 | |
5 | #include <c10/core/Allocator.h> |
6 | #include <c10/core/DeviceType.h> |
7 | #include <c10/core/ScalarType.h> |
8 | #include <c10/core/TensorOptions.h> |
9 | #include <c10/util/Exception.h> |
10 | #include <c10/util/Optional.h> |
11 | #include <c10/util/intrusive_ptr.h> |
12 | #include <c10/util/irange.h> |
13 | #include <torch/csrc/distributed/c10d/ProcessGroup.hpp> |
14 | #include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp> |
15 | #include <stdexcept> |
16 | #include <utility> |
17 | |
18 | namespace c10d { |
19 | |
20 | namespace { |
21 | // A container for information about a particular collective, including optype |
22 | // and input tensors (if applicable.) |
23 | struct CollectiveFingerPrint { |
24 | // Current collective's operation type. |
25 | OpType op_type_; |
26 | // Number of input tensors |
27 | std::size_t num_tensors_{}; |
28 | // input tensor data types |
29 | std::vector<int8_t> tensor_dtypes_; |
30 | // input tensor device types |
31 | std::vector<int8_t> tensor_device_types_; |
32 | // input tensor sizes |
33 | std::vector<std::vector<int64_t>> tensor_sizes_; |
34 | |
35 | explicit CollectiveFingerPrint( |
36 | OpType op_type, |
37 | const std::vector<at::Tensor>& input_tensors) |
38 | : op_type_(op_type), num_tensors_(input_tensors.size()) { |
39 | tensor_dtypes_.reserve(num_tensors_); |
40 | tensor_device_types_.reserve(num_tensors_); |
41 | tensor_sizes_.reserve(num_tensors_); |
42 | for (const at::Tensor& t : input_tensors) { |
43 | tensor_dtypes_.push_back(static_cast<int8_t>(t.dtype().toScalarType())); |
44 | tensor_device_types_.push_back(static_cast<int8_t>(t.device().type())); |
45 | tensor_sizes_.push_back(t.sizes().vec()); |
46 | } |
47 | } |
48 | |
49 | // Constructor for the data received from deserialized fingerprint |
50 | CollectiveFingerPrint( |
51 | OpType op_type, |
52 | std::vector<int8_t> tensor_dtypes, |
53 | std::vector<int8_t> tensor_device_types, |
54 | std::vector<std::vector<int64_t>> tensor_sizes) |
55 | : op_type_(op_type), |
56 | tensor_dtypes_(std::move(tensor_dtypes)), |
57 | tensor_device_types_(std::move(tensor_device_types)), |
58 | tensor_sizes_(std::move(tensor_sizes)) {} |
59 | |
60 | // Logs collective information in case of a failure. |
61 | friend std::ostream& operator<<( |
62 | std::ostream& output, |
63 | const CollectiveFingerPrint& collective_fingerprint); |
64 | |
65 | // Executes and verifies the collective fingerprint. |
66 | void verify(c10::intrusive_ptr<Backend> backend) { |
67 | at::Tensor serialized_tensor = serialize_fingerprint(); |
68 | std::vector<at::Tensor> inp{serialized_tensor}; |
69 | // First verify tensor shapes. This is needed because if e.g. tensor dim |
70 | // does not match across processes, directly verifying tensors will result |
71 | // in a crash during allgather, but we'd actually like to report a |
72 | // description about the inconsistency. Since the input is just a 1D tensor |
73 | // the shape will be a single int k_i and we need to make sure k_i is |
74 | // consistent across the whole world. |
75 | std::vector<at::Tensor> sp = c10d::getTensorShapes(inp); |
76 | verify_tensors(sp, backend); |
77 | // Now verify consistency for the actual tensor. |
78 | verify_tensors(inp, backend); |
79 | } |
80 | |
81 | // Takes a serialized fingerprint from |
82 | // CollectiveFingerPrint::serialize_fingerprint and deserializes it back to a |
83 | // CollectiveFingerPrint struct |
84 | CollectiveFingerPrint deserialize_fingerprint(at::Tensor serialized_tensor) { |
85 | OpType optype; |
86 | auto dtypes = std::vector<int8_t>(); |
87 | auto device_types = std::vector<int8_t>(); |
88 | auto sizes = std::vector<std::vector<int64_t>>(); |
89 | int index = 0; |
90 | // 1. OpType |
91 | optype = OpType(serialized_tensor[index].item<int>()); |
92 | index++; |
93 | |
94 | if (index < serialized_tensor.size(0)) { |
95 | // 2. Num tensors |
96 | int num_tensors = serialized_tensor[index].item<int>(); |
97 | index++; |
98 | dtypes.reserve(num_tensors); |
99 | device_types.reserve(num_tensors); |
100 | sizes.reserve(num_tensors); |
101 | |
102 | // 3. Tensor dtypes |
103 | for (int i = 0; i < num_tensors; i++) { |
104 | dtypes.push_back(serialized_tensor[index].item<int8_t>()); |
105 | index++; |
106 | } |
107 | // 4. Device types |
108 | for (int i = 0; i < num_tensors; i++) { |
109 | device_types.push_back(serialized_tensor[index].item<int8_t>()); |
110 | index++; |
111 | } |
112 | // 5. Tensor shapes |
113 | for (int i = 0; i < num_tensors; i++) { |
114 | // 5a. Shape size |
115 | int size = serialized_tensor[index].item<int>(); |
116 | index++; |
117 | // 5b. Shape |
118 | auto shapeVec = std::vector<int64_t>(); |
119 | shapeVec.reserve(size); |
120 | for (int j = 0; j < size; j++) { |
121 | shapeVec.push_back(serialized_tensor[index].item<int64_t>()); |
122 | index++; |
123 | } |
124 | sizes.push_back(shapeVec); |
125 | } |
126 | } |
127 | return CollectiveFingerPrint(optype, dtypes, device_types, sizes); |
128 | } |
129 | |
130 | private: |
131 | void verify_tensors( |
132 | std::vector<at::Tensor>& tensors_to_verify, |
133 | c10::intrusive_ptr<Backend>& backend) { |
134 | // Create output tensor data structure to pass into allgather. |
135 | std::vector<std::vector<at::Tensor>> output_tensors; |
136 | // output tensors: [<tensor 0 outputs>, <tensor 1 outputs>, ..., <tensor n |
137 | // outputs>] |
138 | output_tensors.reserve(tensors_to_verify.size()); |
139 | for (const auto& tensor_shape : tensors_to_verify) { |
140 | // Each rank has its own outputs shape, e.g. |
141 | // <tensor 0 outputs>: [<rank 0 tensor>, <rank 1 tensor>, ..., <rank n |
142 | // tensor>] |
143 | std::vector<at::Tensor> outputs; |
144 | outputs.reserve(backend->getSize()); |
145 | for (const auto i : c10::irange(backend->getSize())) { |
146 | std::ignore = i; // Suppress unused variable warning |
147 | outputs.emplace_back(at::zeros_like(tensor_shape)); |
148 | } |
149 | output_tensors.emplace_back(outputs); |
150 | } |
151 | // Allgather tensor shapes. |
152 | backend->allgather(output_tensors, tensors_to_verify)->wait(); |
153 | // Verify equivalence |
154 | for (const auto i : c10::irange(output_tensors.size())) { |
155 | const std::vector<at::Tensor> gathered_tensors = output_tensors[i]; |
156 | const at::Tensor reference_tensor = tensors_to_verify[i]; |
157 | for (int rank = 0; rank < gathered_tensors.size(); rank++) { |
158 | const auto& rank_tensor = gathered_tensors[rank]; |
159 | if (!rank_tensor.equal(reference_tensor)) { |
160 | CollectiveFingerPrint rank_fingerprint = |
161 | deserialize_fingerprint(rank_tensor); |
162 | std::stringstream ss; |
163 | ss << "Detected mismatch between collectives on ranks. Rank " |
164 | << backend->getRank() << " is running collective: " << *this |
165 | << ", but Rank " << rank |
166 | << " is running collective: " << rank_fingerprint << "." ; |
167 | TORCH_CHECK(false, ss.str()); |
168 | } |
169 | } |
170 | } |
171 | } |
172 | |
173 | // Serializes the information (op type, input shapes, data types, device |
174 | // types) about the collective fingerprint into a tensor |
175 | at::Tensor serialize_fingerprint() { |
176 | auto data = std::make_unique<std::vector<int64_t>>(); |
177 | // std::vector<int64_t> data; |
178 | // 1. OpType |
179 | data->push_back(static_cast<int64_t>(op_type_)); |
180 | // 2. Num tensors |
181 | data->push_back(static_cast<int64_t>(num_tensors_)); |
182 | // 3. Tensor dtypes |
183 | for (const auto& type : tensor_dtypes_) { |
184 | data->push_back(type); |
185 | } |
186 | // 4. Device types |
187 | for (const auto& d : tensor_device_types_) { |
188 | data->push_back(d); |
189 | } |
190 | // 5. Shapes |
191 | for (const auto& sizes : tensor_sizes_) { |
192 | data->push_back(sizes.size()); |
193 | for (const auto& s : sizes) { |
194 | data->push_back(s); |
195 | } |
196 | } |
197 | // Serialize data into tensor |
198 | int64_t data_size = data->size(); |
199 | // Need to release here and get the ptr due to C++ parameter evaluation |
200 | // order. |
201 | auto d = data.release(); |
202 | at::Tensor serialized_tensor = |
203 | at::for_blob(d->data(), {data_size}) |
204 | .context( |
205 | d, |
206 | [](void* ctx) { |
207 | delete static_cast<std::vector<int64_t>*>(ctx); |
208 | }) |
209 | .options(at::TensorOptions().dtype(at::kLong)) |
210 | .make_tensor(); |
211 | return serialized_tensor; |
212 | } |
213 | }; |
214 | |
215 | std::ostream& operator<<( |
216 | std::ostream& output, |
217 | const CollectiveFingerPrint& collective_fingerprint) { |
218 | std::string collectiveInfo; |
219 | if (collective_fingerprint.num_tensors_ != 0) { |
220 | // Convert dtype and device type info to string. |
221 | std::vector<std::string> dtype_strs; |
222 | std::vector<std::string> device_type_strs; |
223 | std::vector<std::string> size_strs; |
224 | dtype_strs.reserve(collective_fingerprint.tensor_dtypes_.size()); |
225 | for (const auto& tensor_dtype : collective_fingerprint.tensor_dtypes_) { |
226 | dtype_strs.emplace_back( |
227 | c10::toString(static_cast<at::ScalarType>(tensor_dtype))); |
228 | } |
229 | device_type_strs.reserve( |
230 | collective_fingerprint.tensor_device_types_.size()); |
231 | for (const auto& tensor_device_type : |
232 | collective_fingerprint.tensor_device_types_) { |
233 | device_type_strs.emplace_back( |
234 | c10::toString(static_cast<at::DeviceType>(tensor_device_type))); |
235 | } |
236 | if (!collective_fingerprint.tensor_sizes_.empty()) { |
237 | for (const auto& single_tensor_shape_num : |
238 | collective_fingerprint.tensor_sizes_[0]) { |
239 | size_strs.emplace_back(std::to_string(single_tensor_shape_num)); |
240 | } |
241 | } |
242 | |
243 | collectiveInfo = c10::str( |
244 | "CollectiveFingerPrint(" , |
245 | "OpType=" , |
246 | opTypeToString(collective_fingerprint.op_type_), |
247 | ", TensorShape=[" , |
248 | c10::Join(", " , size_strs), |
249 | "], TensorDtypes=" , |
250 | (dtype_strs), |
251 | ", TensorDeviceTypes=" , |
252 | (device_type_strs), |
253 | ")" ); |
254 | } else { |
255 | collectiveInfo = c10::str( |
256 | "CollectiveFingerPrint(" , |
257 | "OpType=" , |
258 | opTypeToString(collective_fingerprint.op_type_), |
259 | ")" ); |
260 | } |
261 | return output << collectiveInfo; |
262 | } |
263 | |
264 | } // namespace |
265 | |
266 | ProcessGroupWrapper::ProcessGroupWrapper( |
267 | c10::intrusive_ptr<Backend> backend, |
268 | c10::intrusive_ptr<Backend> glooBackend) |
269 | : Backend(backend->getRank(), backend->getSize()), |
270 | backend_(backend), |
271 | glooBackend_(std::move(glooBackend)) { |
272 | // Set the sequence number for the underlying process group. |
273 | backend_->setSequenceNumberForGroup(); |
274 | } |
275 | |
276 | const std::string ProcessGroupWrapper::getBackendName() const { |
277 | return backend_->getBackendName(); |
278 | } |
279 | |
280 | c10::intrusive_ptr<Work> ProcessGroupWrapper::broadcast( |
281 | std::vector<at::Tensor>& data, |
282 | const BroadcastOptions& opts) { |
283 | runCollectiveChecks(OpType::BROADCAST, data); |
284 | return backend_->broadcast(data, opts); |
285 | } |
286 | |
287 | c10::intrusive_ptr<Work> ProcessGroupWrapper::allreduce( |
288 | std::vector<at::Tensor>& data, |
289 | const AllreduceOptions& opts) { |
290 | runCollectiveChecks(OpType::ALLREDUCE, data); |
291 | return backend_->allreduce(data, opts); |
292 | } |
293 | |
294 | c10::intrusive_ptr<Work> ProcessGroupWrapper::allreduce_coalesced( |
295 | std::vector<at::Tensor>& tensors, |
296 | const AllreduceCoalescedOptions& opts) { |
297 | // NOTE: We don't enforce shape checking for allreduce_coalesced because |
298 | // the implementation itself does not enforce it we have tests that use |
299 | // inconsistent shapes, see python implementation in distributed_c10d for |
300 | // details. |
301 | runCollectiveChecks(OpType::ALLREDUCE_COALESCED, {}); |
302 | return backend_->allreduce_coalesced(tensors, opts); |
303 | } |
304 | |
305 | c10::intrusive_ptr<Work> ProcessGroupWrapper::reduce( |
306 | std::vector<at::Tensor>& tensors, |
307 | const ReduceOptions& opts) { |
308 | runCollectiveChecks(OpType::REDUCE, tensors); |
309 | return backend_->reduce(tensors, opts); |
310 | } |
311 | |
312 | c10::intrusive_ptr<Work> ProcessGroupWrapper::allgather( |
313 | std::vector<std::vector<at::Tensor>>& outputTensors, |
314 | std::vector<at::Tensor>& inputTensors, |
315 | const AllgatherOptions& opts) { |
316 | runCollectiveChecks(OpType::ALLGATHER, inputTensors); |
317 | return backend_->allgather(outputTensors, inputTensors, opts); |
318 | } |
319 | |
320 | c10::intrusive_ptr<Work> ProcessGroupWrapper::_allgather_base( |
321 | at::Tensor& outputBuffer, |
322 | at::Tensor& inputBuffer, |
323 | const AllgatherOptions& opts) { |
324 | std::vector<at::Tensor> inputTensors({inputBuffer}); |
325 | runCollectiveChecks(OpType::_ALLGATHER_BASE, inputTensors); |
326 | return backend_->_allgather_base(outputBuffer, inputBuffer, opts); |
327 | } |
328 | |
329 | c10::intrusive_ptr<Work> ProcessGroupWrapper::allgather_coalesced( |
330 | std::vector<std::vector<at::Tensor>>& outputTensorLists, |
331 | std::vector<at::Tensor>& inputTensors, |
332 | const AllgatherOptions& opts) { |
333 | // NOTE: We don't enforce shape checking for allgather_coalesced because |
334 | // the implementation itself does not enforce it we have tests that use |
335 | // inconsistent shapes, see python implementation in distributed_c10d for |
336 | // details. |
337 | runCollectiveChecks(OpType::ALLGATHER_COALESCED, {}); |
338 | return backend_->allgather_coalesced(outputTensorLists, inputTensors, opts); |
339 | } |
340 | |
341 | c10::intrusive_ptr<Work> ProcessGroupWrapper::gather( |
342 | std::vector<std::vector<at::Tensor>>& outputTensors, |
343 | std::vector<at::Tensor>& inputTensors, |
344 | const GatherOptions& opts) { |
345 | runCollectiveChecks(OpType::GATHER, inputTensors); |
346 | return backend_->gather(outputTensors, inputTensors, opts); |
347 | } |
348 | |
349 | c10::intrusive_ptr<Work> ProcessGroupWrapper::scatter( |
350 | std::vector<at::Tensor>& outputTensors, |
351 | std::vector<std::vector<at::Tensor>>& inputTensors, |
352 | const ScatterOptions& opts) { |
353 | runCollectiveChecks(OpType::SCATTER, outputTensors); |
354 | return backend_->scatter(outputTensors, inputTensors, opts); |
355 | } |
356 | |
357 | c10::intrusive_ptr<Work> ProcessGroupWrapper::reduce_scatter( |
358 | std::vector<at::Tensor>& outputTensors, |
359 | std::vector<std::vector<at::Tensor>>& inputTensors, |
360 | const ReduceScatterOptions& opts) { |
361 | runCollectiveChecks(OpType::REDUCE_SCATTER, outputTensors); |
362 | return backend_->reduce_scatter(outputTensors, inputTensors, opts); |
363 | } |
364 | |
365 | c10::intrusive_ptr<Work> ProcessGroupWrapper::alltoall_base( |
366 | at::Tensor& outputTensor, |
367 | at::Tensor& inputTensor, |
368 | std::vector<int64_t>& outputSplitSizes, |
369 | std::vector<int64_t>& inputSplitSizes, |
370 | const AllToAllOptions& opts) { |
371 | // alltoall supports uneven split, so don't enforce shape checking. |
372 | runCollectiveChecks(OpType::ALLTOALL_BASE, {}); |
373 | return backend_->alltoall_base( |
374 | outputTensor, inputTensor, outputSplitSizes, inputSplitSizes, opts); |
375 | } |
376 | |
377 | c10::intrusive_ptr<Work> ProcessGroupWrapper::alltoall( |
378 | std::vector<at::Tensor>& outputTensors, |
379 | std::vector<at::Tensor>& inputTensors, |
380 | const AllToAllOptions& opts) { |
381 | // alltoall supports uneven split, so don't enforce shape checking. |
382 | runCollectiveChecks(OpType::ALLTOALL, {}); |
383 | return backend_->alltoall(outputTensors, inputTensors, opts); |
384 | } |
385 | |
386 | void ProcessGroupWrapper::monitoredBarrier( |
387 | const BarrierOptions& opts, |
388 | bool waitAllRanks) { |
389 | return backend_->monitoredBarrier(opts, waitAllRanks); |
390 | } |
391 | |
392 | void ProcessGroupWrapper::setSequenceNumberForGroup() { |
393 | // Set underlying pg's sequence number if it is not set. |
394 | if (backend_->getSequenceNumberForGroup() == 0) { |
395 | // Set the sequence number for the underlying process group. |
396 | backend_->setSequenceNumberForGroup(); |
397 | } |
398 | } |
399 | |
400 | uint64_t ProcessGroupWrapper::getSequenceNumberForGroup() { |
401 | return backend_->getSequenceNumberForGroup(); |
402 | } |
403 | |
404 | c10::intrusive_ptr<Work> ProcessGroupWrapper::send( |
405 | std::vector<at::Tensor>& tensors, |
406 | int dstRank, |
407 | int tag) { |
408 | return backend_->send(tensors, dstRank, tag); |
409 | } |
410 | |
411 | c10::intrusive_ptr<Work> ProcessGroupWrapper::recv( |
412 | std::vector<at::Tensor>& tensors, |
413 | int srcRank, |
414 | int tag) { |
415 | return backend_->recv(tensors, srcRank, tag); |
416 | } |
417 | |
418 | c10::intrusive_ptr<Work> ProcessGroupWrapper::recvAnysource( |
419 | std::vector<at::Tensor>& tensors, |
420 | int tag) { |
421 | return backend_->recvAnysource(tensors, tag); |
422 | } |
423 | |
424 | c10::intrusive_ptr<Work> ProcessGroupWrapper::barrier( |
425 | const BarrierOptions& opts) { |
426 | runCollectiveChecks(OpType::BARRIER, {}); |
427 | return backend_->barrier(opts); |
428 | } |
429 | |
430 | c10::intrusive_ptr<Work> ProcessGroupWrapper::_reduce_scatter_base( |
431 | at::Tensor& outputBuffer, |
432 | at::Tensor& inputBuffer, |
433 | const ReduceScatterOptions& opts) { |
434 | runCollectiveChecks( |
435 | OpType::_REDUCE_SCATTER_BASE, {inputBuffer, outputBuffer}); |
436 | return backend_->_reduce_scatter_base(outputBuffer, inputBuffer, opts); |
437 | } |
438 | |
439 | c10::intrusive_ptr<Backend> ProcessGroupWrapper::getWrappedPg() const { |
440 | return backend_; |
441 | } |
442 | |
443 | void ProcessGroupWrapper::runCollectiveChecks( |
444 | OpType op_type, |
445 | const std::vector<at::Tensor>& tensors) const { |
446 | // first perform a monitored barrier to ensure all ranks can synchronize. |
447 | c10d::BarrierOptions options; |
448 | // TODO: we should use wrapped backend_'s timeout here, but C++ ProcessGroup |
449 | // API does not expose timeout. |
450 | auto finger_print = CollectiveFingerPrint(op_type, tensors); |
451 | try { |
452 | glooBackend_->monitoredBarrier(options, /* waitAllRanks */ true); |
453 | } catch (const std::runtime_error& e) { |
454 | // Attach collective info to the exception and re-raise. |
455 | std::stringstream ss; |
456 | ss << finger_print; |
457 | auto collective_info = ss.str(); |
458 | auto err_msg = c10::str( |
459 | "ProcessGroupWrapper: Monitored Barrier encountered error running collective: " , |
460 | collective_info, |
461 | ". Error: \n" , |
462 | e.what()); |
463 | TORCH_CHECK(false, err_msg); |
464 | } |
465 | // Will throw if an ill-formed collective is detected. |
466 | finger_print.verify(glooBackend_); |
467 | } |
468 | |
469 | } // namespace c10d |
470 | |
471 | #endif // USE_C10D_GLOO |
472 | |