1#ifdef USE_C10D_UCC
2
3#include <torch/csrc/distributed/c10d/UCCTracing.hpp>
4#include <torch/csrc/distributed/c10d/UCCUtils.hpp>
5
6#include <cctype>
7#include <string>
8#include <unordered_map>
9#include <unordered_set>
10
11namespace c10d {
12
13namespace {
14// Constants for store keys.
15constexpr char kTeamRank[] = "teamr";
16constexpr char kAllGatherDone[] = "ag_done";
17constexpr char kAllGatherFree[] = "ag_free";
18} // namespace
19
20ucc_status_t oob_allgather(
21 void* sbuf,
22 void* rbuf,
23 size_t msglen,
24 void* coll_info,
25 void** req) {
26 auto* info = reinterpret_cast<torch_ucc_oob_coll_info_t*>(coll_info);
27 TORCH_CHECK(info != nullptr);
28 std::vector<uint8_t> val = std::vector<uint8_t>(
29 reinterpret_cast<uint8_t*>(sbuf),
30 reinterpret_cast<uint8_t*>(sbuf) + msglen);
31 try {
32 info->store->set(info->getKey(kTeamRank + std::to_string(info->rank)), val);
33 info->rbuf = rbuf;
34 info->msglen = msglen;
35 *req = coll_info;
36 } catch (std::exception& ex) {
37 LOG(ERROR) << "(oob_allgather) Caught exception in Store Operation .. "
38 << "[" << ex.what() << "]";
39 return UCC_ERR_NO_MESSAGE;
40 }
41 return UCC_OK;
42}
43
44ucc_status_t oob_allgather_test(void* req) {
45 auto* info = reinterpret_cast<torch_ucc_oob_coll_info_t*>(req);
46 TORCH_CHECK(info != nullptr);
47
48 try {
49 for (int r = 0; r < info->size; r++) {
50 if (!info->store->check({info->getKey(kTeamRank + std::to_string(r))})) {
51 return UCC_INPROGRESS;
52 }
53 }
54 for (int r = 0; r < info->size; r++) {
55 std::vector<uint8_t> data =
56 info->store->get(info->getKey(kTeamRank + std::to_string(r)));
57 memcpy(
58 (void*)((ptrdiff_t)info->rbuf + info->msglen * r),
59 data.data(),
60 info->msglen);
61 }
62 } catch (std::exception& ex) {
63 LOG(ERROR) << "(oob_allgather) Caught exception in Store Operation .. "
64 << "[" << ex.what() << "]";
65 return UCC_ERR_NO_MESSAGE;
66 }
67 return UCC_OK;
68}
69
70ucc_status_t oob_allgather_free(void* req) {
71 auto* info = reinterpret_cast<torch_ucc_oob_coll_info_t*>(req);
72 TORCH_CHECK(info != nullptr);
73 try {
74 int num_done = info->store->add({info->getKey(kAllGatherDone)}, 1);
75 if (num_done == info->size) {
76 info->store->deleteKey(info->getKey(kAllGatherDone));
77 // Note: to avoid race condition, it's important to remove all keys in
78 // oob_allgather_free first and only after that signal completion to
79 // other ranks
80 for (const auto r : c10::irange(info->size)) {
81 info->store->deleteKey(info->getKey(kTeamRank + std::to_string(r)));
82 }
83 for (const auto r : c10::irange(info->size)) {
84 info->store->add({info->getKey(kAllGatherFree + std::to_string(r))}, 1);
85 }
86 } else {
87 info->store->wait(
88 {info->getKey(kAllGatherFree + std::to_string(info->rank))});
89 }
90 info->store->deleteKey(
91 info->getKey(kAllGatherFree + std::to_string(info->rank)));
92 } catch (std::exception& ex) {
93 LOG(ERROR) << "(oob_allgather) Caught exception in Store Operation .. "
94 << "[" << ex.what() << "]";
95 return UCC_ERR_NO_MESSAGE;
96 }
97 return UCC_OK;
98}
99
100CommUCC::CommUCC(
101 std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
102 const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger)
103 : CommBase(logger) {
104 ucc_lib_config_h lib_config;
105 ucc_context_config_h context_config;
106 ucc_lib_params_t lib_params;
107 ucc_context_params_t context_params;
108 ucc_status_t st;
109
110 TORCH_UCC_CHECK(
111 ucc_lib_config_read("TORCH", nullptr, &lib_config),
112 "failed to read UCC lib config");
113 memset(&lib_params, 0, sizeof(ucc_lib_params_t));
114 lib_params.mask = UCC_LIB_PARAM_FIELD_THREAD_MODE;
115 lib_params.thread_mode = UCC_THREAD_MULTIPLE;
116 TORCH_UCC_CHECK(
117 ucc_init(&lib_params, lib_config, &lib), "failed to init UCC lib");
118 ucc_lib_config_release(lib_config);
119 ucc_lib_attr_t lib_attr;
120 lib_attr.mask = UCC_LIB_ATTR_FIELD_THREAD_MODE;
121 TORCH_UCC_CHECK(
122 ucc_lib_get_attr(lib, &lib_attr), "failed to query for lib attr");
123 TORCH_CHECK(
124 lib_attr.thread_mode == UCC_THREAD_MULTIPLE,
125 "ucc library wasn't initialized with multithreading support, "
126 "please check ucc build options");
127 st = ucc_context_config_read(lib, NULL, &context_config);
128 if (st != UCC_OK) {
129 // FIXME: would this cause deadlock if only one rank fails?
130 TORCH_UCC_CHECK(
131 ucc_finalize(lib),
132 "failed to finalize UCC library when failing to read UCC context config");
133 TORCH_UCC_LOG_ERROR(
134 TORCH_UCC_INIT,
135 c10::str("failed to read UCC context config: ", ucc_status_string(st)));
136 throw std::runtime_error(ucc_status_string(st));
137 }
138 st = ucc_context_config_modify(
139 context_config,
140 NULL,
141 "ESTIMATED_NUM_EPS",
142 std::to_string(oob->size).c_str());
143 if (st != UCC_OK) {
144 ucc_context_config_release(context_config);
145 ucc_finalize(lib);
146 TORCH_UCC_LOG_ERROR(
147 TORCH_UCC_INIT,
148 c10::str(
149 "UCC failed to modify UCC context config: ",
150 ucc_status_string(st)));
151 throw std::runtime_error(ucc_status_string(st));
152 }
153 memset(&context_params, 0, sizeof(ucc_context_params_t));
154 context_params.mask =
155 UCC_CONTEXT_PARAM_FIELD_TYPE | UCC_CONTEXT_PARAM_FIELD_OOB;
156 context_params.type = UCC_CONTEXT_SHARED;
157 context_params.oob.n_oob_eps = oob->size;
158 context_params.oob.oob_ep = oob->rank;
159 context_params.oob.allgather = oob_allgather;
160 context_params.oob.req_test = oob_allgather_test;
161 context_params.oob.req_free = oob_allgather_free;
162 context_params.oob.coll_info = oob.get();
163 st = ucc_context_create(lib, &context_params, context_config, &context);
164 ucc_context_config_release(context_config);
165 if (st != UCC_OK) {
166 TORCH_UCC_CHECK(
167 ucc_finalize(lib),
168 "failed to finalize UCC library when failing to creat UCC context");
169 TORCH_UCC_LOG_ERROR(
170 TORCH_UCC_INIT,
171 c10::str("UCC failed to create UCC context: ", ucc_status_string(st)));
172 throw std::runtime_error(ucc_status_string(st));
173 }
174}
175
176void CommUCC::progress() {
177 TORCH_UCC_CHECK(
178 ucc_context_progress(context), "failed to progress UCC collective");
179}
180
181void CommUCC::free_request(ucc_coll_req_h request) {
182 TORCH_UCC_CHECK(
183 ucc_collective_finalize(request), "failed to release UCC request");
184}
185
186CommUCC::~CommUCC() {
187 if (context != nullptr) {
188 TORCH_UCC_CHECK(
189 ucc_context_destroy(context), "failed to destroy UCC context");
190 }
191 if (lib != nullptr) {
192 TORCH_UCC_CHECK(ucc_finalize(lib), "failed to finalize UCC library");
193 }
194 context = nullptr;
195 lib = nullptr;
196}
197
198std::string ProcessGroupUCCLogger::getLogPrefix(torch_ucc_phase_t phase) {
199 // caller can override the phase stored locally
200 torch_ucc_phase_t phase_ =
201 (local_phase != phase && phase != TORCH_UCC_UNKNOWN) ? phase
202 : local_phase;
203 return c10::str(log_prefix, "[", ucc_phase_map.at(phase_), "]");
204}
205void ProcessGroupUCCLogger::setLogPrefix(std::string log_prefix_) {
206 log_prefix = log_prefix_;
207}
208
209ProcessGroupUCCLogger::ProcessGroupUCCLogger() {
210 setLogPrefix("[ProcessGroupUCC]");
211}
212ProcessGroupUCCLogger::ProcessGroupUCCLogger(
213 std::string log_prefix,
214 torch_ucc_phase_t phase)
215 : local_phase(phase) {
216 setLogPrefix(log_prefix);
217}
218
219} // namespace c10d
220
221#endif // USE_C10D_UCC
222