1 | #ifdef USE_C10D_UCC |
2 | |
3 | #include <torch/csrc/distributed/c10d/UCCTracing.hpp> |
4 | #include <torch/csrc/distributed/c10d/UCCUtils.hpp> |
5 | |
6 | #include <cctype> |
7 | #include <string> |
8 | #include <unordered_map> |
9 | #include <unordered_set> |
10 | |
11 | namespace c10d { |
12 | |
13 | namespace { |
14 | // Constants for store keys. |
15 | constexpr char kTeamRank[] = "teamr" ; |
16 | constexpr char kAllGatherDone[] = "ag_done" ; |
17 | constexpr char kAllGatherFree[] = "ag_free" ; |
18 | } // namespace |
19 | |
20 | ucc_status_t oob_allgather( |
21 | void* sbuf, |
22 | void* rbuf, |
23 | size_t msglen, |
24 | void* coll_info, |
25 | void** req) { |
26 | auto* info = reinterpret_cast<torch_ucc_oob_coll_info_t*>(coll_info); |
27 | TORCH_CHECK(info != nullptr); |
28 | std::vector<uint8_t> val = std::vector<uint8_t>( |
29 | reinterpret_cast<uint8_t*>(sbuf), |
30 | reinterpret_cast<uint8_t*>(sbuf) + msglen); |
31 | try { |
32 | info->store->set(info->getKey(kTeamRank + std::to_string(info->rank)), val); |
33 | info->rbuf = rbuf; |
34 | info->msglen = msglen; |
35 | *req = coll_info; |
36 | } catch (std::exception& ex) { |
37 | LOG(ERROR) << "(oob_allgather) Caught exception in Store Operation .. " |
38 | << "[" << ex.what() << "]" ; |
39 | return UCC_ERR_NO_MESSAGE; |
40 | } |
41 | return UCC_OK; |
42 | } |
43 | |
44 | ucc_status_t oob_allgather_test(void* req) { |
45 | auto* info = reinterpret_cast<torch_ucc_oob_coll_info_t*>(req); |
46 | TORCH_CHECK(info != nullptr); |
47 | |
48 | try { |
49 | for (int r = 0; r < info->size; r++) { |
50 | if (!info->store->check({info->getKey(kTeamRank + std::to_string(r))})) { |
51 | return UCC_INPROGRESS; |
52 | } |
53 | } |
54 | for (int r = 0; r < info->size; r++) { |
55 | std::vector<uint8_t> data = |
56 | info->store->get(info->getKey(kTeamRank + std::to_string(r))); |
57 | memcpy( |
58 | (void*)((ptrdiff_t)info->rbuf + info->msglen * r), |
59 | data.data(), |
60 | info->msglen); |
61 | } |
62 | } catch (std::exception& ex) { |
63 | LOG(ERROR) << "(oob_allgather) Caught exception in Store Operation .. " |
64 | << "[" << ex.what() << "]" ; |
65 | return UCC_ERR_NO_MESSAGE; |
66 | } |
67 | return UCC_OK; |
68 | } |
69 | |
70 | ucc_status_t oob_allgather_free(void* req) { |
71 | auto* info = reinterpret_cast<torch_ucc_oob_coll_info_t*>(req); |
72 | TORCH_CHECK(info != nullptr); |
73 | try { |
74 | int num_done = info->store->add({info->getKey(kAllGatherDone)}, 1); |
75 | if (num_done == info->size) { |
76 | info->store->deleteKey(info->getKey(kAllGatherDone)); |
77 | // Note: to avoid race condition, it's important to remove all keys in |
78 | // oob_allgather_free first and only after that signal completion to |
79 | // other ranks |
80 | for (const auto r : c10::irange(info->size)) { |
81 | info->store->deleteKey(info->getKey(kTeamRank + std::to_string(r))); |
82 | } |
83 | for (const auto r : c10::irange(info->size)) { |
84 | info->store->add({info->getKey(kAllGatherFree + std::to_string(r))}, 1); |
85 | } |
86 | } else { |
87 | info->store->wait( |
88 | {info->getKey(kAllGatherFree + std::to_string(info->rank))}); |
89 | } |
90 | info->store->deleteKey( |
91 | info->getKey(kAllGatherFree + std::to_string(info->rank))); |
92 | } catch (std::exception& ex) { |
93 | LOG(ERROR) << "(oob_allgather) Caught exception in Store Operation .. " |
94 | << "[" << ex.what() << "]" ; |
95 | return UCC_ERR_NO_MESSAGE; |
96 | } |
97 | return UCC_OK; |
98 | } |
99 | |
100 | CommUCC::CommUCC( |
101 | std::shared_ptr<torch_ucc_oob_coll_info_t> oob, |
102 | const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger) |
103 | : CommBase(logger) { |
104 | ucc_lib_config_h lib_config; |
105 | ucc_context_config_h context_config; |
106 | ucc_lib_params_t lib_params; |
107 | ucc_context_params_t context_params; |
108 | ucc_status_t st; |
109 | |
110 | TORCH_UCC_CHECK( |
111 | ucc_lib_config_read("TORCH" , nullptr, &lib_config), |
112 | "failed to read UCC lib config" ); |
113 | memset(&lib_params, 0, sizeof(ucc_lib_params_t)); |
114 | lib_params.mask = UCC_LIB_PARAM_FIELD_THREAD_MODE; |
115 | lib_params.thread_mode = UCC_THREAD_MULTIPLE; |
116 | TORCH_UCC_CHECK( |
117 | ucc_init(&lib_params, lib_config, &lib), "failed to init UCC lib" ); |
118 | ucc_lib_config_release(lib_config); |
119 | ucc_lib_attr_t lib_attr; |
120 | lib_attr.mask = UCC_LIB_ATTR_FIELD_THREAD_MODE; |
121 | TORCH_UCC_CHECK( |
122 | ucc_lib_get_attr(lib, &lib_attr), "failed to query for lib attr" ); |
123 | TORCH_CHECK( |
124 | lib_attr.thread_mode == UCC_THREAD_MULTIPLE, |
125 | "ucc library wasn't initialized with multithreading support, " |
126 | "please check ucc build options" ); |
127 | st = ucc_context_config_read(lib, NULL, &context_config); |
128 | if (st != UCC_OK) { |
129 | // FIXME: would this cause deadlock if only one rank fails? |
130 | TORCH_UCC_CHECK( |
131 | ucc_finalize(lib), |
132 | "failed to finalize UCC library when failing to read UCC context config" ); |
133 | TORCH_UCC_LOG_ERROR( |
134 | TORCH_UCC_INIT, |
135 | c10::str("failed to read UCC context config: " , ucc_status_string(st))); |
136 | throw std::runtime_error(ucc_status_string(st)); |
137 | } |
138 | st = ucc_context_config_modify( |
139 | context_config, |
140 | NULL, |
141 | "ESTIMATED_NUM_EPS" , |
142 | std::to_string(oob->size).c_str()); |
143 | if (st != UCC_OK) { |
144 | ucc_context_config_release(context_config); |
145 | ucc_finalize(lib); |
146 | TORCH_UCC_LOG_ERROR( |
147 | TORCH_UCC_INIT, |
148 | c10::str( |
149 | "UCC failed to modify UCC context config: " , |
150 | ucc_status_string(st))); |
151 | throw std::runtime_error(ucc_status_string(st)); |
152 | } |
153 | memset(&context_params, 0, sizeof(ucc_context_params_t)); |
154 | context_params.mask = |
155 | UCC_CONTEXT_PARAM_FIELD_TYPE | UCC_CONTEXT_PARAM_FIELD_OOB; |
156 | context_params.type = UCC_CONTEXT_SHARED; |
157 | context_params.oob.n_oob_eps = oob->size; |
158 | context_params.oob.oob_ep = oob->rank; |
159 | context_params.oob.allgather = oob_allgather; |
160 | context_params.oob.req_test = oob_allgather_test; |
161 | context_params.oob.req_free = oob_allgather_free; |
162 | context_params.oob.coll_info = oob.get(); |
163 | st = ucc_context_create(lib, &context_params, context_config, &context); |
164 | ucc_context_config_release(context_config); |
165 | if (st != UCC_OK) { |
166 | TORCH_UCC_CHECK( |
167 | ucc_finalize(lib), |
168 | "failed to finalize UCC library when failing to creat UCC context" ); |
169 | TORCH_UCC_LOG_ERROR( |
170 | TORCH_UCC_INIT, |
171 | c10::str("UCC failed to create UCC context: " , ucc_status_string(st))); |
172 | throw std::runtime_error(ucc_status_string(st)); |
173 | } |
174 | } |
175 | |
176 | void CommUCC::progress() { |
177 | TORCH_UCC_CHECK( |
178 | ucc_context_progress(context), "failed to progress UCC collective" ); |
179 | } |
180 | |
181 | void CommUCC::free_request(ucc_coll_req_h request) { |
182 | TORCH_UCC_CHECK( |
183 | ucc_collective_finalize(request), "failed to release UCC request" ); |
184 | } |
185 | |
186 | CommUCC::~CommUCC() { |
187 | if (context != nullptr) { |
188 | TORCH_UCC_CHECK( |
189 | ucc_context_destroy(context), "failed to destroy UCC context" ); |
190 | } |
191 | if (lib != nullptr) { |
192 | TORCH_UCC_CHECK(ucc_finalize(lib), "failed to finalize UCC library" ); |
193 | } |
194 | context = nullptr; |
195 | lib = nullptr; |
196 | } |
197 | |
198 | std::string ProcessGroupUCCLogger::getLogPrefix(torch_ucc_phase_t phase) { |
199 | // caller can override the phase stored locally |
200 | torch_ucc_phase_t phase_ = |
201 | (local_phase != phase && phase != TORCH_UCC_UNKNOWN) ? phase |
202 | : local_phase; |
203 | return c10::str(log_prefix, "[" , ucc_phase_map.at(phase_), "]" ); |
204 | } |
205 | void ProcessGroupUCCLogger::setLogPrefix(std::string log_prefix_) { |
206 | log_prefix = log_prefix_; |
207 | } |
208 | |
209 | ProcessGroupUCCLogger::ProcessGroupUCCLogger() { |
210 | setLogPrefix("[ProcessGroupUCC]" ); |
211 | } |
212 | ProcessGroupUCCLogger::ProcessGroupUCCLogger( |
213 | std::string log_prefix, |
214 | torch_ucc_phase_t phase) |
215 | : local_phase(phase) { |
216 | setLogPrefix(log_prefix); |
217 | } |
218 | |
219 | } // namespace c10d |
220 | |
221 | #endif // USE_C10D_UCC |
222 | |