1 | #ifdef USE_C10D_UCC |
2 | |
3 | #include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp> |
4 | #include <torch/csrc/distributed/c10d/UCCTracing.hpp> |
5 | #include <torch/csrc/distributed/c10d/UCCUtils.hpp> |
6 | #include <list> |
7 | #include <memory> |
8 | #include <unordered_map> |
9 | #include <unordered_set> |
10 | |
11 | namespace c10d { |
12 | |
13 | namespace { |
14 | constexpr int64_t kBusyWaitMillis = 10; |
15 | |
16 | const std::map<c10::DeviceType, ucc_memory_type_t> ucc_mtype_map = { |
17 | {c10::kCPU, UCC_MEMORY_TYPE_HOST}, |
18 | {c10::kCUDA, UCC_MEMORY_TYPE_CUDA}, |
19 | }; |
20 | |
21 | ucc_memory_type_t to_ucc_memType(c10::DeviceType _c10_type) { |
22 | if (ucc_mtype_map.find(_c10_type) != ucc_mtype_map.end()) |
23 | return ucc_mtype_map.at(_c10_type); |
24 | else |
25 | return UCC_MEMORY_TYPE_UNKNOWN; |
26 | } |
27 | |
28 | const std::map<at::ScalarType, ucc_datatype_t> ucc_dtype_map = { |
29 | {at::kByte, UCC_DT_UINT8}, |
30 | {at::kChar, UCC_DT_INT8}, |
31 | {at::kHalf, UCC_DT_FLOAT16}, |
32 | {at::kBFloat16, UCC_DT_BFLOAT16}, |
33 | {at::kDouble, UCC_DT_FLOAT64}, |
34 | {at::kFloat, UCC_DT_FLOAT32}, |
35 | {at::kInt, UCC_DT_INT32}, |
36 | {at::kLong, UCC_DT_INT64}, |
37 | {at::kBool, UCC_DT_UINT8}, |
38 | }; |
39 | |
40 | ucc_datatype_t to_ucc_dType(at::Tensor _tensor) { |
41 | if (_tensor.scalar_type() == at::kBool && _tensor.element_size() != 1) { |
42 | TORCH_CHECK( |
43 | false, "Size of Boolean type larger than 1 is not supported in UCC" ); |
44 | } |
45 | try { |
46 | return ucc_dtype_map.at(_tensor.scalar_type()); |
47 | } catch (const std::out_of_range& e) { |
48 | TORCH_CHECK(false, "Not supported data type for UCC" ); |
49 | } |
50 | } |
51 | |
52 | const std::map<ReduceOp, ucc_reduction_op_t> ucc_op_map = { |
53 | {ReduceOp::SUM, UCC_OP_SUM}, |
54 | {ReduceOp::PRODUCT, UCC_OP_PROD}, |
55 | {ReduceOp::MIN, UCC_OP_MIN}, |
56 | {ReduceOp::MAX, UCC_OP_MAX}, |
57 | {ReduceOp::BAND, UCC_OP_BAND}, |
58 | {ReduceOp::BOR, UCC_OP_BOR}, |
59 | {ReduceOp::BXOR, UCC_OP_BXOR}, |
60 | {ReduceOp::AVG, UCC_OP_AVG}, |
61 | }; |
62 | |
63 | ucc_reduction_op_t to_ucc_reduceOp( |
64 | const ReduceOp _op, |
65 | const at::ScalarType _dt) { |
66 | if (_dt == at::kBool) { |
67 | if (_op == ReduceOp::SUM) { |
68 | // bitwise or |
69 | return UCC_OP_MAX; |
70 | } else if (_op == ReduceOp::PRODUCT) { |
71 | // bitwise and |
72 | return UCC_OP_MIN; |
73 | } else if (_op == ReduceOp::AVG) { |
74 | TORCH_CHECK(false, "Cannot use ReduceOp.AVG with boolean inputs" ); |
75 | } |
76 | } |
77 | |
78 | try { |
79 | return ucc_op_map.at(_op); |
80 | } catch (const std::out_of_range& e) { |
81 | TORCH_CHECK(false, "Not supported ReduceOp for UCC" ); |
82 | } |
83 | } |
84 | |
85 | struct torch_ucc_config_t { |
86 | c10::once_flag flag; |
87 | std::array<bool, 32> blocking_wait; |
88 | bool enable_comms_logger; |
89 | bool use_future; |
90 | // Sharing UCC communicator among multiple PGs to save resource. |
91 | bool shared_comm; |
92 | // Using allgatherv to achieve allgather, without flattening the list of |
93 | // (potentially non-contiguous) tensors. |
94 | bool use_allgatherv; |
95 | bool enable_health_check; |
96 | } torch_ucc_config; |
97 | |
98 | std::unordered_map<std::string, std::string> torch_ucc_envs_map = { |
99 | // TORCH_UCC_BLOCKING_WAIT allowed syntax: |
100 | // - TORCH_UCC_BLOCKING_WAIT=none --> blocking wait completely disabled |
101 | // - TORCH_UCC_BLOCKING_WAIT=all --> blocking wait completely enabled |
102 | // - TORCH_UCC_BLOCKING_WAIT=allreduce,send,recv --> blocking wait enabled |
103 | // on selected operations |
104 | // Supported operations: |
105 | // [allgather,allgather_base,allreduce,alltoall,broadcast, |
106 | // gather,reduce,reduce_scatter,scatter,send,recv] |
107 | {"TORCH_UCC_BLOCKING_WAIT" , "none" }, |
108 | |
109 | {"TORCH_UCC_USE_FUTURE" , "1" }, |
110 | {"TORCH_UCC_PROFILING_ENABLE" , "0" }, |
111 | {"TORCH_UCC_SHARED_COMM" , "1" }, |
112 | {"TORCH_UCC_USE_ALLGATHERV" , "0" }, |
113 | {"TORCH_UCC_ENABLE_HEALTH_CHECK" , "0" }, |
114 | {"TORCH_UCC_ENABLE_COMMS_LOGGER" , "0" }, |
115 | }; |
116 | |
117 | std::vector<OpType> parse_blocking_wait(std::string op_list_string) { |
118 | const static std::unordered_map<std::string, OpType> str2op = { |
119 | {"allgather" , OpType::ALLGATHER}, |
120 | {"allgather_base" , OpType::_ALLGATHER_BASE}, |
121 | {"allreduce" , OpType::ALLREDUCE}, |
122 | {"alltoall_base" , OpType::ALLTOALL_BASE}, |
123 | {"broadcast" , OpType::BROADCAST}, |
124 | {"gather" , OpType::GATHER}, |
125 | {"reduce" , OpType::REDUCE}, |
126 | {"reduce_scatter" , OpType::REDUCE_SCATTER}, |
127 | {"scatter" , OpType::SCATTER}, |
128 | {"send" , OpType::SEND}, |
129 | {"recv" , OpType::RECV}, |
130 | }; |
131 | auto op_list = parse_list(op_list_string); |
132 | if (op_list == std::vector<std::string>{"none" }) { |
133 | return {}; |
134 | } |
135 | std::vector<OpType> result; |
136 | if (op_list == std::vector<std::string>{"all" }) { |
137 | for (auto entry : str2op) { |
138 | result.push_back(entry.second); |
139 | } |
140 | } else { |
141 | for (auto op_string : op_list) { |
142 | result.push_back(str2op.at(op_string)); |
143 | } |
144 | } |
145 | return result; |
146 | } |
147 | |
148 | } // namespace |
149 | |
150 | void read_config() { |
151 | // default configuration |
152 | torch_ucc_config.blocking_wait.fill(false); |
153 | torch_ucc_config.use_future = true; |
154 | torch_ucc_config.shared_comm = false; |
155 | torch_ucc_config.use_allgatherv = false; |
156 | torch_ucc_config.enable_health_check = false; |
157 | torch_ucc_config.enable_comms_logger = false; |
158 | |
159 | // read all torch_ucc env. variables and update the map |
160 | char* env; |
161 | for (auto& torch_ucc_env : torch_ucc_envs_map) { |
162 | env = std::getenv(torch_ucc_env.first.c_str()); |
163 | if (env) { |
164 | torch_ucc_envs_map[torch_ucc_env.first] = std::string(env); |
165 | } |
166 | } |
167 | |
168 | auto blocking_wait_str = torch_ucc_envs_map.at("TORCH_UCC_BLOCKING_WAIT" ); |
169 | for (auto op : parse_blocking_wait(blocking_wait_str)) { |
170 | torch_ucc_config.blocking_wait[(std::uint8_t)op] = true; |
171 | } |
172 | // barrier is always blocking |
173 | torch_ucc_config.blocking_wait[(std::uint8_t)OpType::BARRIER] = true; |
174 | |
175 | // barrier is always blocking |
176 | torch_ucc_config.blocking_wait[(std::uint8_t)OpType::BARRIER] = true; |
177 | |
178 | torch_ucc_config.use_future = |
179 | std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_FUTURE" )); |
180 | torch_ucc_config.shared_comm = |
181 | std::stoi(torch_ucc_envs_map.at("TORCH_UCC_SHARED_COMM" )); |
182 | torch_ucc_config.use_allgatherv = |
183 | std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_ALLGATHERV" )); |
184 | torch_ucc_config.enable_health_check = |
185 | std::stoi(torch_ucc_envs_map.at("TORCH_UCC_ENABLE_HEALTH_CHECK" )); |
186 | torch_ucc_config.enable_comms_logger = |
187 | std::stoi(torch_ucc_envs_map.at("TORCH_UCC_ENABLE_COMMS_LOGGER" )); |
188 | } |
189 | |
190 | void check_device(c10::Device dev1, c10::Device dev2) { |
191 | if (dev1.is_cuda() && dev2.is_cuda() && dev1 != dev2) { |
192 | throw std::runtime_error("ProcessGroupUCC multidevice is not supported" ); |
193 | } |
194 | } |
195 | |
196 | void check_tensor(const std::vector<at::Tensor>& tensors) { |
197 | if (tensors.size() != 1) { |
198 | throw std::runtime_error( |
199 | "ProcessGroupUCC takes 1 tensor. Got " + |
200 | std::to_string(tensors.size()) + ". " ); |
201 | } |
202 | if (!tensors[0].is_contiguous()) { |
203 | throw std::runtime_error( |
204 | "ProcessGroupUCC input tensor has to be contiguous" ); |
205 | } |
206 | if (tensors[0].is_sparse()) { |
207 | throw std::runtime_error("ProcessGroupUCC input tensor has to be dense" ); |
208 | } |
209 | // TODO: check cuda case |
210 | } |
211 | |
212 | ProcessGroupUCC::WorkUCC::~WorkUCC() { |
213 | #ifdef USE_CUDA |
214 | if (fence && ep) { |
215 | std::lock_guard<std::mutex> lock(ep->event_pool_mutex); |
216 | ep->event_pool.push(std::move(fence)); |
217 | } |
218 | #endif |
219 | } |
220 | |
221 | void ProcessGroupUCC::WorkUCC::setException() { |
222 | if (exception() || !entry_) { |
223 | return; |
224 | } |
225 | exception_ = entry_->eptr_; |
226 | } |
227 | |
228 | void ProcessGroupUCC::WorkUCC::setAndThrowException() { |
229 | setException(); |
230 | if (exception()) { |
231 | std::rethrow_exception(exception()); |
232 | } |
233 | } |
234 | |
235 | bool ProcessGroupUCC::WorkUCC::isCompleted() { |
236 | if (!entry_) { |
237 | return true; |
238 | } |
239 | setException(); |
240 | // status_ <= 0 to avoid listing all possible status codes. The main thread |
241 | // needs to be unblocked when UCC (in progress thread) returns success (== 0) |
242 | // or any error code (< 0). |
243 | return exception() || entry_->status_ <= 0; |
244 | } |
245 | |
246 | bool ProcessGroupUCC::WorkUCC::isSuccess() const { |
247 | if (!entry_) { |
248 | return true; |
249 | } |
250 | return !exception() && entry_->status_ == 0; |
251 | } |
252 | |
253 | bool ProcessGroupUCC::WorkUCC::wait(std::chrono::milliseconds /* unused */) { |
254 | if (torch_ucc_config.enable_comms_logger && logger_) { |
255 | logger_->trace_generator->recordComms("wait" , (uintptr_t)this, rank_); |
256 | } |
257 | #ifdef USE_CUDA |
258 | if (fence && !torch_ucc_config.blocking_wait[(int)opType_]) { |
259 | // block user stream |
260 | setAndThrowException(); |
261 | fence->block(at::cuda::getCurrentCUDAStream()); |
262 | return true; |
263 | } |
264 | #endif |
265 | // wait for complete. For blocking case, the main thread will be blocked in |
266 | // this loop until the progress thread changes the status of this request. |
267 | // If timeout occurs, UCC will return UCC_ERR_TIMEOUT as the status. The |
268 | // main thread will throw out the exception then. There is no "abort" |
269 | // function in UCC currently. |
270 | while (!isCompleted()) |
271 | ; |
272 | setAndThrowException(); |
273 | // manually call profiling end callbacks if they are set, |
274 | // since progress thread does not own WorkUCC |
275 | if (Work::recordFunctionEndCallback_) { |
276 | Work::recordFunctionEndCallback_(); |
277 | Work::recordFunctionEndCallback_ = nullptr; |
278 | } |
279 | return true; |
280 | } |
281 | |
282 | c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupUCC::WorkUCC::getFuture() { |
283 | return future_; |
284 | } |
285 | |
286 | int ProcessGroupUCC::WorkUCC::sourceRank() const { |
287 | if (opType_ != OpType::RECV && opType_ != OpType::RECVANYSOURCE) { |
288 | // Throw an error |
289 | return Work::sourceRank(); |
290 | } |
291 | return sourceRank_; |
292 | } |
293 | |
294 | std::vector<at::Tensor> ProcessGroupUCC::WorkUCC::result() { |
295 | return *outputs_; |
296 | } |
297 | |
298 | void ProcessGroupUCC::ProgressEntry::finalize(std::exception_ptr eptr) { |
299 | ucc_status_t status = UCC_OK; |
300 | |
301 | if (request_ != nullptr) { |
302 | status = request_->status; |
303 | comm_->free_request(request_); |
304 | } |
305 | if (eptr) { |
306 | eptr_ = eptr; |
307 | } else { |
308 | status_ = status; |
309 | } |
310 | if (future_) { |
311 | if (eptr) { |
312 | future_->setError(eptr); |
313 | } else { |
314 | future_->markCompleted( |
315 | c10::IValue(data ? data->dst : std::vector<at::Tensor>())); |
316 | } |
317 | } |
318 | } |
319 | |
320 | Comm::Comm( |
321 | const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger_, |
322 | std::shared_ptr<torch_ucc_oob_coll_info_t> oob_, |
323 | c10::Device dev, |
324 | bool is_health_check) |
325 | : logger(logger_), |
326 | oob(oob_), |
327 | ucc_comm(oob, logger), |
328 | finalize_phase( |
329 | is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_FINALIZE), |
330 | cuda_device_index(TORCH_UCC_DEVICE_NOT_SET) { |
331 | if (dev.is_cuda()) { |
332 | cuda_device_index = dev.index(); |
333 | } |
334 | stop_progress_loop = false; |
335 | collective_inprogress = false; |
336 | progress_thread = std::thread(&Comm::progress_loop, this); |
337 | #ifdef _GNU_SOURCE |
338 | pthread_setname_np(progress_thread.native_handle(), "ucc-progress" ); |
339 | #endif |
340 | } |
341 | |
342 | Comm::~Comm() { |
343 | std::unique_lock<std::mutex> lock(mutex); |
344 | queue_consume_cv.wait( |
345 | lock, [&] { return progress_queue.empty() && !collective_inprogress; }); |
346 | stop_progress_loop = true; |
347 | lock.unlock(); |
348 | queue_produce_cv.notify_all(); |
349 | progress_thread.join(); |
350 | } |
351 | |
352 | std::shared_ptr<Comm> Comm::get_comm( |
353 | uint32_t& id, |
354 | c10::Device dev, |
355 | std::shared_ptr<torch_ucc_oob_coll_info_t> oob, |
356 | const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger, |
357 | bool is_health_check) { |
358 | static std::mutex m; |
359 | static std::weak_ptr<Comm> comm; |
360 | static uint32_t comm_id; |
361 | |
362 | std::lock_guard<std::mutex> lock(m); |
363 | id = comm_id; |
364 | |
365 | std::string group_id = "group_id" ; |
366 | if (is_health_check) { |
367 | group_id = c10::str(dev.type()) + "/" + group_id; |
368 | } |
369 | |
370 | std::vector<uint8_t> remote_comm_id; |
371 | oob->store->deleteKey(group_id + std::to_string(0)); |
372 | if (oob->rank != 0) { |
373 | std::vector<uint8_t> val = std::vector<uint8_t>( |
374 | reinterpret_cast<uint8_t*>(&id), |
375 | reinterpret_cast<uint8_t*>(&id) + sizeof(id)); |
376 | oob->store->set(group_id + std::to_string(oob->rank), val); |
377 | } else { |
378 | for (int i = 1; i < oob->size; i++) { |
379 | remote_comm_id = oob->store->get(group_id + std::to_string(i)); |
380 | oob->store->deleteKey(group_id + std::to_string(i)); |
381 | // Find the highest id. |
382 | id = std::max(id, *(reinterpret_cast<uint32_t*>(remote_comm_id.data()))); |
383 | } |
384 | std::vector<uint8_t> val = std::vector<uint8_t>( |
385 | reinterpret_cast<uint8_t*>(&id), |
386 | reinterpret_cast<uint8_t*>(&id) + sizeof(id)); |
387 | oob->store->set(group_id + std::to_string(oob->rank), val); |
388 | } |
389 | remote_comm_id = oob->store->get(group_id + std::to_string(0)); |
390 | oob->comm_id = *(reinterpret_cast<uint32_t*>(remote_comm_id.data())); |
391 | // Prepare comm_id (static variable) to the next id. |
392 | comm_id = oob->comm_id + 1; |
393 | |
394 | if (torch_ucc_config.shared_comm) { |
395 | std::shared_ptr<Comm> shared_comm = comm.lock(); |
396 | if (!shared_comm) { |
397 | shared_comm = std::make_shared<Comm>(logger, oob, dev, is_health_check); |
398 | comm = shared_comm; |
399 | } else { |
400 | if (dev.is_cuda() && !is_health_check) { |
401 | if ((shared_comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) && |
402 | (shared_comm->cuda_device_index != dev.index())) { |
403 | TORCH_UCC_LOG_ERROR( |
404 | is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_INIT, |
405 | "ucc communicator was initialized with different cuda device," |
406 | "multi device is not supported" ); |
407 | throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED)); |
408 | } |
409 | shared_comm->cuda_device_index = dev.index(); |
410 | } |
411 | } |
412 | return shared_comm; |
413 | } else { |
414 | return std::make_shared<Comm>(logger, oob, dev, is_health_check); |
415 | } |
416 | } |
417 | |
418 | void Comm::ucc_create_team( |
419 | ucc_team_h& team, |
420 | std::shared_ptr<torch_ucc_oob_coll_info_t> oob) { |
421 | ucc_status_t st; |
422 | ucc_team_params_t team_params; |
423 | team_params.mask = UCC_TEAM_PARAM_FIELD_EP | UCC_TEAM_PARAM_FIELD_EP_RANGE | |
424 | UCC_TEAM_PARAM_FIELD_OOB; |
425 | team_params.oob.allgather = oob_allgather; |
426 | team_params.oob.req_test = oob_allgather_test; |
427 | team_params.oob.req_free = oob_allgather_free; |
428 | team_params.oob.coll_info = oob.get(); |
429 | team_params.oob.n_oob_eps = oob->size; |
430 | team_params.oob.oob_ep = oob->rank; |
431 | team_params.ep = oob->rank; |
432 | team_params.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG; |
433 | TORCH_UCC_CHECK( |
434 | ucc_team_create_post(&ucc_comm.context, 1, &team_params, &team), |
435 | "failed to post team create" ); |
436 | do { |
437 | st = ucc_team_create_test(team); |
438 | ucc_context_progress(ucc_comm.context); |
439 | } while (st == UCC_INPROGRESS); |
440 | TORCH_UCC_CHECK(st, "failed to create UCC team" ); |
441 | } |
442 | |
443 | void Comm::ucc_destroy_team(ucc_team_h& team) { |
444 | std::unique_lock<std::mutex> lock(mutex); |
445 | queue_consume_cv.wait( |
446 | lock, [&] { return progress_queue.empty() && !collective_inprogress; }); |
447 | |
448 | ucc_status_t status; |
449 | while (UCC_INPROGRESS == (status = ucc_team_destroy(team))) { |
450 | if (UCC_OK != status) { |
451 | TORCH_UCC_LOG_ERROR( |
452 | finalize_phase, |
453 | c10::str("ucc team destroy error: " , ucc_status_string(status))); |
454 | break; |
455 | } |
456 | } |
457 | |
458 | lock.unlock(); |
459 | } |
460 | |
461 | void Comm::enqueue_collective( |
462 | std::unique_ptr<ProcessGroupUCC::WorkData> data, |
463 | c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work, |
464 | ucc_coll_args_t& coll, |
465 | ucc_team_h team) { |
466 | ucc_coll_req_h request; |
467 | TORCH_UCC_CHECK( |
468 | ucc_collective_init(&coll, &request, team), "failed to init collective" ); |
469 | TORCH_UCC_CHECK_REQUEST( |
470 | request, ucc_collective_post(request), "failed to post collective" ); |
471 | |
472 | auto entry = |
473 | std::make_shared<ProcessGroupUCC::ProgressEntry>(&ucc_comm, request); |
474 | entry->data = std::move(data); |
475 | entry->future_ = work->getFuture(); |
476 | work->entry_ = entry; |
477 | std::unique_lock<std::mutex> lock(mutex); |
478 | progress_queue.push_back(entry); |
479 | lock.unlock(); |
480 | queue_produce_cv.notify_one(); |
481 | } |
482 | |
483 | #ifdef USE_CUDA |
484 | void Comm::enqueue_cuda_collective( |
485 | std::unique_ptr<ProcessGroupUCC::WorkData> data, |
486 | c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work, |
487 | ucc_coll_args_t& coll, |
488 | ucc_team_h team, |
489 | ucc_ee_h ee) { |
490 | ucc_coll_req_h request; |
491 | TORCH_UCC_CHECK( |
492 | ucc_collective_init(&coll, &request, team), |
493 | "failed to init cuda collective" ); |
494 | ucc_ev_t comp_ev, *post_ev; |
495 | comp_ev.ev_type = UCC_EVENT_COMPUTE_COMPLETE; |
496 | comp_ev.ev_context = nullptr; |
497 | comp_ev.ev_context_size = 0; |
498 | comp_ev.req = request; |
499 | TORCH_UCC_CHECK_REQUEST( |
500 | request, |
501 | ucc_collective_triggered_post(ee, &comp_ev), |
502 | "failed to post triggered collective" ); |
503 | ucc_status_t st = ucc_ee_get_event(ee, &post_ev); |
504 | TORCH_CHECK(st == UCC_OK && post_ev->ev_type == UCC_EVENT_COLLECTIVE_POST); |
505 | ucc_ee_ack_event(ee, post_ev); |
506 | auto entry = |
507 | std::make_shared<ProcessGroupUCC::ProgressEntry>(&ucc_comm, request); |
508 | entry->data = std::move(data); |
509 | work->entry_ = entry; |
510 | std::unique_lock<std::mutex> lock(mutex); |
511 | progress_queue.push_back(entry); |
512 | lock.unlock(); |
513 | queue_produce_cv.notify_one(); |
514 | } |
515 | #endif |
516 | |
517 | void Comm::progress_loop() { |
518 | std::unique_lock<std::mutex> lock(mutex); |
519 | #ifdef USE_CUDA |
520 | bool device_set = false; |
521 | #endif |
522 | while (!stop_progress_loop) { |
523 | if (progress_queue.empty()) { |
524 | queue_produce_cv.wait(lock); |
525 | continue; |
526 | } |
527 | collective_inprogress = true; |
528 | auto work = progress_queue.front(); |
529 | progress_queue.pop_front(); |
530 | lock.unlock(); |
531 | #ifdef USE_CUDA |
532 | if ((!device_set) && (cuda_device_index != TORCH_UCC_DEVICE_NOT_SET)) { |
533 | c10::cuda::set_device(cuda_device_index); |
534 | device_set = true; |
535 | } |
536 | #endif |
537 | std::exception_ptr eptr; |
538 | try { |
539 | while (work->request_->status > 0) { |
540 | ucc_comm.progress(); |
541 | } |
542 | if (work->request_->status < 0) { |
543 | eptr = std::make_exception_ptr( |
544 | std::runtime_error(ucc_status_string(work->request_->status))); |
545 | std::string err_log = c10::str( |
546 | "Failed to progress communication" , // TODO: report exact op type or |
547 | // id? |
548 | ucc_status_string(work->request_->status)); |
549 | TORCH_UCC_LOG_ERROR(TORCH_UCC_COLL_PROGRESS, err_log); |
550 | } |
551 | } catch (...) { |
552 | eptr = std::current_exception(); |
553 | } |
554 | work->finalize(eptr); |
555 | work = nullptr; |
556 | collective_inprogress = false; |
557 | queue_consume_cv.notify_one(); |
558 | lock.lock(); |
559 | } |
560 | } |
561 | |
562 | ProcessGroupUCC::ProcessGroupUCC( |
563 | const c10::intrusive_ptr<Store>& store, |
564 | int rank, |
565 | int size, |
566 | std::chrono::duration<float> timeout) |
567 | : Backend(rank, size), timeout_(timeout) { |
568 | c10::call_once(torch_ucc_config.flag, read_config); |
569 | oob = std::make_shared<torch_ucc_oob_coll_info_t>(); |
570 | oob->rank = rank; |
571 | oob->size = size; |
572 | oob->store = store; |
573 | comm = nullptr; |
574 | cuda_ee = nullptr; |
575 | static uint32_t id = 0; |
576 | uint32_t pg_id = id++; |
577 | |
578 | logger = c10::make_intrusive<ProcessGroupUCCLogger>( |
579 | c10::str("[Rank " , rank_, "]" , "[ProcessGroupUCC-" , pg_id, "]" ), |
580 | TORCH_UCC_INIT); |
581 | TORCH_UCC_LOG_INFO( |
582 | TORCH_UCC_INIT, |
583 | c10::str( |
584 | "Created ProcessGroupUCC with " , |
585 | size, |
586 | " ranks, with timeout " , |
587 | timeout_.count(), |
588 | " secs" )); |
589 | std::string envs = "" ; |
590 | for (auto& torch_ucc_env : torch_ucc_envs_map) { |
591 | envs += ("\n\t" + torch_ucc_env.first + "=" + torch_ucc_env.second); |
592 | } |
593 | TORCH_UCC_LOG_INFO( |
594 | TORCH_UCC_INIT, |
595 | c10::str( |
596 | "Successfully read and set ProcessGroupUCC env. variables as followings" , |
597 | envs)); |
598 | |
599 | if (torch_ucc_config.enable_health_check) { |
600 | // Perform health check by initializing dummy communicators and destroying |
601 | // them. This will help indicate any UCC/UCX-related issues prior to the |
602 | // first collective. Run it in a separate thread and wait on CV to handle |
603 | // timeouts so that if there are hangs, the main thread can still run |
604 | // correctly. |
605 | runHealthCheck(); |
606 | } |
607 | if (torch_ucc_config.enable_comms_logger) { |
608 | logger->initCommsTracer(); |
609 | } |
610 | } |
611 | |
612 | ProcessGroupUCC::~ProcessGroupUCC() { |
613 | if (torch_ucc_config.enable_comms_logger) { |
614 | logger->flushComms(this->getRank(), this->getSize()); |
615 | } |
616 | if (comm) { |
617 | logger->setPhase(TORCH_UCC_FINALIZE); |
618 | comm->ucc_destroy_team(team); |
619 | TORCH_UCC_LOG_INFO( |
620 | TORCH_UCC_FINALIZE, "Successfully destroyed UCC library" ); |
621 | try { |
622 | if (cuda_ee) { |
623 | ucc_ee_destroy(cuda_ee); |
624 | } |
625 | } catch (std::exception& ex) { |
626 | TORCH_UCC_LOG_INFO( |
627 | TORCH_UCC_FINALIZE, |
628 | c10::str( |
629 | "(~ProcessGroupUCC) Caught error in Store Operation .. " , |
630 | "[" , |
631 | ex.what(), |
632 | "]" )); |
633 | } |
634 | comm = nullptr; |
635 | } |
636 | } |
637 | |
638 | #ifdef USE_CUDA |
639 | // Return CUDA device with ordinal given by input rank. |
640 | c10::Device getCUDADeviceForRank(int rank) { |
641 | TORCH_CHECK(rank >= 0, "Invalid rank " , rank); |
642 | auto numGPUs = at::cuda::getNumGPUs(); |
643 | auto deviceIdx = static_cast<c10::DeviceIndex>(rank % numGPUs); |
644 | return c10::Device(c10::DeviceType::CUDA, deviceIdx); |
645 | } |
646 | #endif |
647 | |
648 | void ProcessGroupUCC::runHealthCheck() { |
649 | // Run health check in a separate thread and wait on CV to handle timeouts. |
650 | // This design allows us to handle hangs. |
651 | |
652 | // When size_ is 1, there is no need to do any communication at all. |
653 | if (size_ == 1) |
654 | return; |
655 | |
656 | struct HealthCheckData { |
657 | std::mutex healthCheckMutex; |
658 | std::condition_variable healthCheckCv; |
659 | bool uccHealthCheckSuccess = false; |
660 | std::exception_ptr healthCheckException; |
661 | } healthCheckData; |
662 | |
663 | auto t = std::thread([&healthCheckData, this]() { |
664 | std::list<c10::Device> devices{c10::kCPU}; |
665 | #ifdef USE_CUDA |
666 | c10::cuda::OptionalCUDAGuard gpuGuard; |
667 | if (at::cuda::is_available()) { |
668 | devices.emplace_front(getCUDADeviceForRank(rank_)); |
669 | } |
670 | #endif |
671 | for (auto device : devices) { |
672 | bool is_last_device = (device == devices.back()); |
673 | try { |
674 | auto oob = std::make_shared<torch_ucc_oob_coll_info_t>(); |
675 | oob->rank = this->oob->rank; |
676 | oob->size = this->oob->size; |
677 | oob->store = this->oob->store; |
678 | ucc_team_h team = nullptr; |
679 | uint32_t comm_id; |
680 | #ifdef USE_CUDA |
681 | if (device.is_cuda()) { |
682 | gpuGuard.set_index(device.index()); |
683 | } |
684 | #endif |
685 | auto comm = Comm::get_comm(comm_id, device, oob, logger, true); |
686 | comm->ucc_create_team(team, oob); |
687 | comm->ucc_destroy_team(team); |
688 | TORCH_UCC_LOG_INFO( |
689 | TORCH_UCC_HEALTH_CHECK, |
690 | c10::str( |
691 | "UCC library health check succeed for device " , |
692 | c10::DeviceTypeName(device.type()))); |
693 | // Mark ucc health check as complete. |
694 | if (is_last_device) { |
695 | std::lock_guard<std::mutex> lk(healthCheckData.healthCheckMutex); |
696 | healthCheckData.uccHealthCheckSuccess = true; |
697 | } |
698 | |
699 | comm = nullptr; |
700 | oob = nullptr; |
701 | // Notify main thread the health check is complete. |
702 | if (is_last_device) { |
703 | healthCheckData.healthCheckCv.notify_one(); |
704 | } |
705 | } catch (const std::exception& e) { |
706 | // Populate exception ptr. |
707 | healthCheckData.healthCheckException = std::current_exception(); |
708 | // Unblock waiting main thread which will report exception. |
709 | healthCheckData.healthCheckCv.notify_one(); |
710 | } // Unknown exceptions will just cause the program to terminate. |
711 | } |
712 | }); |
713 | // We don't need to join the thread, just need to verify health check via the |
714 | // CV. Hence we detach the thread here. |
715 | t.detach(); // NOLINT |
716 | TORCH_UCC_LOG_INFO( |
717 | TORCH_UCC_HEALTH_CHECK, |
718 | c10::str( |
719 | "will wait up to " , |
720 | timeout_.count(), |
721 | " msec for UCC health check to complete." )); |
722 | std::unique_lock<std::mutex> lock(healthCheckData.healthCheckMutex); |
723 | healthCheckData.healthCheckCv.wait_for(lock, timeout_, [&healthCheckData]() { |
724 | return healthCheckData.uccHealthCheckSuccess; |
725 | }); |
726 | |
727 | if (healthCheckData.healthCheckException) { |
728 | std::rethrow_exception(healthCheckData.healthCheckException); |
729 | } |
730 | // If there is no exception, the likely culprit is a timeout/hang |
731 | TORCH_CHECK( |
732 | healthCheckData.uccHealthCheckSuccess, |
733 | "ProcessGroupUCC: Health check failure: Failed to initialize UCC on rank " , |
734 | rank_); |
735 | } |
736 | |
737 | void ProcessGroupUCC::set_timeout(ucc_coll_args_t& args) { |
738 | args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; |
739 | args.flags |= UCC_COLL_ARGS_FLAG_TIMEOUT; |
740 | args.timeout = timeout_.count(); |
741 | } |
742 | |
743 | #ifdef USE_CUDA |
744 | std::unique_ptr<at::cuda::CUDAEvent> ProcessGroupUCC::getPooledEvent() { |
745 | std::unique_ptr<at::cuda::CUDAEvent> ev; |
746 | std::lock_guard<std::mutex> lock(ep.event_pool_mutex); |
747 | if (ep.event_pool.empty()) { |
748 | ev = std::make_unique<at::cuda::CUDAEvent>(); |
749 | } else { |
750 | ev = std::move(ep.event_pool.front()); |
751 | ep.event_pool.pop(); |
752 | } |
753 | return ev; |
754 | } |
755 | #endif |
756 | |
757 | template <typename PreProcess, typename PostProcess> |
758 | c10::intrusive_ptr<Work> ProcessGroupUCC::collective_post( |
759 | OpType opType, |
760 | PreProcess preproc, |
761 | PostProcess postproc, |
762 | ucc_coll_args_t& coll, |
763 | std::unique_ptr<ProcessGroupUCC::WorkData> data, |
764 | c10::Device dev, |
765 | std::vector<at::Tensor>& inputTensors, |
766 | std::vector<at::Tensor>& outputTensors, |
767 | const char* prof_title) { |
768 | seq_++; |
769 | set_timeout(coll); |
770 | auto work = c10::make_intrusive<ProcessGroupUCC::WorkUCC>( |
771 | opType, seq_, prof_title, inputTensors, logger); |
772 | |
773 | if (opType == OpType::RECV) { |
774 | work->sourceRank_ = coll.root; |
775 | } |
776 | |
777 | RECORD_COMMS_TRACE( |
778 | logger->trace_generator, |
779 | work, |
780 | opType, |
781 | this->getRank(), |
782 | this->getSize(), |
783 | inputTensors, |
784 | outputTensors); |
785 | |
786 | // Store references to outputs to be used by result |
787 | work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputTensors); |
788 | switch (dev.type()) { |
789 | case c10::DeviceType::CPU: { |
790 | if (torch_ucc_config.use_future) { |
791 | work->future_ = c10::make_intrusive<at::ivalue::Future>( |
792 | c10::ListType::create(c10::TensorType::get())); |
793 | } |
794 | preproc(); |
795 | comm->enqueue_collective(std::move(data), work, coll, team); |
796 | postproc(); |
797 | return work; |
798 | } |
799 | #ifdef USE_CUDA |
800 | case c10::DeviceType::CUDA: { |
801 | auto cuda_ev = getPooledEvent(); |
802 | cuda_ev->record(at::cuda::getCurrentCUDAStream(dev.index())); |
803 | cuda_ev->block(*stream); |
804 | at::cuda::CUDAStreamGuard guard(*stream); |
805 | preproc(); |
806 | comm->enqueue_cuda_collective(std::move(data), work, coll, team, cuda_ee); |
807 | postproc(); |
808 | cuda_ev->record(*stream); |
809 | work->fence = std::move(cuda_ev); |
810 | work->ep = &ep; |
811 | if (torch_ucc_config.use_future) { |
812 | c10::cuda::CUDAMultiStreamGuard streamGuard(*stream); |
813 | std::vector<c10::Device> devList{dev}; |
814 | work->future_ = c10::make_intrusive<at::ivalue::Future>( |
815 | c10::ListType::create(c10::TensorType::get()), devList); |
816 | // Add a callback that runs profiling end callbacks |
817 | if (work->recordFunctionEndCallback_) { |
818 | work->future_->addCallback([work](at::ivalue::Future& /* unused */) { |
819 | work->recordFunctionEndCallback_(); |
820 | }); |
821 | } |
822 | |
823 | work->future_->markCompleted(c10::IValue(outputTensors)); |
824 | } |
825 | return work; |
826 | } |
827 | #endif // #ifdef USE_CUDA |
828 | default: { |
829 | TORCH_UCC_LOG_ERROR( |
830 | TORCH_UCC_COLL_POST, c10::str("unsupported device type " , dev.str())); |
831 | throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED)); |
832 | } |
833 | } |
834 | } |
835 | |
836 | c10::intrusive_ptr<Work> ProcessGroupUCC::allgather( |
837 | std::vector<std::vector<at::Tensor>>& outputTensors, |
838 | std::vector<at::Tensor>& inputTensors, |
839 | const AllgatherOptions& /* unused */) { |
840 | auto& tensor = inputTensors[0]; |
841 | check_device(tensor.device(), outputTensors[0][0].device()); |
842 | initComm(tensor.device()); |
843 | |
844 | if (tensor.device().is_cpu() || torch_ucc_config.use_allgatherv) { |
845 | AllgathervWorkData* data = new AllgathervWorkData(size_); |
846 | for (int i = 0; i < size_; i++) { |
847 | data->recv_lengths[i] = tensor.element_size() * tensor.numel(); |
848 | data->recv_offsets[i] = (uint64_t)outputTensors[0][i].data_ptr(); |
849 | } |
850 | ucc_coll_args_t coll; |
851 | coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; |
852 | coll.flags = |
853 | UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; |
854 | coll.coll_type = UCC_COLL_TYPE_ALLGATHERV; |
855 | coll.src.info.buffer = tensor.data_ptr(); |
856 | coll.src.info.count = tensor.element_size() * tensor.numel(); |
857 | coll.src.info.datatype = UCC_DT_UINT8; |
858 | coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); |
859 | coll.dst.info_v.buffer = nullptr; |
860 | coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data(); |
861 | coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data(); |
862 | coll.dst.info_v.datatype = UCC_DT_UINT8; |
863 | coll.dst.info_v.mem_type = |
864 | to_ucc_memType(outputTensors[0][0].device().type()); |
865 | SAVE_TENSORS(inputTensors, data->src); |
866 | SAVE_TENSORS(outputTensors[0], data->dst); |
867 | |
868 | return collective_post( |
869 | OpType::ALLGATHER, |
870 | []() {}, |
871 | []() {}, |
872 | coll, |
873 | std::unique_ptr<WorkData>(data), |
874 | tensor.device(), |
875 | inputTensors, |
876 | outputTensors[0], |
877 | "ucc:all_gather" ); |
878 | } else { |
879 | WorkData* data = new WorkData(); |
880 | std::vector<at::Tensor> flat_output(outputTensors.size()); |
881 | for (size_t i = 0; i < outputTensors.size(); i++) { |
882 | TORCH_CHECK( |
883 | outputTensors[i].size() == outputTensors.size() * size_, |
884 | "Tensor output list is not valid for the number of participants" ); |
885 | flat_output[i] = c10d::newLikeFlat(outputTensors, i); |
886 | } |
887 | SAVE_TENSORS(flat_output, data->flat); |
888 | ucc_coll_args_t coll; |
889 | coll.mask = 0; |
890 | coll.flags = 0; |
891 | coll.coll_type = UCC_COLL_TYPE_ALLGATHER; |
892 | coll.src.info.buffer = tensor.data_ptr(); |
893 | coll.src.info.count = tensor.numel(); |
894 | coll.src.info.datatype = to_ucc_dType(tensor); |
895 | coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); |
896 | coll.dst.info.buffer = flat_output[0].data_ptr(); |
897 | coll.dst.info.count = flat_output[0].numel(); |
898 | coll.dst.info.datatype = to_ucc_dType(flat_output[0]); |
899 | coll.dst.info.mem_type = |
900 | to_ucc_memType(outputTensors[0][0].device().type()); |
901 | |
902 | auto copy_from_flat = [&] { |
903 | bool asyncCopy = false; |
904 | #ifdef USE_CUDA |
905 | bool isCuda = outputTensors[0][0].device().is_cuda(); |
906 | ; |
907 | #endif |
908 | for (size_t i = 0; i < outputTensors.size(); i++) { |
909 | auto inumel = inputTensors[i].numel(); |
910 | for (size_t j = 0; j < outputTensors[i].size(); j++) { |
911 | TORCH_CHECK( |
912 | (outputTensors[i][j].numel() == inumel), |
913 | "Tensor operand counts must be same" ); |
914 | #ifdef USE_CUDA |
915 | if (isCuda) { |
916 | c10::cuda::CUDACachingAllocator::recordStream( |
917 | outputTensors[i][j].storage().data_ptr(), (*stream)); |
918 | asyncCopy = true; |
919 | } |
920 | #endif |
921 | outputTensors[i][j].copy_(flat_output[i][j], asyncCopy); |
922 | } |
923 | } |
924 | }; |
925 | return collective_post( |
926 | OpType::ALLGATHER, |
927 | []() {}, |
928 | copy_from_flat, |
929 | coll, |
930 | std::unique_ptr<WorkData>(data), |
931 | tensor.device(), |
932 | inputTensors, |
933 | outputTensors[0], |
934 | "ucc:all_gather" ); |
935 | } |
936 | } |
937 | |
938 | c10::intrusive_ptr<Work> ProcessGroupUCC::_allgather_base( |
939 | at::Tensor& outputTensor, |
940 | at::Tensor& inputTensor, |
941 | const AllgatherOptions& opts) { |
942 | check_tensor({outputTensor}); |
943 | check_tensor({inputTensor}); |
944 | initComm(outputTensor.device()); |
945 | |
946 | WorkData* data = new WorkData(); |
947 | |
948 | ucc_coll_args_t coll; |
949 | coll.mask = 0; |
950 | coll.flags = 0; |
951 | coll.coll_type = UCC_COLL_TYPE_ALLGATHER; |
952 | coll.src.info.buffer = inputTensor.data_ptr(); |
953 | coll.src.info.count = inputTensor.numel(); |
954 | coll.src.info.datatype = ucc_dtype_map.at(inputTensor.scalar_type()); |
955 | coll.src.info.mem_type = to_ucc_memType(inputTensor.device().type()); |
956 | coll.dst.info.buffer = outputTensor.data_ptr(); |
957 | coll.dst.info.count = outputTensor.numel(); |
958 | coll.dst.info.datatype = ucc_dtype_map.at(outputTensor.scalar_type()); |
959 | coll.dst.info.mem_type = to_ucc_memType(outputTensor.device().type()); |
960 | |
961 | std::vector<at::Tensor> inputTensors = {inputTensor}; |
962 | std::vector<at::Tensor> outputTensors = {outputTensor}; |
963 | SAVE_TENSORS(inputTensors, data->src); |
964 | SAVE_TENSORS(outputTensors, data->dst); |
965 | |
966 | return collective_post( |
967 | OpType::_ALLGATHER_BASE, |
968 | []() {}, |
969 | []() {}, |
970 | coll, |
971 | std::unique_ptr<WorkData>(data), |
972 | outputTensor.device(), |
973 | inputTensors, |
974 | outputTensors, |
975 | "ucc:allgather_base" ); |
976 | } |
977 | |
978 | c10::intrusive_ptr<Work> ProcessGroupUCC::allreduce( |
979 | std::vector<at::Tensor>& tensors, |
980 | const AllreduceOptions& opts) { |
981 | check_tensor(tensors); |
982 | auto& tensor = tensors[0]; |
983 | initComm(tensor.device()); |
984 | WorkData* data = new WorkData(); |
985 | |
986 | ucc_coll_args_t coll; |
987 | coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; |
988 | coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; |
989 | coll.coll_type = UCC_COLL_TYPE_ALLREDUCE; |
990 | coll.op = to_ucc_reduceOp(opts.reduceOp, tensor.scalar_type()); |
991 | coll.src.info.buffer = nullptr; |
992 | coll.src.info.count = tensor.numel(); |
993 | coll.src.info.datatype = to_ucc_dType(tensor); |
994 | coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); |
995 | coll.dst.info.buffer = tensor.data_ptr(); |
996 | coll.dst.info.count = tensor.numel(); |
997 | coll.dst.info.datatype = to_ucc_dType(tensor); |
998 | coll.dst.info.mem_type = to_ucc_memType(tensor.device().type()); |
999 | SAVE_TENSORS(tensors, data->dst); |
1000 | return collective_post( |
1001 | OpType::ALLREDUCE, |
1002 | []() {}, |
1003 | []() {}, |
1004 | coll, |
1005 | std::unique_ptr<WorkData>(data), |
1006 | tensor.device(), |
1007 | tensors, |
1008 | tensors, |
1009 | "ucc:all_reduce" ); |
1010 | } |
1011 | |
1012 | c10::intrusive_ptr<Work> ProcessGroupUCC::allreduce_coalesced( |
1013 | std::vector<at::Tensor>& /* unused */, |
1014 | const AllreduceCoalescedOptions& /* unused */) { |
1015 | throw std::runtime_error( |
1016 | "ProcessGroupUCC does not support allreduce_coalesced" ); |
1017 | } |
1018 | |
1019 | c10::intrusive_ptr<Work> ProcessGroupUCC::alltoall( |
1020 | std::vector<at::Tensor>& outputTensors, |
1021 | std::vector<at::Tensor>& inputTensors, |
1022 | const AllToAllOptions& /* unused */) { |
1023 | auto device = outputTensors[0].device(); |
1024 | for (const auto r : c10::irange(outputTensors.size())) { |
1025 | TORCH_CHECK( |
1026 | device == outputTensors[r].device() && |
1027 | device == inputTensors[r].device(), |
1028 | "Tensors must be on the same device" ) |
1029 | } |
1030 | |
1031 | initComm(device); |
1032 | ucc_coll_args_t coll; |
1033 | AlltoallWorkData* data; |
1034 | data = new AlltoallWorkData(size_); |
1035 | |
1036 | /* to avoid flatten the tensors, we use alltoallv to achieve Alltoall as |
1037 | follow. |
1038 | 1. store addresses of each tensor directly in displacements, keep buffer |
1039 | to nullptr, i.e., 0 |
1040 | 2. convert datatype to UINT8, which is always 1 bytes, to avoid wrong size |
1041 | calculation in UCC layer |
1042 | 3. post Alltoallv |
1043 | */ |
1044 | for (const auto i : c10::irange(size_)) { |
1045 | data->send_lengths[i] = |
1046 | (uint64_t)(inputTensors[i].element_size() * inputTensors[i].numel()); |
1047 | data->send_offsets[i] = (uint64_t)inputTensors[i].data_ptr(); |
1048 | data->recv_lengths[i] = |
1049 | (uint64_t)(outputTensors[i].element_size() * outputTensors[i].numel()); |
1050 | data->recv_offsets[i] = (uint64_t)outputTensors[i].data_ptr(); |
1051 | } |
1052 | |
1053 | coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; |
1054 | coll.flags = |
1055 | UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; |
1056 | coll.coll_type = UCC_COLL_TYPE_ALLTOALLV; |
1057 | coll.src.info_v.buffer = 0; |
1058 | coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data(); |
1059 | coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data(); |
1060 | coll.src.info_v.datatype = UCC_DT_UINT8; |
1061 | coll.src.info_v.mem_type = to_ucc_memType(inputTensors[0].device().type()); |
1062 | coll.dst.info_v.buffer = 0; |
1063 | coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data(); |
1064 | coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data(); |
1065 | coll.dst.info_v.datatype = UCC_DT_UINT8; |
1066 | coll.dst.info_v.mem_type = to_ucc_memType(outputTensors[0].device().type()); |
1067 | |
1068 | SAVE_TENSORS(inputTensors, data->src); |
1069 | SAVE_TENSORS(outputTensors, data->dst); |
1070 | |
1071 | return collective_post( |
1072 | OpType::ALLTOALL, |
1073 | []() {}, |
1074 | []() {}, |
1075 | coll, |
1076 | std::unique_ptr<WorkData>(data), |
1077 | device, |
1078 | inputTensors, |
1079 | outputTensors, |
1080 | "ucc:alltoall" ); |
1081 | } |
1082 | |
1083 | c10::intrusive_ptr<Work> ProcessGroupUCC::alltoall_base( |
1084 | at::Tensor& outputTensor, |
1085 | at::Tensor& inputTensor, |
1086 | std::vector<int64_t>& outputSplitSizes, |
1087 | std::vector<int64_t>& inputSplitSizes, |
1088 | const AllToAllOptions& /* unused */) { |
1089 | check_device(inputTensor.device(), outputTensor.device()); |
1090 | initComm(inputTensor.device()); |
1091 | ucc_coll_args_t coll; |
1092 | AlltoallWorkData* data; |
1093 | |
1094 | if ((outputSplitSizes.size() == 0) && (inputSplitSizes.size() == 0)) { |
1095 | data = new AlltoallWorkData(0); |
1096 | TORCH_CHECK( |
1097 | (outputTensor.size(0) % size_ == 0) && |
1098 | (inputTensor.size(0) % size_ == 0), |
1099 | "Tensor's dim 0 does not divide equally across group size" ); |
1100 | coll.mask = 0; |
1101 | coll.flags = 0; |
1102 | coll.coll_type = UCC_COLL_TYPE_ALLTOALL; |
1103 | coll.src.info.buffer = inputTensor.data_ptr(); |
1104 | coll.src.info.count = inputTensor.element_size() * inputTensor.numel(); |
1105 | coll.src.info.datatype = UCC_DT_UINT8; |
1106 | coll.src.info.mem_type = to_ucc_memType(inputTensor.device().type()); |
1107 | coll.dst.info.buffer = outputTensor.data_ptr(); |
1108 | coll.dst.info.count = outputTensor.element_size() * outputTensor.numel(); |
1109 | coll.dst.info.datatype = UCC_DT_UINT8; |
1110 | coll.dst.info.mem_type = to_ucc_memType(outputTensor.device().type()); |
1111 | coll.flags = 0; |
1112 | } else { |
1113 | data = new AlltoallWorkData(size_); |
1114 | c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); |
1115 | c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); |
1116 | computeLengthsAndOffsets( |
1117 | outputSplitSizes, |
1118 | outputTensor, |
1119 | &data->recv_lengths, |
1120 | &data->recv_offsets); |
1121 | computeLengthsAndOffsets( |
1122 | inputSplitSizes, inputTensor, &data->send_lengths, &data->send_offsets); |
1123 | coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; |
1124 | coll.coll_type = UCC_COLL_TYPE_ALLTOALLV; |
1125 | coll.src.info_v.buffer = inputTensor.data_ptr(); |
1126 | coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data(); |
1127 | coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data(); |
1128 | coll.src.info_v.datatype = to_ucc_dType(inputTensor); |
1129 | coll.src.info_v.mem_type = to_ucc_memType(inputTensor.device().type()); |
1130 | coll.dst.info_v.buffer = outputTensor.data_ptr(); |
1131 | coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data(); |
1132 | coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data(); |
1133 | coll.dst.info_v.datatype = to_ucc_dType(outputTensor); |
1134 | coll.dst.info_v.mem_type = to_ucc_memType(outputTensor.device().type()); |
1135 | coll.flags = UCC_COLL_ARGS_FLAG_CONTIG_SRC_BUFFER | |
1136 | UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER | UCC_COLL_ARGS_FLAG_COUNT_64BIT | |
1137 | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; |
1138 | |
1139 | if (torch_ucc_config.enable_comms_logger) { |
1140 | logger->trace_generator->recordOptionalInfo( |
1141 | outputSplitSizes, inputSplitSizes); |
1142 | } |
1143 | } |
1144 | std::vector<at::Tensor> inputTensors = {inputTensor}; |
1145 | std::vector<at::Tensor> outputTensors = {outputTensor}; |
1146 | SAVE_TENSORS(inputTensors, data->src); |
1147 | SAVE_TENSORS(outputTensors, data->dst); |
1148 | |
1149 | return collective_post( |
1150 | OpType::ALLTOALL_BASE, |
1151 | []() {}, |
1152 | []() {}, |
1153 | coll, |
1154 | std::unique_ptr<WorkData>(data), |
1155 | inputTensor.device(), |
1156 | inputTensors, |
1157 | outputTensors, |
1158 | "ucc:alltoall" ); |
1159 | } |
1160 | |
1161 | c10::intrusive_ptr<Work> ProcessGroupUCC::barrier(const BarrierOptions& opts) { |
1162 | c10::Device device = c10::Device(c10::DeviceType::CPU); |
1163 | #ifdef USE_CUDA |
1164 | auto numGPUs = c10::cuda::device_count(); |
1165 | if (!opts.device_ids.empty()) { |
1166 | device = c10::Device(c10::DeviceType::CUDA, opts.device_ids.front()); |
1167 | } else if (comm && comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) { |
1168 | device = c10::Device(c10::DeviceType::CUDA, comm->cuda_device_index); |
1169 | } else if (numGPUs > 0) { |
1170 | int8_t deviceIdx = static_cast<int8_t>(c10::cuda::current_device()); |
1171 | // if current device is 0, likely the device is not set, use the best guess |
1172 | if (0 == (int)deviceIdx) { |
1173 | deviceIdx = static_cast<int8_t>(this->getRank() % numGPUs); |
1174 | } |
1175 | TORCH_UCC_LOG_INFO( |
1176 | TORCH_UCC_COLL_POST, |
1177 | c10::str( |
1178 | "post barrier before specifying any GPU while there are " , |
1179 | numGPUs, |
1180 | " GPUs available. " , |
1181 | "Not clear if GPU barrier is required, using GPU " , |
1182 | (int)deviceIdx, |
1183 | " to perform barrier. " , |
1184 | "Specify device_ids option in barrier() to force " , |
1185 | "use of a particular device" )); |
1186 | device = c10::Device(c10::DeviceType::CUDA, deviceIdx); |
1187 | } |
1188 | #endif |
1189 | initComm(device); |
1190 | |
1191 | ucc_coll_args_t coll; |
1192 | coll.mask = 0; |
1193 | coll.flags = 0; |
1194 | coll.coll_type = UCC_COLL_TYPE_BARRIER; |
1195 | auto dummy_tensor = std::vector<at::Tensor>(); |
1196 | return collective_post( |
1197 | OpType::BARRIER, |
1198 | []() {}, |
1199 | []() {}, |
1200 | coll, |
1201 | nullptr, |
1202 | device, |
1203 | dummy_tensor, |
1204 | dummy_tensor, |
1205 | "ucc:barrier" ); |
1206 | } |
1207 | |
1208 | c10::intrusive_ptr<Work> ProcessGroupUCC::broadcast( |
1209 | std::vector<at::Tensor>& tensors, |
1210 | const BroadcastOptions& opts) { |
1211 | check_tensor(tensors); |
1212 | auto& tensor = tensors[0]; |
1213 | initComm(tensor.device()); |
1214 | WorkData* data = new WorkData(); |
1215 | |
1216 | ucc_coll_args_t coll; |
1217 | coll.mask = 0; |
1218 | coll.flags = 0; |
1219 | coll.coll_type = UCC_COLL_TYPE_BCAST; |
1220 | coll.src.info.buffer = tensor.data_ptr(); |
1221 | coll.src.info.count = tensor.numel(); |
1222 | coll.src.info.datatype = to_ucc_dType(tensor); |
1223 | coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); |
1224 | coll.root = opts.rootRank; |
1225 | SAVE_TENSORS(tensors, data->dst); |
1226 | |
1227 | if (torch_ucc_config.enable_comms_logger) { |
1228 | logger->trace_generator->recordOptionalInfo(opts.rootRank); |
1229 | } |
1230 | |
1231 | return collective_post( |
1232 | OpType::BROADCAST, |
1233 | []() {}, |
1234 | []() {}, |
1235 | coll, |
1236 | std::unique_ptr<WorkData>(data), |
1237 | tensor.device(), |
1238 | tensors, |
1239 | tensors, |
1240 | "ucc:broadcast" ); |
1241 | } |
1242 | |
1243 | c10::intrusive_ptr<Work> ProcessGroupUCC::gather( |
1244 | std::vector<std::vector<at::Tensor>>& outputTensors, |
1245 | std::vector<at::Tensor>& inputTensors, |
1246 | const GatherOptions& opts) { |
1247 | std::vector<at::Tensor> outputs; |
1248 | auto& input = inputTensors[0]; |
1249 | initComm(input.device()); |
1250 | |
1251 | AllgathervWorkData* data = new AllgathervWorkData(size_); |
1252 | ucc_coll_args_t coll; |
1253 | coll.root = opts.rootRank; |
1254 | coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; |
1255 | coll.flags = |
1256 | UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; |
1257 | coll.coll_type = UCC_COLL_TYPE_GATHERV; |
1258 | |
1259 | /* for non-root ranks, only src is valid */ |
1260 | coll.src.info.buffer = input.data_ptr(); |
1261 | coll.src.info.count = (uint64_t)(input.element_size() * input.numel()); |
1262 | coll.src.info.datatype = UCC_DT_UINT8; |
1263 | coll.src.info.mem_type = to_ucc_memType(input.device().type()); |
1264 | |
1265 | if (getRank() == opts.rootRank) { |
1266 | if (outputTensors.size() != 1) { |
1267 | TORCH_UCC_LOG_ERROR( |
1268 | TORCH_UCC_COLL_POST, |
1269 | c10::str( |
1270 | "gather requires a single-element output list containing a list with " , |
1271 | getSize(), |
1272 | " tensors." )); |
1273 | } else if (outputTensors[0].size() != static_cast<size_t>(getSize())) { |
1274 | TORCH_UCC_LOG_ERROR( |
1275 | TORCH_UCC_COLL_POST, |
1276 | c10::str( |
1277 | "Incorrect output list size " , |
1278 | outputTensors[0].size(), |
1279 | ". Output list size should be " , |
1280 | getSize(), |
1281 | ", same as size of the process group." )); |
1282 | } |
1283 | outputs = outputTensors[0]; |
1284 | |
1285 | for (int i = 0; i < size_; i++) { |
1286 | data->recv_lengths[i] = |
1287 | (uint64_t)(outputs[i].element_size() * outputs[i].numel()); |
1288 | data->recv_offsets[i] = (uint64_t)outputs[i].data_ptr(); |
1289 | } |
1290 | /* use gatherv and store non-contiguous addresses in displacements to avoid |
1291 | * flatten outputTensors */ |
1292 | coll.dst.info_v.buffer = nullptr; |
1293 | coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data(); |
1294 | coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data(); |
1295 | coll.dst.info_v.datatype = UCC_DT_UINT8; |
1296 | coll.dst.info_v.mem_type = to_ucc_memType(outputs[0].device().type()); |
1297 | |
1298 | SAVE_TENSORS(outputs, data->dst); |
1299 | } else { |
1300 | // for non-root ranks, outputTensors should be an empty list |
1301 | if (outputTensors.size() != 0) { |
1302 | TORCH_UCC_LOG_ERROR( |
1303 | TORCH_UCC_COLL_POST, "requires empty output on non-root" ); |
1304 | } |
1305 | outputs = {}; |
1306 | // append a empty tensor to the list to be used by future mark |
1307 | outputs.emplace_back(); |
1308 | } |
1309 | |
1310 | SAVE_TENSORS(inputTensors, data->src); |
1311 | |
1312 | return collective_post( |
1313 | OpType::GATHER, |
1314 | []() {}, |
1315 | []() {}, |
1316 | coll, |
1317 | std::unique_ptr<WorkData>(data), |
1318 | input.device(), |
1319 | inputTensors, |
1320 | outputs, |
1321 | "ucc:gather" ); |
1322 | } |
1323 | |
1324 | c10::intrusive_ptr<Work> ProcessGroupUCC::reduce( |
1325 | std::vector<at::Tensor>& tensors, |
1326 | const ReduceOptions& opts) { |
1327 | check_tensor(tensors); |
1328 | auto& tensor = tensors[0]; |
1329 | initComm(tensor.device()); |
1330 | WorkData* data = new WorkData(); |
1331 | |
1332 | ucc_coll_args_t coll; |
1333 | coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; |
1334 | coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; |
1335 | coll.coll_type = UCC_COLL_TYPE_REDUCE; |
1336 | coll.op = ucc_op_map.at(opts.reduceOp); |
1337 | coll.root = opts.rootRank; |
1338 | coll.src.info.buffer = tensor.data_ptr(); |
1339 | coll.src.info.count = tensor.numel(); |
1340 | coll.src.info.datatype = ucc_dtype_map.at(tensor.scalar_type()); |
1341 | coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); |
1342 | coll.dst.info.buffer = tensor.data_ptr(); |
1343 | coll.dst.info.count = tensor.numel(); |
1344 | coll.dst.info.datatype = ucc_dtype_map.at(tensor.scalar_type()); |
1345 | coll.dst.info.mem_type = to_ucc_memType(tensor.device().type()); |
1346 | SAVE_TENSORS(tensors, data->dst); |
1347 | return collective_post( |
1348 | OpType::REDUCE, |
1349 | []() {}, |
1350 | []() {}, |
1351 | coll, |
1352 | std::unique_ptr<WorkData>(data), |
1353 | tensor.device(), |
1354 | tensors, |
1355 | tensors, |
1356 | "ucc:reduce" ); |
1357 | } |
1358 | |
1359 | c10::intrusive_ptr<Work> ProcessGroupUCC::reduce_scatter( |
1360 | std::vector<at::Tensor>& outputTensors, |
1361 | std::vector<std::vector<at::Tensor>>& inputTensors, |
1362 | const ReduceScatterOptions& opts) { |
1363 | TORCH_CHECK( |
1364 | (outputTensors.size() == inputTensors.size()), |
1365 | "Tensor input/output list for reduce_scatter must have same size" ); |
1366 | check_tensor(outputTensors); |
1367 | check_device(inputTensors[0][0].device(), outputTensors[0].device()); |
1368 | initComm(inputTensors[0][0].device()); |
1369 | auto data = std::make_unique<WorkData>(); |
1370 | std::vector<at::Tensor> flat_input(inputTensors.size()); |
1371 | for (size_t i = 0; i < inputTensors.size(); i++) { |
1372 | TORCH_CHECK( |
1373 | inputTensors[i].size() == inputTensors.size() * size_, |
1374 | "Tensor input list is not valid for the number of participants" ); |
1375 | flat_input[i] = c10d::newLikeFlat(inputTensors, i); |
1376 | } |
1377 | SAVE_TENSORS(flat_input, data->flat); |
1378 | check_tensor(flat_input); |
1379 | ucc_coll_args_t coll; |
1380 | coll.mask = 0; |
1381 | coll.flags = 0; |
1382 | coll.coll_type = UCC_COLL_TYPE_REDUCE_SCATTER; |
1383 | coll.op = to_ucc_reduceOp(opts.reduceOp, flat_input[0].scalar_type()); |
1384 | |
1385 | coll.src.info.buffer = flat_input[0].data_ptr(); |
1386 | coll.src.info.count = flat_input[0].numel(); |
1387 | coll.src.info.datatype = to_ucc_dType(flat_input[0]); |
1388 | coll.src.info.mem_type = to_ucc_memType(flat_input[0].device().type()); |
1389 | coll.dst.info.buffer = outputTensors[0].data_ptr(); |
1390 | coll.dst.info.count = outputTensors[0].numel(); |
1391 | coll.dst.info.datatype = to_ucc_dType(outputTensors[0]); |
1392 | coll.dst.info.mem_type = to_ucc_memType(outputTensors[0].device().type()); |
1393 | |
1394 | SAVE_TENSORS(inputTensors[0], data->src); |
1395 | SAVE_TENSORS(outputTensors, data->dst); |
1396 | |
1397 | auto copy_to_flat = [&] { |
1398 | bool asyncCopy = false; |
1399 | auto isize = inputTensors.size(); |
1400 | #ifdef USE_CUDA |
1401 | bool isCuda = inputTensors[0][0].device().is_cuda(); |
1402 | #endif |
1403 | for (size_t i = 0; i < isize; i++) { |
1404 | auto onumel = outputTensors[i].numel(); |
1405 | for (size_t j = 0; j < inputTensors[i].size(); j++) { |
1406 | TORCH_CHECK( |
1407 | (inputTensors[i][j].numel() == onumel), |
1408 | "Tensor operand counts must be same" ); |
1409 | #ifdef USE_CUDA |
1410 | if (isCuda) { |
1411 | c10::cuda::CUDACachingAllocator::recordStream( |
1412 | inputTensors[i][j].storage().data_ptr(), (*stream)); |
1413 | asyncCopy = true; |
1414 | } |
1415 | #endif |
1416 | flat_input[i][j].copy_(inputTensors[i][j], asyncCopy); |
1417 | } |
1418 | } |
1419 | }; |
1420 | |
1421 | return collective_post( |
1422 | OpType::REDUCE_SCATTER, |
1423 | copy_to_flat, |
1424 | []() {}, |
1425 | coll, |
1426 | std::move(data), |
1427 | inputTensors[0][0].device(), |
1428 | inputTensors[0], |
1429 | outputTensors, |
1430 | "ucc:reduce_scatter" ); |
1431 | } |
1432 | |
1433 | c10::intrusive_ptr<Work> ProcessGroupUCC::scatter( |
1434 | std::vector<at::Tensor>& outputTensors, |
1435 | std::vector<std::vector<at::Tensor>>& inputTensors, |
1436 | const ScatterOptions& opts) { |
1437 | auto& tensor = outputTensors[0]; |
1438 | initComm(tensor.device()); |
1439 | |
1440 | ScattervWorkData* data = new ScattervWorkData(size_); |
1441 | ucc_coll_args_t coll; |
1442 | coll.root = opts.rootRank; |
1443 | coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; |
1444 | coll.flags = |
1445 | UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; |
1446 | coll.coll_type = UCC_COLL_TYPE_SCATTERV; |
1447 | |
1448 | if (getRank() == opts.rootRank) { |
1449 | /* src is only valid at non-root rank */ |
1450 | if (inputTensors.size() != 1) { |
1451 | TORCH_UCC_LOG_ERROR( |
1452 | TORCH_UCC_COLL_POST, |
1453 | c10::str( |
1454 | "gather requires a single-element output list containing a list with " , |
1455 | getSize(), |
1456 | " tensors." )); |
1457 | } else if (inputTensors[0].size() != static_cast<size_t>(getSize())) { |
1458 | TORCH_UCC_LOG_ERROR( |
1459 | TORCH_UCC_COLL_POST, |
1460 | c10::str( |
1461 | "Incorrect output list size " , |
1462 | inputTensors[0].size(), |
1463 | ". Output list size should be " , |
1464 | getSize(), |
1465 | ", same as size of the process group." )); |
1466 | } |
1467 | |
1468 | for (int i = 0; i < size_; i++) { |
1469 | data->send_lengths[i] = (uint64_t)tensor.element_size() * tensor.numel(); |
1470 | data->send_offsets[i] = (uint64_t)inputTensors[0][i].data_ptr(); |
1471 | } |
1472 | /* use scatter and store non-contiguous addresses in displacements to avoid |
1473 | * flatten inputTensors */ |
1474 | coll.src.info_v.buffer = nullptr; |
1475 | coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data(); |
1476 | coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data(); |
1477 | coll.src.info_v.datatype = UCC_DT_UINT8; |
1478 | coll.src.info_v.mem_type = |
1479 | to_ucc_memType(inputTensors[0][0].device().type()); |
1480 | |
1481 | SAVE_TENSORS(inputTensors[0], data->src); |
1482 | } else { |
1483 | // for non-root ranks, inputTensors should be an empty list |
1484 | if (inputTensors.size() != 0) { |
1485 | TORCH_UCC_LOG_ERROR( |
1486 | TORCH_UCC_COLL_POST, "requires empty output on non-root" ); |
1487 | } |
1488 | } |
1489 | |
1490 | coll.dst.info.buffer = tensor.data_ptr(); |
1491 | coll.dst.info.count = (uint64_t)tensor.element_size() * tensor.numel(); |
1492 | coll.dst.info.datatype = UCC_DT_UINT8; |
1493 | coll.dst.info.mem_type = to_ucc_memType(tensor.device().type()); |
1494 | SAVE_TENSORS(outputTensors, data->dst); |
1495 | |
1496 | return collective_post( |
1497 | OpType::SCATTER, |
1498 | []() {}, |
1499 | []() {}, |
1500 | coll, |
1501 | std::unique_ptr<WorkData>(data), |
1502 | tensor.device(), |
1503 | inputTensors[0], |
1504 | outputTensors, |
1505 | "ucc:scatter" ); |
1506 | } |
1507 | |
1508 | c10::intrusive_ptr<Work> ProcessGroupUCC::send( |
1509 | std::vector<at::Tensor>& tensors, |
1510 | int dstRank, |
1511 | int tag) { |
1512 | check_tensor(tensors); |
1513 | auto& tensor = tensors[0]; |
1514 | initComm(tensor.device()); |
1515 | |
1516 | WorkData* data = new WorkData(); |
1517 | ucc_coll_args_t coll; |
1518 | coll.tag = tag; |
1519 | coll.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | UCC_COLL_ARGS_FIELD_TAG; |
1520 | coll.flags = 0; |
1521 | coll.coll_type = UCC_COLL_TYPE_BCAST; |
1522 | coll.src.info.buffer = tensor.data_ptr(); |
1523 | coll.src.info.count = tensor.numel(); |
1524 | coll.src.info.datatype = to_ucc_dType(tensor); |
1525 | coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); |
1526 | coll.root = getRank(); |
1527 | |
1528 | coll.active_set.size = 2; |
1529 | coll.active_set.start = getRank(); |
1530 | coll.active_set.stride = dstRank - getRank(); |
1531 | SAVE_TENSORS(tensors, data->dst); |
1532 | |
1533 | return collective_post( |
1534 | OpType::SEND, |
1535 | []() {}, |
1536 | []() {}, |
1537 | coll, |
1538 | std::unique_ptr<WorkData>(data), |
1539 | tensor.device(), |
1540 | tensors, |
1541 | tensors, |
1542 | "ucc:send" ); |
1543 | } |
1544 | |
1545 | c10::intrusive_ptr<Work> ProcessGroupUCC::recv( |
1546 | std::vector<at::Tensor>& tensors, |
1547 | int srcRank, |
1548 | int tag) { |
1549 | check_tensor(tensors); |
1550 | auto& tensor = tensors[0]; |
1551 | initComm(tensor.device()); |
1552 | |
1553 | WorkData* data = new WorkData(); |
1554 | ucc_coll_args_t coll; |
1555 | coll.tag = tag; |
1556 | coll.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | UCC_COLL_ARGS_FIELD_TAG; |
1557 | coll.flags = 0; |
1558 | coll.coll_type = UCC_COLL_TYPE_BCAST; |
1559 | coll.src.info.buffer = tensor.data_ptr(); |
1560 | coll.src.info.count = tensor.numel(); |
1561 | coll.src.info.datatype = to_ucc_dType(tensor); |
1562 | coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); |
1563 | coll.root = srcRank; |
1564 | |
1565 | coll.active_set.size = 2; |
1566 | coll.active_set.start = srcRank; |
1567 | coll.active_set.stride = getRank() - srcRank; |
1568 | SAVE_TENSORS(tensors, data->dst); |
1569 | |
1570 | return collective_post( |
1571 | OpType::RECV, |
1572 | []() {}, |
1573 | []() {}, |
1574 | coll, |
1575 | std::unique_ptr<WorkData>(data), |
1576 | tensor.device(), |
1577 | tensors, |
1578 | tensors, |
1579 | "ucc:recv" ); |
1580 | } |
1581 | |
1582 | void ProcessGroupUCC::setSequenceNumberForGroup() {} |
1583 | |
1584 | uint64_t ProcessGroupUCC::getSequenceNumberForGroup() { |
1585 | return seq_; |
1586 | } |
1587 | |
1588 | c10::intrusive_ptr<Backend> ProcessGroupUCC::createProcessGroupUCC( |
1589 | const c10::intrusive_ptr<::c10d::Store>& store, |
1590 | int rank, |
1591 | int size, |
1592 | const std::chrono::duration<float>& timeout) { |
1593 | return c10::make_intrusive<ProcessGroupUCC>(store, rank, size, timeout); |
1594 | } |
1595 | |
1596 | void ProcessGroupUCC::initComm(c10::Device dev) { |
1597 | if (!comm) { |
1598 | #ifdef USE_CUDA |
1599 | if (dev.is_cuda()) { |
1600 | c10::cuda::set_device(dev.index()); |
1601 | } |
1602 | #endif |
1603 | comm = Comm::get_comm(comm_id, dev, oob, logger); |
1604 | TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCX library" ); |
1605 | comm->ucc_create_team(team, oob); |
1606 | TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCC library" ); |
1607 | logger->setPhase(TORCH_UCC_READY); |
1608 | } else { |
1609 | if (dev.is_cuda()) { |
1610 | if ((comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) && |
1611 | (comm->cuda_device_index != dev.index())) { |
1612 | TORCH_UCC_LOG_ERROR( |
1613 | TORCH_UCC_INIT, |
1614 | "ucc communicator was initialized with different cuda device," |
1615 | "multi device is not supported" ); |
1616 | throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED)); |
1617 | } |
1618 | comm->cuda_device_index = dev.index(); |
1619 | } |
1620 | } |
1621 | #ifdef USE_CUDA |
1622 | // Create UCC execution engine. |
1623 | if (!cuda_ee && dev.is_cuda()) { |
1624 | stream = std::make_unique<at::cuda::CUDAStream>( |
1625 | at::cuda::getStreamFromPool(true, dev.index())); |
1626 | ucc_ee_params_t params; |
1627 | params.ee_type = UCC_EE_CUDA_STREAM; |
1628 | params.ee_context = (void*)stream->stream(); |
1629 | params.ee_context_size = sizeof(cudaStream_t); |
1630 | TORCH_UCC_CHECK( |
1631 | ucc_ee_create(team, ¶ms, &cuda_ee), |
1632 | "failed to create UCC execution engine" ); |
1633 | } |
1634 | #endif |
1635 | } |
1636 | |
1637 | } // namespace c10d |
1638 | |
1639 | #endif // USE_C10D_UCC |
1640 | |