1#include <chrono>
2#include <iostream>
3
4#include <torch/csrc/distributed/c10d/FileStore.hpp>
5#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
6#include "CUDATest.hpp"
7#include "TestUtils.hpp"
8#include "c10d/Types.hpp"
9
10#include <c10/cuda/CUDAGuard.h>
11#include <c10/cuda/CUDAStream.h>
12#include <c10/util/irange.h>
13
14#include <gtest/gtest.h>
15#include <torch/csrc/autograd/profiler.h>
16
17using namespace c10d::test;
18
19using at::cuda::CUDAStream;
20
21class NCCLTestBase {
22 public:
23 NCCLTestBase(
24 const std::string& path,
25 const std::chrono::milliseconds pgTimeout = kBackendDefaultTimeout)
26 : path_(path), pgTimeout_(pgTimeout) {}
27
28 NCCLTestBase(NCCLTestBase&& other) {
29 path_ = std::move(other.path_);
30 pg_ = std::move(other.pg_);
31 }
32
33 ::c10d::ProcessGroupNCCL& getProcessGroup() {
34 return *pg_;
35 }
36
37 void initialize(int rank, int size) {
38 auto store = c10::make_intrusive<::c10d::FileStore>(path_, size);
39
40 c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts =
41 c10::make_intrusive<c10d::ProcessGroupNCCL::Options>();
42 opts->timeout = pgTimeout_;
43 setenv("ENABLE_NCCL_HEALTH_CHECK", "1", /* overwrite */ 1);
44 pg_ = std::unique_ptr<::c10d::ProcessGroupNCCL>(
45 new ::c10d::ProcessGroupNCCL(store, rank, size, std::move(opts)));
46 }
47
48 protected:
49 std::string path_;
50 std::unique_ptr<::c10d::ProcessGroupNCCL> pg_;
51 std::chrono::milliseconds pgTimeout_;
52};
53
54class NCCLTest : public NCCLTestBase {
55 public:
56 NCCLTest(
57 const std::string& path,
58 int worldSize,
59 std::chrono::milliseconds pgTimeout = kBackendDefaultTimeout)
60 : NCCLTestBase(path, pgTimeout),
61 numDevices_(cudaNumDevices()),
62 worldSize_(worldSize) {
63 // Each device has a single tensor to perf the NCCL op
64 ::at::globalContext().lazyInitCUDA();
65 tensors_.resize(numDevices_);
66 inputs_.resize(numDevices_);
67 outputs_.resize(numDevices_);
68 at::cuda::OptionalCUDAGuard deviceGuard;
69 for (const auto i : c10::irange(numDevices_)) {
70 deviceGuard.set_index(i);
71 tensors_[i] = at::empty({3, 3}, at::kCUDA);
72 inputs_[i].resize(worldSize_ * numDevices_);
73 outputs_[i].resize(worldSize_ * numDevices_);
74 for (auto j = 0; j < worldSize_ * numDevices_; ++j) {
75 inputs_[i][j] = at::empty({3, 3}, at::kCUDA);
76 outputs_[i][j] = at::empty({3, 3}, at::kCUDA);
77 }
78 }
79
80 // Allocate a stream per device.
81 //
82 // The "current stream" is set globally per device in THC, so we
83 // can't make two tensors on the same device use different streams
84 // and pass this along to the collective (since it uses the THC
85 // getters to retrieve the current stream).
86 //
87 streams_.reserve(numDevices_);
88 for (const auto i : c10::irange(numDevices_)) {
89 deviceGuard.set_index(i);
90 streams_.push_back(at::cuda::getStreamFromPool());
91 }
92 }
93
94 void wait(
95 c10::intrusive_ptr<c10d::Work>& work,
96 std::chrono::milliseconds timeout = kNoTimeout) {
97 c10::cuda::CUDAMultiStreamGuard guard(streams_);
98 work->wait(timeout);
99 }
100
101 std::vector<at::Tensor> getTensors() {
102 std::vector<at::Tensor> outputs(numDevices_);
103
104 // For the duration of this function, make THC use our streams
105 c10::cuda::CUDAMultiStreamGuard guard(streams_);
106
107 // Copy inputs to outputs
108 for (const auto i : c10::irange(numDevices_)) {
109 C10_CUDA_CHECK(cudaStreamSynchronize(streams_[i].stream()));
110 outputs[i] = tensors_[i].cpu();
111 }
112
113 return outputs;
114 }
115
116 std::vector<std::vector<at::Tensor>> getInputTensors() {
117 return getTensorLists(inputs_);
118 }
119 std::vector<std::vector<at::Tensor>> getOutputTensors() {
120 return getTensorLists(outputs_);
121 }
122
123 int numDevices() const {
124 return numDevices_;
125 }
126
127 private:
128 std::vector<std::vector<at::Tensor>> getTensorLists(
129 std::vector<std::vector<at::Tensor>>& tensor_lists) {
130 std::vector<std::vector<at::Tensor>> outputs(numDevices_);
131 for (auto& output : outputs) {
132 output = std::vector<at::Tensor>(worldSize_ * numDevices_);
133 }
134
135 // For the duration of this function, make THC use our streams
136 c10::cuda::CUDAMultiStreamGuard guard(streams_);
137
138 // Copy inputs to outputs
139 for (const auto i : c10::irange(numDevices_)) {
140 C10_CUDA_CHECK(cudaStreamSynchronize(streams_[i].stream()));
141 for (auto j = 0; j < worldSize_ * numDevices_; ++j) {
142 outputs[i][j] = tensor_lists[i][j].cpu();
143 }
144 }
145 return outputs;
146 }
147
148 protected:
149 // Launches sleep on every CUDA device
150 void launchDeviceSleep() {
151 at::cuda::OptionalCUDAGuard deviceGuard;
152 for (const auto i : c10::irange(numDevices_)) {
153 deviceGuard.set_index(i);
154 cudaSleep(streams_[i], 2000 * 1000 * 1000);
155 }
156 }
157
158 // Launches value initialization for every tensor
159 void valueInitialization() {
160 at::cuda::OptionalCUDAGuard deviceGuard;
161 for (const auto i : c10::irange(numDevices_)) {
162 deviceGuard.set_index(i);
163 tensors_[i].fill_(pg_->getRank() * numDevices_ + i);
164 }
165 }
166
167 const int numDevices_;
168 int worldSize_;
169 std::vector<at::Tensor> tensors_;
170 std::vector<std::vector<at::Tensor>> inputs_;
171 std::vector<std::vector<at::Tensor>> outputs_;
172 std::vector<CUDAStream> streams_;
173};
174
175class AllreduceNCCLTest : public NCCLTest {
176 public:
177 AllreduceNCCLTest(const std::string& path, int worldSize)
178 : NCCLTest(path, worldSize) {}
179
180 c10::intrusive_ptr<c10d::Work> run() {
181 // For the duration of this function, make THC use our streams
182 c10::cuda::CUDAMultiStreamGuard guard(streams_);
183
184 launchDeviceSleep();
185 valueInitialization();
186
187 using namespace torch::autograd::profiler;
188 // Make sure enabling profile does not make any issue. Note, in single
189 // process multi-device mode we do not expect any events be populated for
190 // collective operations, since profiling for that mode is not supported.
191 enableProfilerLegacy(ProfilerConfig(ProfilerState::CPU));
192 auto results = pg_->allreduce(tensors_);
193 disableProfilerLegacy();
194 return results;
195 }
196};
197
198class BroadcastNCCLTest : public NCCLTest {
199 public:
200 BroadcastNCCLTest(const std::string& path, int worldSize)
201 : NCCLTest(path, worldSize) {}
202
203 c10::intrusive_ptr<c10d::Work> run(int rootRank, int rootTensor) {
204 // For the duration of this function, make THC use our streams
205 c10::cuda::CUDAMultiStreamGuard guard(streams_);
206
207 launchDeviceSleep();
208 valueInitialization();
209
210 ::c10d::BroadcastOptions options;
211 options.rootRank = rootRank;
212 options.rootTensor = rootTensor;
213 return pg_->broadcast(tensors_, options);
214 }
215};
216
217class ReduceNCCLTest : public NCCLTest {
218 public:
219 ReduceNCCLTest(const std::string& path, int worldSize)
220 : NCCLTest(path, worldSize) {}
221
222 c10::intrusive_ptr<c10d::Work> run(int rootRank, int rootTensor) {
223 // For the duration of this function, make THC use our streams
224 c10::cuda::CUDAMultiStreamGuard guard(streams_);
225
226 launchDeviceSleep();
227 valueInitialization();
228
229 ::c10d::ReduceOptions options;
230 options.rootRank = rootRank;
231 options.rootTensor = rootTensor;
232 return pg_->reduce(tensors_, options);
233 }
234};
235
236class AllgatherNCCLTest : public NCCLTest {
237 public:
238 AllgatherNCCLTest(const std::string& path, int worldSize)
239 : NCCLTest(path, worldSize) {}
240
241 c10::intrusive_ptr<c10d::Work> run() {
242 // For the duration of this function, make THC use our streams
243 c10::cuda::CUDAMultiStreamGuard guard(streams_);
244
245 launchDeviceSleep();
246 valueInitialization();
247
248 return pg_->allgather(outputs_, tensors_);
249 }
250};
251
252class AllgatherBaseNCCLTest : public NCCLTest {
253 public:
254 AllgatherBaseNCCLTest(const std::string& path, int worldSize)
255 : NCCLTest(path, worldSize) {
256 output_tensor_ = at::empty({worldSize_, 3, 3}, at::kCUDA);
257 }
258
259 c10::intrusive_ptr<c10d::Work> run() {
260 // For the duration of this function, make THC use our streams
261 c10::cuda::CUDAMultiStreamGuard guard(streams_);
262
263 launchDeviceSleep();
264 valueInitialization();
265 // contains at least one element otherwise wouldn't run.
266 // this is a flattened allgather, hence one rank contributes
267 // only 1 tensor, regardless of number of devices
268 return pg_->_allgather_base(output_tensor_, tensors_[0]);
269 }
270
271 at::Tensor getOutputTensor() {
272 c10::cuda::CUDAMultiStreamGuard guard(streams_);
273 return output_tensor_.cpu();
274 }
275
276 at::Tensor getInputTensor() {
277 c10::cuda::CUDAMultiStreamGuard guard(streams_);
278 return tensors_[0].cpu();
279 }
280
281 private:
282 at::Tensor output_tensor_;
283};
284
285struct ReduceScatterNCCLTest : NCCLTest {
286 ReduceScatterNCCLTest(const std::string& path, int worldSize)
287 : NCCLTest(path, worldSize) {}
288
289 c10::intrusive_ptr<c10d::Work> run() {
290 // For the duration of this function, make THC use our streams
291 c10::cuda::CUDAMultiStreamGuard guard(streams_);
292
293 at::cuda::OptionalCUDAGuard deviceGuard;
294 launchDeviceSleep();
295
296 // Launch value initialization for every tensor
297 for (const auto i : c10::irange(numDevices_)) {
298 deviceGuard.set_index(i);
299 for (auto j = 0; j < worldSize_ * numDevices_; ++j) {
300 inputs_[i][j].fill_(
301 pg_->getRank() * numDevices_ * worldSize_ + i * worldSize_ + j);
302 }
303 }
304
305 return pg_->reduce_scatter(tensors_, inputs_);
306 }
307};
308
309class ReduceScatterBaseNCCLTest : public NCCLTest {
310 public:
311 ReduceScatterBaseNCCLTest(const std::string& path, int worldSize)
312 : NCCLTest(path, worldSize) {
313 output_tensor_ = at::empty({1}, at::kCUDA);
314 input_tensor_ = at::empty({worldSize}, at::kCUDA);
315 for (const auto i : c10::irange(worldSize)) {
316 input_tensor_[i] = i;
317 }
318 }
319
320 c10::intrusive_ptr<c10d::Work> run() {
321 // For the duration of this function, make THC use our streams
322 at::cuda::CUDAMultiStreamGuard guard(streams_);
323
324 launchDeviceSleep();
325 return pg_->_reduce_scatter_base(output_tensor_, input_tensor_);
326 }
327
328 at::Tensor getOutputTensor() {
329 at::cuda::CUDAMultiStreamGuard guard(streams_);
330 return output_tensor_.cpu();
331 }
332
333 at::Tensor getInputTensor() {
334 at::cuda::CUDAMultiStreamGuard guard(streams_);
335 return input_tensor_.cpu();
336 }
337
338 private:
339 at::Tensor output_tensor_;
340 at::Tensor input_tensor_;
341};
342
343void testAllreduce(const std::string& path, int rank, int size) {
344 auto test = AllreduceNCCLTest(path, size);
345 test.initialize(rank, size);
346 auto work = test.run();
347 // Wait for work to finish
348 test.wait(work);
349
350 // Validation
351 const int totalNumGPUs = test.numDevices() * size;
352 const auto expected = (totalNumGPUs * (totalNumGPUs - 1)) / 2;
353 const auto tensors = test.getTensors();
354 for (const auto& tensor : tensors) {
355 const auto* const data = tensor.data_ptr<float>();
356 for (const auto k : c10::irange(tensor.numel())) {
357 EXPECT_EQ(data[k], expected)
358 << "Allreduce outputs do not match expected outputs";
359 }
360 }
361}
362
363void testBroadcast(const std::string& path, int rank, int size) {
364 auto test = BroadcastNCCLTest(path, size);
365 test.initialize(rank, size);
366
367 const int numDevices = test.numDevices();
368 // try every permutation of root rank and root tensor
369 for (const auto rootRank : c10::irange(size)) {
370 for (const auto rootTensor : c10::irange(numDevices)) {
371 auto work = test.run(rootRank, rootTensor);
372
373 // wait for work to complete
374 test.wait(work);
375
376 // Check results
377 const auto expected = (rootRank * numDevices + rootTensor);
378 const auto tensors = test.getTensors();
379 for (const auto& tensor : tensors) {
380 const auto* const data = tensor.data_ptr<float>();
381 for (const auto k : c10::irange(tensor.numel())) {
382 EXPECT_EQ(data[k], expected)
383 << "Broadcast outputs do not match expected outputs";
384 }
385 }
386 }
387 }
388}
389
390void testReduce(const std::string& path, int rank, int size) {
391 auto test = ReduceNCCLTest(path, size);
392 test.initialize(rank, size);
393
394 const int numDevices = test.numDevices();
395 // try every permutation of root rank and root tensor
396 for (const auto rootRank : c10::irange(size)) {
397 for (const auto rootTensor : c10::irange(numDevices)) {
398 auto work = test.run(rootRank, rootTensor);
399
400 // wait for work to complete
401 test.wait(work);
402
403 // Validation
404 const int totalNumGPUs = numDevices * size;
405 const auto expected = (totalNumGPUs * (totalNumGPUs - 1)) / 2;
406 auto tensors = test.getTensors();
407 if (rank == rootRank) {
408 auto& tensor = tensors[rootTensor];
409 auto data = tensor.data_ptr<float>();
410 for (const auto k : c10::irange(tensor.numel())) {
411 EXPECT_EQ(data[k], expected)
412 << "Reduce outputs do not match expected outputs";
413 }
414 }
415 }
416 }
417}
418
419void testAllgather(const std::string& path, int rank, int size) {
420 auto test = AllgatherNCCLTest(path, size);
421 test.initialize(rank, size);
422 auto work = test.run();
423 // Wait for work to finish
424 test.wait(work);
425
426 // Validation
427 auto tensors = test.getOutputTensors();
428 // device index
429 for (auto& device : tensors) {
430 // rank index
431 for (const auto j : c10::irange(device.size())) {
432 const auto expected = j;
433 auto& tensor = device[j];
434 auto data = tensor.data_ptr<float>();
435 for (const auto k : c10::irange(tensor.numel())) {
436 EXPECT_EQ(data[k], expected)
437 << "Allgather outputs do not match expected outputs";
438 }
439 }
440 }
441}
442
443void testAllgatherBase(const std::string& path, int rank, int size) {
444 auto test = AllgatherBaseNCCLTest(path, size);
445 test.initialize(rank, size);
446 auto work = test.run();
447 // Wait for work to finish
448 test.wait(work);
449 // Validation
450 auto output_tensor = test.getOutputTensor();
451 auto input_tensor = test.getInputTensor();
452
453 auto data = output_tensor.data_ptr<float>();
454
455 // Rank index
456 for (const auto i : c10::irange(output_tensor.numel())) {
457 // expected is i // input.numel() <- rank, and each rank contributed rank *
458 // num_gpu
459 const auto expected = (i / input_tensor.numel()) * test.numDevices();
460 EXPECT_EQ(data[i], expected)
461 << "Allgather_base outputs do not match expected outputs";
462 }
463}
464void testReduceScatterBase(const std::string& path, int rank, int size) {
465 auto test = ReduceScatterBaseNCCLTest(path, size);
466 test.initialize(rank, size);
467 auto work = test.run();
468 // Wait for work to finish
469 test.wait(work);
470 // Validation
471 auto output_tensor = test.getOutputTensor();
472 auto input_tensor = test.getInputTensor();
473
474 auto data = output_tensor.data_ptr<float>();
475
476 // Rank index
477 for (const auto i : c10::irange(output_tensor.numel())) {
478 // expected is i * input.numel() <- rank, and each rank contributed rank *
479 // num_gpu
480 const auto expected = size * rank * test.numDevices();
481 EXPECT_EQ(data[i], expected)
482 << "Reducescatter_base outputs do not match expected outputs";
483 }
484}
485
486void testReduceScatter(const std::string& path, int rank, int size) {
487 auto test = ReduceScatterNCCLTest(path, size);
488 test.initialize(rank, size);
489 auto work = test.run();
490 // Wait for work to finish
491 test.wait(work);
492
493 const auto participants = test.numDevices() * size;
494 const auto base = (participants * (participants - 1)) / 2;
495
496 // Validation
497 auto tensors = test.getTensors();
498 // device index
499 for (const auto i : c10::irange(tensors.size())) {
500 const auto modifier = participants * (rank * participants + i);
501 const auto expected = base + modifier;
502 auto& tensor = tensors[i];
503 auto data = tensor.data_ptr<float>();
504 for (const auto j : c10::irange(tensor.numel())) {
505 EXPECT_EQ(data[j], expected)
506 << "ReduceScatter outputs do not match expected outputs!";
507 }
508 }
509}
510
511void testProcessGroupNCCLHealthCheckFailHelper(
512 const std::string& path,
513 bool timeout) {
514 // simulate world_size > 1 here via threads.
515 const int worldSize = 4;
516 std::unordered_set<uint64_t> nums;
517 auto runTest = [&](int i) {
518 NCCLTest test(path, worldSize, std::chrono::milliseconds(3000));
519 // Catch error relating to health check failure
520 bool error_caught = false;
521 try {
522 test.initialize(timeout ? 0 : -1, worldSize);
523 } catch (const std::exception& e) {
524 std::string errMsg = e.what();
525 const std::string kTimeoutErr =
526 "Failed to initialize NCCL communicator on rank";
527 const std::string kInvalidRankErr = "Invalid rank";
528 std::string expectedSubstr = timeout ? kTimeoutErr : kInvalidRankErr;
529 bool cond = errMsg.find(expectedSubstr) != std::string::npos;
530 EXPECT_TRUE(cond);
531 error_caught = true;
532 }
533 EXPECT_TRUE(error_caught);
534 };
535 std::vector<std::thread> threads;
536 threads.reserve(worldSize);
537 for (const auto r : c10::irange(worldSize)) {
538 threads.emplace_back(std::thread([=]() { runTest(r); }));
539 }
540 for (auto& t : threads) {
541 t.join();
542 }
543}
544
545void testProcessGroupNCCLHealthCheckFailException(
546 const std::string& path,
547 int /* unused */,
548 int /* unused */) {
549 testProcessGroupNCCLHealthCheckFailHelper(path, /* timeout */ false);
550}
551
552void testProcessGroupNCCLHealthCheckFailTimeout(
553 const std::string& path,
554 int /* unused */,
555 int /* unused */) {
556 testProcessGroupNCCLHealthCheckFailHelper(path, /* timeout */ true);
557}
558
559void testSequenceNumInit(
560 const std::string& path,
561 int /* unused */,
562 int /* unused */) {
563 // Note: ProcessGroupNCCLTest doesn't support multiprocess testing. So we
564 // simulate world_size > 1 here via threads.
565 const int worldSize = 2;
566 std::mutex m;
567 std::unordered_set<uint64_t> nums;
568 auto runTest = [&](int i) {
569 NCCLTest test(path, worldSize);
570 test.initialize(i, worldSize);
571 test.getProcessGroup().setSequenceNumberForGroup();
572 std::lock_guard<std::mutex> lock(m);
573 auto seqNum = test.getProcessGroup().getSequenceNumberForGroup();
574 nums.insert(seqNum);
575 };
576 std::vector<std::thread> threads;
577 threads.reserve(worldSize);
578 for (const auto r : c10::irange(worldSize)) {
579 threads.emplace_back(std::thread([=]() { runTest(r); }));
580 }
581 for (auto& t : threads) {
582 t.join();
583 }
584 EXPECT_EQ(nums.size(), 1);
585}
586
587class ProcessGroupNCCLTest : public ::testing::Test {
588 protected:
589 void SetUp() override {
590 // Use WORLD_SIZE and RANK environmental variables to do multi-node
591 // distributed testing
592 auto sizeEnv = std::getenv("WORLD_SIZE");
593 auto rankEnv = std::getenv("RANK");
594
595 if (sizeEnv && rankEnv) {
596 size_ = std::stoi(std::string(sizeEnv));
597 rank_ = std::stoi(std::string(rankEnv));
598 }
599 LOG(INFO) << "Multi-node world size: " << size_ << " rank: " << rank_;
600 }
601
602 void TearDown() override {
603 // Reset NCCL_BLOCKING_WAIT environment variable after each run.
604 ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "0", 1) == 0);
605 }
606
607 bool skipTest() {
608 // Skip tests if CUDA is not available.
609 if (!at::cuda::is_available()) {
610 LOG(INFO) << "CUDA not available, skipping test";
611 return true;
612 }
613 return false;
614 }
615
616 int size_{1};
617 int rank_{0};
618};
619
620TEST_F(ProcessGroupNCCLTest, testAllreduce) {
621 if (skipTest()) {
622 return;
623 }
624 {
625 TemporaryFile file;
626 testAllreduce(file.path, rank_, size_);
627 }
628}
629
630TEST_F(ProcessGroupNCCLTest, testBroadcast) {
631 if (skipTest()) {
632 return;
633 }
634 {
635 TemporaryFile file;
636 testBroadcast(file.path, rank_, size_);
637 }
638}
639
640TEST_F(ProcessGroupNCCLTest, testReduce) {
641 if (skipTest()) {
642 return;
643 }
644 {
645 TemporaryFile file;
646 testReduce(file.path, rank_, size_);
647 }
648}
649
650TEST_F(ProcessGroupNCCLTest, testAllgather) {
651 if (skipTest()) {
652 return;
653 }
654 {
655 TemporaryFile file;
656 testAllgather(file.path, rank_, size_);
657 }
658}
659
660TEST_F(ProcessGroupNCCLTest, testAllgatherBase) {
661 if (skipTest()) {
662 return;
663 }
664 {
665 TemporaryFile file;
666 testAllgatherBase(file.path, rank_, size_);
667 }
668}
669
670TEST_F(ProcessGroupNCCLTest, testReduceScatter) {
671 if (skipTest()) {
672 return;
673 }
674 {
675 TemporaryFile file;
676 testReduceScatter(file.path, rank_, size_);
677 }
678}
679
680TEST_F(ProcessGroupNCCLTest, testSequenceNumInit) {
681 if (skipTest()) {
682 return;
683 }
684 {
685 TemporaryFile file;
686 testSequenceNumInit(file.path, rank_, size_);
687 }
688}
689
690TEST_F(ProcessGroupNCCLTest, testProcessGroupNCCLHealthCheckFailTimeout) {
691 if (skipTest()) {
692 return;
693 }
694 {
695 TemporaryFile file;
696 testProcessGroupNCCLHealthCheckFailTimeout(file.path, rank_, size_);
697 }
698}
699
700TEST_F(ProcessGroupNCCLTest, testProcessGroupNCCLHealthCheckFailException) {
701 if (skipTest()) {
702 return;
703 }
704 {
705 TemporaryFile file;
706 testProcessGroupNCCLHealthCheckFailException(file.path, rank_, size_);
707 }
708}
709
710TEST_F(ProcessGroupNCCLTest, testReduceScatterBase) {
711 if (skipTest()) {
712 return;
713 }
714 {
715 TemporaryFile file;
716 testReduceScatterBase(file.path, rank_, size_);
717 }
718}
719
720TEST_F(ProcessGroupNCCLTest, testBackendName) {
721 if (skipTest()) {
722 return;
723 }
724 {
725 TemporaryFile file;
726 auto test = NCCLTestBase(file.path);
727 test.initialize(rank_, size_);
728 EXPECT_EQ(
729 test.getProcessGroup().getBackendName(),
730 std::string(c10d::NCCL_BACKEND_NAME));
731 }
732}
733