1 | #pragma once |
2 | |
3 | #include <ATen/ATen.h> |
4 | #include <c10/util/accumulate.h> |
5 | #include <c10/util/irange.h> |
6 | #include <torch/csrc/distributed/c10d/Types.hpp> |
7 | |
8 | #ifdef _WIN32 |
9 | #include <winsock2.h> |
10 | #include <ws2tcpip.h> |
11 | typedef SSIZE_T ssize_t; |
12 | #pragma comment(lib, "Ws2_32.lib") |
13 | #else |
14 | #include <fcntl.h> |
15 | #include <netdb.h> |
16 | #include <sys/poll.h> |
17 | #include <sys/socket.h> |
18 | #include <unistd.h> |
19 | #endif |
20 | |
21 | #include <sys/types.h> |
22 | |
23 | #include <chrono> |
24 | #include <cstdint> |
25 | #include <cstdlib> |
26 | #include <functional> |
27 | #include <limits> |
28 | #include <string> |
29 | #include <system_error> |
30 | #include <tuple> |
31 | #include <vector> |
32 | |
33 | namespace c10d { |
34 | |
35 | TORCH_API std::string parse_env(const char* env_var_name); |
36 | |
37 | // Retrieve tensor shapes from a given tensor. |
38 | TORCH_API std::vector<at::Tensor> getTensorShapes(const std::vector<at::Tensor>& tensors); |
39 | |
40 | // Use -2 to represent unset state of env vars |
41 | #define C10D_ENV_NOT_SET -2 |
42 | |
43 | // Turns at::IntArrayRef into "(1, 2, 3, 4)". |
44 | inline std::string toString(at::IntArrayRef l) { |
45 | std::stringstream ss; |
46 | ss << "(" ; |
47 | for (const auto i : c10::irange(l.size())) { |
48 | if (i > 0) { |
49 | ss << ", " ; |
50 | } |
51 | ss << l[i]; |
52 | } |
53 | ss << ")" ; |
54 | return ss.str(); |
55 | } |
56 | |
57 | inline std::string toString(const c10::Layout& layout) { |
58 | std::stringstream ss; |
59 | ss << layout; |
60 | return ss.str(); |
61 | } |
62 | |
63 | inline void assertSameType( |
64 | const at::DeprecatedTypeProperties& type, |
65 | const std::vector<at::Tensor>& tensors) { |
66 | for (const auto i : c10::irange(tensors.size())) { |
67 | if (!tensors[i].options().type_equal(type.options())) { |
68 | const std::string expected = type.toString(); |
69 | const std::string actual = tensors[i].toString(); |
70 | throw std::invalid_argument( |
71 | "mixed types (" + expected + " and " + actual + ")" ); |
72 | } |
73 | } |
74 | } |
75 | |
76 | inline std::vector<std::string> split(char separator, const std::string& string) { |
77 | std::vector<std::string> pieces; |
78 | std::stringstream ss(string); |
79 | std::string item; |
80 | while (std::getline(ss, item, separator)) { |
81 | pieces.push_back(std::move(item)); |
82 | } |
83 | return pieces; |
84 | } |
85 | |
86 | inline int parseEnvVarInt(const char* envVarName) { |
87 | char* stringValue = std::getenv(envVarName); |
88 | if (stringValue != nullptr) { |
89 | int val; |
90 | try { |
91 | val = std::stoi(stringValue); |
92 | } catch (std::exception& e) { |
93 | TORCH_CHECK(false, |
94 | "Invalid value for environment variable: " + std::string(envVarName)); |
95 | } |
96 | return val; |
97 | } |
98 | return C10D_ENV_NOT_SET; |
99 | } |
100 | |
101 | inline int parseEnvVarIntDefault(const char* envVarName, int defaultVal) { |
102 | int val = parseEnvVarInt(envVarName); |
103 | if (val == C10D_ENV_NOT_SET) |
104 | return defaultVal; |
105 | return val; |
106 | } |
107 | |
108 | inline bool parseEnvVarFlag(const char* envVarName) { |
109 | int val = parseEnvVarInt(envVarName); |
110 | if (val == 1) { |
111 | return true; |
112 | } else if (val == 0 || val == C10D_ENV_NOT_SET) { |
113 | return false; |
114 | } |
115 | TORCH_CHECK(false, |
116 | "Invalid value for environment variable: " + std::string(envVarName)); |
117 | return false; |
118 | } |
119 | |
120 | inline void assertSameSizes( |
121 | const at::IntArrayRef& sizes, |
122 | const std::vector<at::Tensor>& tensors) { |
123 | for (const auto i : c10::irange(tensors.size())) { |
124 | if (!tensors[i].sizes().equals(sizes)) { |
125 | const auto expected = toString(sizes); |
126 | const auto actual = toString(tensors[i].sizes()); |
127 | throw std::invalid_argument( |
128 | "mixed sizes (" + expected + " and " + actual + ")" ); |
129 | } |
130 | } |
131 | } |
132 | |
133 | inline void assertSameSizeAndType(const std::vector<at::Tensor>& tensors) { |
134 | // Ensure we have at least one tensor |
135 | if (tensors.empty()) { |
136 | throw std::invalid_argument("argument is empty" ); |
137 | } |
138 | |
139 | // Ensure all tensors have identical type and shape |
140 | auto options = tensors[0].options(); |
141 | auto sizes = tensors[0].sizes(); |
142 | for (const auto i : c10::irange(1, tensors.size())) { |
143 | if (!tensors[i].options().type_equal(options)) { |
144 | const auto expected = toString(options); |
145 | const auto actual = toString(tensors[i].options()); |
146 | throw std::invalid_argument( |
147 | "argument contains mixed types (" + expected + " and " + actual + |
148 | ")" ); |
149 | } |
150 | if (!tensors[i].sizes().equals(sizes)) { |
151 | const auto expected = toString(sizes); |
152 | const auto actual = toString(tensors[i].sizes()); |
153 | throw std::invalid_argument( |
154 | "argument contains mixed sizes (" + expected + " and " + actual + |
155 | ")" ); |
156 | } |
157 | } |
158 | } |
159 | |
160 | inline void assertTypeMatch( |
161 | std::function<void(const std::string&)> fn, |
162 | const at::DeprecatedTypeProperties& type, |
163 | const at::ArrayRef<at::Tensor> tensors, |
164 | size_t index) { |
165 | if (!tensors[index].options().type_equal(type.options())) { |
166 | fn("invalid tensor type at index " + std::to_string(index) + " (expected " + |
167 | type.toString() + ", got " + tensors[index].toString() + ")" ); |
168 | } |
169 | } |
170 | |
171 | inline void assertTypeMatch( |
172 | std::function<void(const std::string&)> fn, |
173 | const at::TensorOptions& options, |
174 | const at::ArrayRef<at::Tensor> tensors, |
175 | size_t index) { |
176 | if (!tensors[index].options().type_equal(options)) { |
177 | fn("invalid tensor type at index " + std::to_string(index) + " (expected " + |
178 | toString(options) + ", got " + toString(tensors[index].options()) + ")" ); |
179 | } |
180 | } |
181 | |
182 | inline void assertSizesMatch( |
183 | std::function<void(const std::string&)> fn, |
184 | const at::IntArrayRef& sizes, |
185 | const at::ArrayRef<at::Tensor> tensors, |
186 | size_t index) { |
187 | if (tensors[index].sizes() != sizes) { |
188 | fn("invalid tensor size at index " + std::to_string(index) + " (expected " + |
189 | toString(sizes) + ", got " + toString(tensors[index].sizes()) + ")" ); |
190 | } |
191 | } |
192 | |
193 | inline void assertLayoutMatch( |
194 | std::function<void(const std::string&)> fn, |
195 | const c10::Layout& expected, |
196 | const at::ArrayRef<at::Tensor> tensors, |
197 | size_t index) { |
198 | const auto& actual = tensors[index].layout(); |
199 | if (actual != expected) { |
200 | fn("invalid tensor layout at index " + std::to_string(index) + |
201 | " (expected " + toString(expected) + ", got " + toString(actual) + ")" ); |
202 | } |
203 | } |
204 | |
205 | inline void assertLayoutMatch( |
206 | std::function<void(const std::string&)> fn, |
207 | const at::ArrayRef<at::Tensor> tensors) { |
208 | const auto& layout = tensors[0].layout(); |
209 | for (const auto i : c10::irange(1, tensors.size())) { |
210 | assertLayoutMatch(fn, layout, tensors, i); |
211 | } |
212 | } |
213 | |
214 | inline void assertNonEmpty( |
215 | std::function<void(const std::string&)> fn, |
216 | const at::ArrayRef<at::Tensor> tensors) { |
217 | if (tensors.empty()) { |
218 | fn("requires non-empty tensor list" ); |
219 | } |
220 | } |
221 | |
222 | inline void assertSingleElement( |
223 | std::function<void(const std::string&)> fn, |
224 | const at::ArrayRef<at::Tensor> tensors) { |
225 | if (tensors.size() != 1) { |
226 | fn("requires a single-element tensor list" ); |
227 | } |
228 | } |
229 | |
230 | inline void assertSingleElementInput( |
231 | std::function<void(const std::string&)> fn, |
232 | const at::ArrayRef<at::Tensor> tensors) { |
233 | if (tensors.size() != 1) { |
234 | fn("requires a single-element input tensor list" ); |
235 | } |
236 | } |
237 | |
238 | inline void assertSingleElementOutput( |
239 | std::function<void(const std::string&)> fn, |
240 | const at::ArrayRef<at::Tensor> tensors) { |
241 | if (tensors.size() != 1) { |
242 | fn("requires a single-element output tensor list" ); |
243 | } |
244 | } |
245 | |
246 | inline void assertRootRank( |
247 | std::function<void(const std::string&)> fn, |
248 | int rank, |
249 | int size) { |
250 | if (rank < 0 || rank >= size) { |
251 | fn("invalid root rank: " + std::to_string(rank)); |
252 | } |
253 | } |
254 | |
255 | inline void assertRootTensor( |
256 | std::function<void(const std::string&)> fn, |
257 | int rank, |
258 | int size) { |
259 | if (rank < 0 || rank >= size) { |
260 | fn("invalid root tensor: " + std::to_string(rank)); |
261 | } |
262 | } |
263 | |
264 | inline void assertDense( |
265 | std::function<void(const std::string&)> fn, |
266 | const at::ArrayRef<at::Tensor> tensors) { |
267 | const auto& layout = tensors[0].layout(); |
268 | if (layout != at::kStrided) { |
269 | fn("only supports dense tensors" ); |
270 | } |
271 | } |
272 | |
273 | inline void assertCPU( |
274 | std::function<void(const std::string&)> fn, |
275 | const at::ArrayRef<at::Tensor> tensors) { |
276 | const auto& device = tensors[0].device(); |
277 | if (device.type() != at::kCPU) { |
278 | fn("only supports CPU tensors" ); |
279 | } |
280 | } |
281 | |
282 | inline void assertSameDevice( |
283 | std::function<void(const std::string&)> fn, |
284 | const at::ArrayRef<at::Tensor> tensors) { |
285 | if (tensors.size() < 2) { |
286 | return; |
287 | } |
288 | const auto& device = tensors[0].device(); |
289 | for (const auto i : c10::irange(1, tensors.size())) { |
290 | if (tensors[i].device() != device) { |
291 | fn("tensors should be on the same device" ); |
292 | } |
293 | } |
294 | } |
295 | |
296 | inline void assertTypeAndSizesMatch( |
297 | std::function<void(const std::string&)> fn, |
298 | const at::ArrayRef<at::Tensor> tensors, |
299 | const at::DeprecatedTypeProperties& type, |
300 | const at::IntArrayRef& sizes) { |
301 | for (const auto i : c10::irange(tensors.size())) { |
302 | assertTypeMatch(fn, type, tensors, i); |
303 | assertSizesMatch(fn, sizes, tensors, i); |
304 | } |
305 | } |
306 | |
307 | inline void assertTypeAndSizesMatch( |
308 | std::function<void(const std::string&)> fn, |
309 | const at::ArrayRef<at::Tensor> tensors, |
310 | const at::TensorOptions& options, |
311 | const at::IntArrayRef& sizes) { |
312 | for (const auto i : c10::irange(tensors.size())) { |
313 | assertTypeMatch(fn, options, tensors, i); |
314 | assertSizesMatch(fn, sizes, tensors, i); |
315 | } |
316 | } |
317 | |
318 | inline void assertTypeAndSizesMatch( |
319 | std::function<void(const std::string&)> fn, |
320 | const at::ArrayRef<at::Tensor> tensors) { |
321 | const auto& options = tensors[0].options(); |
322 | const auto sizes = tensors[0].sizes(); |
323 | assertTypeAndSizesMatch(fn, tensors.slice(1), options, sizes); |
324 | } |
325 | |
326 | // Copied from ATen/core/functional.h. |
327 | template <typename F, typename T> |
328 | inline auto fmap(T& inputs, const F& fn) |
329 | -> std::vector<decltype(fn(*inputs.begin()))> { |
330 | std::vector<decltype(fn(*inputs.begin()))> r; |
331 | r.reserve(inputs.size()); |
332 | for (auto& input : inputs) { |
333 | r.push_back(fn(input)); |
334 | } |
335 | return r; |
336 | } |
337 | |
338 | // Copied from torch/csrc/utils/tensor_flatten.h. |
339 | inline at::Tensor flattenDenseTensors(at::TensorList tensors) { |
340 | static const auto flatten = [](const at::Tensor& t) { |
341 | return t.contiguous().view({-1}); |
342 | }; |
343 | if (tensors.size() == 1) { |
344 | return flatten(tensors[0]); |
345 | } |
346 | return at::cat(::c10d::fmap(tensors, flatten)); |
347 | } |
348 | |
349 | inline at::Tensor newLikeFlat( |
350 | std::vector<std::vector<at::Tensor>>& tensors, |
351 | size_t deviceIdx) { |
352 | if (tensors.empty() || tensors[0].empty()) { |
353 | TORCH_CHECK(false, "Received an empty list" ); |
354 | } |
355 | if (deviceIdx >= tensors.size()) { |
356 | TORCH_CHECK(false, "Invalid device index" ); |
357 | } |
358 | auto& t = tensors[deviceIdx][0]; |
359 | auto device = t.device(); |
360 | for (const auto i : c10::irange(1, tensors[deviceIdx].size())) { |
361 | if (tensors[deviceIdx][i].device() != device) { |
362 | TORCH_CHECK(false, "Expecting all tensors on the same device" ); |
363 | } |
364 | } |
365 | at::DeviceGuard gpuGuard(device); |
366 | std::vector<int64_t> sizes{static_cast<int64_t>(tensors[deviceIdx].size())}; |
367 | std::vector<int64_t> strides{static_cast<int64_t>(t.numel())}; |
368 | sizes.insert(sizes.end(), t.sizes().begin(), t.sizes().end()); |
369 | strides.insert(strides.end(), t.strides().begin(), t.strides().end()); |
370 | return at::empty_strided( |
371 | sizes, strides, t.options().memory_format(c10::nullopt)); |
372 | } |
373 | |
374 | inline at::Tensor newLikeFlat(std::vector<at::Tensor>& tensors) { |
375 | if (tensors.empty()) { |
376 | TORCH_CHECK(false, "Received an empty list" ); |
377 | } |
378 | auto& t = tensors[0]; |
379 | at::DeviceGuard gpuGuard(t.device()); |
380 | std::vector<int64_t> sizes{static_cast<int64_t>(tensors.size())}; |
381 | sizes.insert(sizes.end(), t.sizes().begin(), t.sizes().end()); |
382 | return at::empty(sizes, t.options()); |
383 | } |
384 | |
385 | inline std::vector<std::vector<int64_t>> getSizes( |
386 | const std::vector<at::Tensor>& tensors) { |
387 | std::vector<std::vector<int64_t>> sizes(tensors.size()); |
388 | for (const auto i : c10::irange(tensors.size())) { |
389 | sizes[i] = tensors[i].sizes().vec(); |
390 | } |
391 | return sizes; |
392 | } |
393 | |
394 | inline std::vector<int> getDevices(const std::vector<at::Tensor>& tensors) { |
395 | std::vector<int> devices(tensors.size(), -1); |
396 | if (tensors[0].device().is_cuda()) { |
397 | for (const auto i : c10::irange(tensors.size())) { |
398 | devices[i] = tensors[i].storage().device().index(); |
399 | } |
400 | } |
401 | return devices; |
402 | } |
403 | |
404 | template <typename T> |
405 | inline T* getDataPointer(const at::Tensor& tensor) { |
406 | // This method is only used in ProcessGroupGloo for now. Call sites must make |
407 | // sure that the input tensor is contiguous. It is OK if the tensor does not |
408 | // start from the beginning of the storage. For example, it could come from |
409 | // chunk(..., dim=0)[1]. Hence, we need to use data_ptr() instead of |
410 | // tensor.storage().data() |
411 | // NB: not using tensor.data<T>() because tensor is not aware of gloo::TYPE |
412 | return static_cast<T*>(tensor.data_ptr()); |
413 | } |
414 | |
415 | template <typename T> |
416 | std::vector<T*> getDataPointers(const std::vector<at::Tensor>& tensors) { |
417 | std::vector<T*> ptrs(tensors.size()); |
418 | for (const auto i : c10::irange(tensors.size())) { |
419 | ptrs[i] = getDataPointer<T>(tensors[i]); |
420 | } |
421 | return ptrs; |
422 | } |
423 | |
424 | // For alltoall split size sanity check |
425 | inline void checkSplitSizes( |
426 | const std::vector<int64_t>& split_sizes, |
427 | const at::Tensor& tensor, |
428 | int group_size) { |
429 | if (split_sizes.empty()) { |
430 | TORCH_CHECK( |
431 | tensor.size(0) % group_size == 0, |
432 | "Tensor's dim 0 does not divide equally across group size" ); |
433 | } else { |
434 | TORCH_CHECK( |
435 | split_sizes.size() == static_cast<size_t>(group_size), |
436 | "Number of tensor splits not equal to group size" ); |
437 | const auto sum = c10::sum_integers(split_sizes); |
438 | TORCH_CHECK( |
439 | sum == tensor.size(0), "Split sizes doesn't match total dim 0 size" ); |
440 | } |
441 | } |
442 | |
443 | // Compute alltoall lengths and offsets, handling multi-dimension tensors |
444 | template <typename T> |
445 | size_t computeLengthsAndOffsets( |
446 | const std::vector<int64_t>& split_sizes, |
447 | const at::Tensor& tensor, |
448 | std::vector<T>* lengths, |
449 | std::vector<T>* offsets) { |
450 | size_t group_size = lengths->size(); |
451 | bool equal_splits = false; |
452 | size_t dim0_size = tensor.size(0); |
453 | size_t row_size = (dim0_size ? tensor.numel() / dim0_size : 1); |
454 | size_t split_size = 0; |
455 | size_t offset = 0; |
456 | |
457 | if (split_sizes.empty()) { |
458 | equal_splits = true; |
459 | split_size = tensor.size(0) / group_size; |
460 | } |
461 | for(const auto i : c10::irange(group_size)) { |
462 | size_t length = row_size * (equal_splits ? split_size : split_sizes[i]); |
463 | TORCH_INTERNAL_ASSERT( |
464 | length <= std::numeric_limits<int>::max() && |
465 | offset <= std::numeric_limits<int>::max(), |
466 | "Length or offset larger than INT_MAX not supported" ); |
467 | (*lengths)[i] = length; |
468 | (*offsets)[i] = offset; |
469 | offset += length; |
470 | } |
471 | return offset; |
472 | } |
473 | |
474 | template <typename T> |
475 | size_t computeLengthsAndOffsets( |
476 | const std::vector<at::Tensor>& tensors, |
477 | std::vector<T>* lengths, |
478 | std::vector<T>* offsets) { |
479 | size_t group_size = lengths->size(); |
480 | size_t offset = 0; |
481 | for(const auto i : c10::irange(group_size)) { |
482 | size_t length = tensors[i].numel(); |
483 | TORCH_INTERNAL_ASSERT( |
484 | length <= std::numeric_limits<int>::max() && |
485 | offset <= std::numeric_limits<int>::max(), |
486 | "Length or offset larger than INT_MAX not supported" ); |
487 | (*lengths)[i] = length; |
488 | (*offsets)[i] = offset; |
489 | offset += length; |
490 | } |
491 | return offset; |
492 | } |
493 | |
494 | using RankType = uint32_t; |
495 | using SizeType = uint64_t; |
496 | |
497 | // `errno` is only meaningful when it fails. E.g., a successful `fork()` sets |
498 | // `errno` to `EINVAL` in child process on some macos |
499 | // (https://stackoverflow.com/a/20295079), and thus `errno` should really only |
500 | // be inspected if an error occurred. |
501 | // |
502 | // `success_cond` is an expression used to check if an error has happend. So for |
503 | // `fork()`, we can use `SYSCHECK(pid = fork(), pid != -1)`. The function output |
504 | // is stored in variable `__output` and may be used in `success_cond`. |
505 | #ifdef _WIN32 |
506 | #define SYSCHECK(expr, success_cond) \ |
507 | while (true) { \ |
508 | auto __output = (expr); \ |
509 | auto errno_local = WSAGetLastError(); \ |
510 | (void)__output; \ |
511 | if (!(success_cond)) { \ |
512 | if (errno == EINTR) { \ |
513 | continue; \ |
514 | } else if ( \ |
515 | errno_local == WSAETIMEDOUT || errno_local == WSAEWOULDBLOCK) { \ |
516 | TORCH_CHECK(false, "Socket Timeout"); \ |
517 | } else { \ |
518 | throw std::system_error(errno_local, std::system_category()); \ |
519 | } \ |
520 | } else { \ |
521 | break; \ |
522 | } \ |
523 | } |
524 | #else |
525 | #define SYSCHECK(expr, success_cond) \ |
526 | while (true) { \ |
527 | auto __output = (expr); \ |
528 | (void)__output; \ |
529 | if (!(success_cond)) { \ |
530 | if (errno == EINTR) { \ |
531 | continue; \ |
532 | } else if (errno == EAGAIN || errno == EWOULDBLOCK) { \ |
533 | TORCH_CHECK(false, "Socket Timeout"); \ |
534 | } else { \ |
535 | throw std::system_error(errno, std::system_category()); \ |
536 | } \ |
537 | } else { \ |
538 | break; \ |
539 | } \ |
540 | } |
541 | #endif |
542 | |
543 | // Most functions indicate error by returning `-1`. This is a helper macro for |
544 | // this common case with `SYSCHECK`. |
545 | // Since SOCKET_ERROR = -1 in MSVC, so also leverage SYSCHECK_ERR_RETURN_NEG1 |
546 | #define SYSCHECK_ERR_RETURN_NEG1(expr) SYSCHECK(expr, __output != -1) |
547 | |
548 | namespace tcputil { |
549 | |
550 | // Send and receive |
551 | template <typename T> |
552 | void sendBytes( |
553 | int socket, |
554 | const T* buffer, |
555 | size_t length, |
556 | bool moreData = false) { |
557 | size_t bytesToSend = sizeof(T) * length; |
558 | if (bytesToSend == 0) { |
559 | return; |
560 | } |
561 | |
562 | auto bytes = reinterpret_cast<const uint8_t*>(buffer); |
563 | uint8_t* currentBytes = const_cast<uint8_t*>(bytes); |
564 | |
565 | int flags = 0; |
566 | |
567 | #ifdef MSG_MORE |
568 | if (moreData) { // there is more data to send |
569 | flags |= MSG_MORE; |
570 | } |
571 | #endif |
572 | |
573 | // Ignore SIGPIPE as the send() return value is always checked for error |
574 | #ifdef MSG_NOSIGNAL |
575 | flags |= MSG_NOSIGNAL; |
576 | #endif |
577 | |
578 | while (bytesToSend > 0) { |
579 | ssize_t bytesSent; |
580 | SYSCHECK_ERR_RETURN_NEG1( |
581 | bytesSent = |
582 | ::send(socket, (const char*)currentBytes, bytesToSend, flags)) |
583 | if (bytesSent == 0) { |
584 | throw std::system_error(ECONNRESET, std::system_category()); |
585 | } |
586 | |
587 | bytesToSend -= bytesSent; |
588 | currentBytes += bytesSent; |
589 | } |
590 | } |
591 | |
592 | template <typename T> |
593 | void recvBytes(int socket, T* buffer, size_t length) { |
594 | size_t bytesToReceive = sizeof(T) * length; |
595 | if (bytesToReceive == 0) { |
596 | return; |
597 | } |
598 | |
599 | auto bytes = reinterpret_cast<uint8_t*>(buffer); |
600 | uint8_t* currentBytes = bytes; |
601 | |
602 | while (bytesToReceive > 0) { |
603 | ssize_t bytesReceived; |
604 | SYSCHECK_ERR_RETURN_NEG1( |
605 | bytesReceived = recv(socket, (char*)currentBytes, bytesToReceive, 0)) |
606 | if (bytesReceived == 0) { |
607 | throw std::system_error(ECONNRESET, std::system_category()); |
608 | } |
609 | |
610 | bytesToReceive -= bytesReceived; |
611 | currentBytes += bytesReceived; |
612 | } |
613 | } |
614 | |
615 | // send a vector's length and data |
616 | template <typename T> |
617 | void sendVector(int socket, const std::vector<T>& vec, bool moreData = false) { |
618 | SizeType size = vec.size(); |
619 | sendBytes<SizeType>(socket, &size, 1, true); |
620 | sendBytes<T>(socket, vec.data(), size, moreData); |
621 | } |
622 | |
623 | // receive a vector as sent in sendVector |
624 | template <typename T> |
625 | std::vector<T> recvVector(int socket) { |
626 | SizeType valueSize; |
627 | recvBytes<SizeType>(socket, &valueSize, 1); |
628 | std::vector<T> value(valueSize); |
629 | recvBytes<T>(socket, value.data(), value.size()); |
630 | return value; |
631 | } |
632 | |
633 | // this is only for convenience when sending rvalues |
634 | template <typename T> |
635 | void sendValue(int socket, const T& value, bool moreData = false) { |
636 | sendBytes<T>(socket, &value, 1, moreData); |
637 | } |
638 | |
639 | template <typename T> |
640 | T recvValue(int socket) { |
641 | T value; |
642 | recvBytes<T>(socket, &value, 1); |
643 | return value; |
644 | } |
645 | |
646 | // send a string's length and data |
647 | inline void sendString( |
648 | int socket, |
649 | const std::string& str, |
650 | bool moreData = false) { |
651 | SizeType size = str.size(); |
652 | sendBytes<SizeType>(socket, &size, 1, true); |
653 | sendBytes<char>(socket, str.data(), size, moreData); |
654 | } |
655 | |
656 | // receive a string as sent in sendString |
657 | inline std::string recvString(int socket) { |
658 | SizeType valueSize; |
659 | recvBytes<SizeType>(socket, &valueSize, 1); |
660 | std::vector<char> value(valueSize); |
661 | recvBytes<char>(socket, value.data(), value.size()); |
662 | return std::string(value.data(), value.size()); |
663 | } |
664 | |
665 | } // namespace tcputil |
666 | } // namespace c10d |
667 | |