1 | // Copyright (c) Meta Platforms, Inc. and its affiliates. |
2 | // All rights reserved. |
3 | // |
4 | // This source code is licensed under the BSD-style license found in the |
5 | // LICENSE file in the root directory of this source tree. |
6 | |
7 | #pragma once |
8 | |
9 | #include <chrono> |
10 | #include <cstdint> |
11 | #include <memory> |
12 | #include <string> |
13 | |
14 | #include <c10/macros/Macros.h> |
15 | #include <torch/csrc/distributed/c10d/exception.h> |
16 | |
17 | namespace c10d { |
18 | namespace detail { |
19 | |
20 | class SocketOptions { |
21 | public: |
22 | SocketOptions& prefer_ipv6(bool value) noexcept { |
23 | prefer_ipv6_ = value; |
24 | |
25 | return *this; |
26 | } |
27 | |
28 | bool prefer_ipv6() const noexcept { |
29 | return prefer_ipv6_; |
30 | } |
31 | |
32 | SocketOptions& connect_timeout(std::chrono::seconds value) noexcept { |
33 | connect_timeout_ = value; |
34 | |
35 | return *this; |
36 | } |
37 | |
38 | std::chrono::seconds connect_timeout() const noexcept { |
39 | return connect_timeout_; |
40 | } |
41 | |
42 | private: |
43 | bool prefer_ipv6_ = true; |
44 | std::chrono::seconds connect_timeout_{30}; |
45 | }; |
46 | |
47 | class SocketImpl; |
48 | |
49 | class Socket { |
50 | public: |
51 | // This function initializes the underlying socket library and must be called |
52 | // before any other socket function. |
53 | static void initialize(); |
54 | |
55 | static Socket listen(std::uint16_t port, const SocketOptions& opts = {}); |
56 | |
57 | static Socket connect( |
58 | const std::string& host, |
59 | std::uint16_t port, |
60 | const SocketOptions& opts = {}); |
61 | |
62 | Socket() noexcept = default; |
63 | |
64 | Socket(const Socket& other) = delete; |
65 | |
66 | Socket& operator=(const Socket& other) = delete; |
67 | |
68 | Socket(Socket&& other) noexcept; |
69 | |
70 | Socket& operator=(Socket&& other) noexcept; |
71 | |
72 | ~Socket(); |
73 | |
74 | Socket accept() const; |
75 | |
76 | int handle() const noexcept; |
77 | |
78 | std::uint16_t port() const; |
79 | |
80 | private: |
81 | explicit Socket(std::unique_ptr<SocketImpl>&& impl) noexcept; |
82 | |
83 | std::unique_ptr<SocketImpl> impl_; |
84 | }; |
85 | |
86 | } // namespace detail |
87 | |
88 | class TORCH_API SocketError : public C10dError { |
89 | public: |
90 | using C10dError::C10dError; |
91 | |
92 | SocketError(const SocketError&) = default; |
93 | |
94 | SocketError& operator=(const SocketError&) = default; |
95 | |
96 | SocketError(SocketError&&) = default; |
97 | |
98 | SocketError& operator=(SocketError&&) = default; |
99 | |
100 | ~SocketError() override; |
101 | }; |
102 | |
103 | } // namespace c10d |
104 | |