1// Copyright (c) Meta Platforms, Inc. and its affiliates.
2// All rights reserved.
3//
4// This source code is licensed under the BSD-style license found in the
5// LICENSE file in the root directory of this source tree.
6
7#pragma once
8
9#include <chrono>
10#include <cstdint>
11#include <memory>
12#include <string>
13
14#include <c10/macros/Macros.h>
15#include <torch/csrc/distributed/c10d/exception.h>
16
17namespace c10d {
18namespace detail {
19
20class SocketOptions {
21 public:
22 SocketOptions& prefer_ipv6(bool value) noexcept {
23 prefer_ipv6_ = value;
24
25 return *this;
26 }
27
28 bool prefer_ipv6() const noexcept {
29 return prefer_ipv6_;
30 }
31
32 SocketOptions& connect_timeout(std::chrono::seconds value) noexcept {
33 connect_timeout_ = value;
34
35 return *this;
36 }
37
38 std::chrono::seconds connect_timeout() const noexcept {
39 return connect_timeout_;
40 }
41
42 private:
43 bool prefer_ipv6_ = true;
44 std::chrono::seconds connect_timeout_{30};
45};
46
47class SocketImpl;
48
49class Socket {
50 public:
51 // This function initializes the underlying socket library and must be called
52 // before any other socket function.
53 static void initialize();
54
55 static Socket listen(std::uint16_t port, const SocketOptions& opts = {});
56
57 static Socket connect(
58 const std::string& host,
59 std::uint16_t port,
60 const SocketOptions& opts = {});
61
62 Socket() noexcept = default;
63
64 Socket(const Socket& other) = delete;
65
66 Socket& operator=(const Socket& other) = delete;
67
68 Socket(Socket&& other) noexcept;
69
70 Socket& operator=(Socket&& other) noexcept;
71
72 ~Socket();
73
74 Socket accept() const;
75
76 int handle() const noexcept;
77
78 std::uint16_t port() const;
79
80 private:
81 explicit Socket(std::unique_ptr<SocketImpl>&& impl) noexcept;
82
83 std::unique_ptr<SocketImpl> impl_;
84};
85
86} // namespace detail
87
88class TORCH_API SocketError : public C10dError {
89 public:
90 using C10dError::C10dError;
91
92 SocketError(const SocketError&) = default;
93
94 SocketError& operator=(const SocketError&) = default;
95
96 SocketError(SocketError&&) = default;
97
98 SocketError& operator=(SocketError&&) = default;
99
100 ~SocketError() override;
101};
102
103} // namespace c10d
104