1#include <torch/csrc/distributed/c10d/GlooDeviceFactory.hpp>
2
3#ifdef USE_C10D_GLOO
4
5#include <cstdlib>
6
7#include <c10/util/Exception.h>
8
9#if GLOO_HAVE_TRANSPORT_TCP
10#include <gloo/transport/tcp/device.h>
11#endif
12
13#if GLOO_HAVE_TRANSPORT_TCP_TLS
14#include <gloo/transport/tcp/tls/device.h>
15#endif
16
17#if GLOO_HAVE_TRANSPORT_UV
18#include <gloo/transport/uv/device.h>
19#endif
20
21// On Linux, check that the tcp transport is available.
22#ifdef __linux__
23#if !GLOO_HAVE_TRANSPORT_TCP
24#error "Expected the tcp transport to be available on Linux."
25#endif
26#endif
27
28// On macOS, check that the uv transport is available.
29#ifdef __APPLE__
30#if !GLOO_HAVE_TRANSPORT_UV
31#error "Expected the uv transport to be available on macOS."
32#endif
33#endif
34
35namespace c10d {
36
37C10_DEFINE_SHARED_REGISTRY_WITHOUT_WARNING(
38 GlooDeviceRegistry,
39 ::gloo::transport::Device,
40 const std::string& /* interface */,
41 const std::string& /* hostname */);
42
43#if GLOO_HAVE_TRANSPORT_TCP
44static std::shared_ptr<::gloo::transport::Device> makeTCPDevice(
45 const std::string& interfaceName,
46 const std::string& hostname) {
47 TORCH_CHECK(
48 !interfaceName.empty() || !hostname.empty(),
49 "GlooDeviceFactory::makeTCPDevice(): interface or hostname "
50 "can't be empty");
51
52 ::gloo::transport::tcp::attr attr;
53 if (!interfaceName.empty()) {
54 attr.iface = interfaceName;
55 } else {
56 attr.hostname = hostname;
57 }
58 return ::gloo::transport::tcp::CreateDevice(attr);
59}
60
61// Registry priority is per key identifier. We register TCP to `LINUX` for
62// the flexibility of other application to override by priority. Register
63// TCP to `TCP` for env "GLOO_DEVICE_TRANSPORT" override.
64C10_REGISTER_CREATOR(GlooDeviceRegistry, LINUX, makeTCPDevice);
65C10_REGISTER_CREATOR(GlooDeviceRegistry, TCP, makeTCPDevice);
66#endif
67
68#if GLOO_HAVE_TRANSPORT_TCP_TLS
69static std::string cstr_to_std_string(const char* chars) {
70 return std::string(chars != nullptr ? chars : "");
71}
72
73static std::shared_ptr<::gloo::transport::Device> makeTCPTLSDevice(
74 const std::string& interface,
75 const std::string& hostname) {
76 TORCH_CHECK(
77 !interface.empty() || !hostname.empty(),
78 "GlooDeviceFactory::makeTCPTLSDevice(): interface or hostname "
79 "can't be empty");
80
81 ::gloo::transport::tcp::attr attr;
82 if (!interface.empty()) {
83 attr.iface = interface;
84 } else {
85 attr.hostname = hostname;
86 }
87 const auto pkey =
88 cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_PKEY"));
89 const auto cert =
90 cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CERT"));
91 const auto caFile =
92 cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_FILE"));
93 const auto caPath =
94 cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_PATH"));
95 return ::gloo::transport::tcp::tls::CreateDevice(
96 attr, pkey, cert, caFile, caPath);
97}
98
99C10_REGISTER_CREATOR(GlooDeviceRegistry, TCP_TLS, makeTCPTLSDevice);
100#endif
101
102#if GLOO_HAVE_TRANSPORT_UV
103static std::shared_ptr<::gloo::transport::Device> makeUVDevice(
104 const std::string& interfaceName,
105 const std::string& hostname) {
106 TORCH_CHECK(
107 !interfaceName.empty() || !hostname.empty(),
108 "GlooDeviceFactory::makeUVDevice(): interface or hostname "
109 "can't be empty");
110
111 ::gloo::transport::uv::attr attr;
112 if (!interfaceName.empty()) {
113 attr.iface = interfaceName;
114 } else {
115 attr.hostname = hostname;
116 }
117 return ::gloo::transport::uv::CreateDevice(attr);
118}
119
120// Registry priority is per key identifier. We register UV to `APPLE` for
121// the flexibility of other application to override by priority. Register
122// UV to `UV` for env "GLOO_DEVICE_TRANSPORT" override.
123C10_REGISTER_CREATOR(GlooDeviceRegistry, APPLE, makeUVDevice);
124C10_REGISTER_CREATOR(GlooDeviceRegistry, WIN32, makeUVDevice);
125C10_REGISTER_CREATOR(GlooDeviceRegistry, UV, makeUVDevice);
126#endif
127
128namespace {
129std::shared_ptr<::gloo::transport::Device> makeGlooDevice(
130 const std::string& interfaceName,
131 const std::string& hostName) {
132 static auto transportName = getenv("GLOO_DEVICE_TRANSPORT");
133 if (transportName) {
134 return GlooDeviceRegistry()->Create(transportName, interfaceName, hostName);
135 }
136
137#ifdef __linux__
138 return GlooDeviceRegistry()->Create("LINUX", interfaceName, hostName);
139#endif
140
141#ifdef __APPLE__
142 return GlooDeviceRegistry()->Create("APPLE", interfaceName, hostName);
143#endif
144
145#ifdef _WIN32
146 return GlooDeviceRegistry()->Create("WIN32", interfaceName, hostName);
147#endif
148
149 return nullptr;
150}
151} // anonymous namespace
152
153std::shared_ptr<::gloo::transport::Device> GlooDeviceFactory::
154 makeDeviceForInterface(const std::string& interfaceName) {
155 auto device = makeGlooDevice(interfaceName, "");
156 if (!device) {
157 TORCH_CHECK(false, "makeDeviceForInterface(): unsupported gloo device");
158 }
159 return device;
160}
161
162std::shared_ptr<::gloo::transport::Device> GlooDeviceFactory::
163 makeDeviceForHostname(const std::string& hostname) {
164 auto device = makeGlooDevice("", hostname);
165 if (!device) {
166 TORCH_CHECK(false, "makeDeviceForHostname(): unsupported gloo device");
167 }
168 return device;
169}
170
171} // namespace c10d
172
173#endif // USE_C10D_GLOO
174