1#include <chrono>
2
3#include <c10/util/irange.h>
4#include <torch/csrc/cuda/nccl.h>
5#include <torch/csrc/distributed/c10d/FileStore.hpp>
6#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
7#include "CUDATest.hpp"
8#include "TestUtils.hpp"
9
10#include <gtest/gtest.h>
11
12using namespace c10d::test;
13
14constexpr int kNcclErrorHandlingVersion = 2400;
15
16class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
17 public:
18 WorkNCCLSimulateErrors(
19 const std::vector<at::Device>& devices,
20 bool simulate_error,
21 int rank,
22 c10d::OpType opType,
23 uint64_t seq)
24 : WorkNCCL(devices, rank, opType, seq), simulate_error_(simulate_error) {}
25
26 std::exception_ptr checkForNCCLErrors(
27 const std::vector<std::shared_ptr<c10d::NCCLComm>>& ncclComms)
28 const override {
29 if (simulate_error_) {
30 return std::make_exception_ptr(std::runtime_error("Error"));
31 }
32 return c10d::ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors(ncclComms);
33 }
34
35 private:
36 bool simulate_error_;
37};
38
39class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
40 public:
41 ProcessGroupNCCLSimulateErrors(
42 const c10::intrusive_ptr<c10d::Store>& store,
43 int rank,
44 int size,
45 c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
46 : ProcessGroupNCCL(store, rank, size, opts), simulate_error_(false) {}
47
48 std::exception_ptr checkForNCCLErrors(
49 const std::vector<std::shared_ptr<c10d::NCCLComm>>& ncclComms) override {
50 if (simulate_error_) {
51 return std::make_exception_ptr(std::runtime_error("Error"));
52 }
53 return c10d::ProcessGroupNCCL::checkForNCCLErrors(ncclComms);
54 }
55
56 std::chrono::duration<int64_t, std::milli> getWatchdogSleepInterval() {
57 return std::chrono::milliseconds(
58 ProcessGroupNCCLSimulateErrors::kWatchdogThreadSleepMillis);
59 }
60
61 c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
62 std::vector<at::Device> devices,
63 int rank,
64 c10d::OpType opType,
65 const char* profilingTitle,
66 const c10::optional<std::vector<at::Tensor>>& inputs =
67 c10::nullopt) override {
68 return c10::make_intrusive<WorkNCCLSimulateErrors>(
69 devices, simulate_error_, rank, opType, seq_);
70 }
71
72 size_t getNCCLCommCacheSize() {
73 return devNCCLCommMap_.size();
74 }
75
76 void simulate_error() {
77 simulate_error_ = true;
78 }
79
80 void reset_error() {
81 simulate_error_ = false;
82 }
83
84 private:
85 bool simulate_error_;
86};
87
88class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
89 public:
90 WorkNCCLTimedoutErrors(
91 const std::vector<at::Device>& devices,
92 bool set_timedout_error,
93 int rank,
94 c10d::OpType opType,
95 uint64_t seq)
96 : WorkNCCL(devices, rank, opType, seq),
97 set_timedout_error_(set_timedout_error) {}
98
99 private:
100 bool isCompleted() override {
101 if (set_timedout_error_) {
102 return false;
103 }
104 return c10d::ProcessGroupNCCL::WorkNCCL::isCompleted();
105 }
106
107 private:
108 bool set_timedout_error_;
109};
110
111class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors {
112 public:
113 ProcessGroupNCCLTimedOutErrors(
114 const c10::intrusive_ptr<c10d::Store>& store,
115 int rank,
116 int size,
117 c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
118 : ProcessGroupNCCLSimulateErrors(store, rank, size, opts),
119 set_timedout_error_(false) {}
120
121 c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
122 std::vector<at::Device> devices,
123 int rank,
124 c10d::OpType opType,
125 const char* profilingTitle,
126 const c10::optional<std::vector<at::Tensor>>& inputs =
127 c10::nullopt) override {
128 return c10::make_intrusive<WorkNCCLTimedoutErrors>(
129 devices, set_timedout_error_, rank, opType, seq_);
130 }
131
132 void set_timedout_error() {
133 set_timedout_error_ = true;
134 }
135
136 void reset_timedout_error() {
137 set_timedout_error_ = false;
138 }
139
140 private:
141 bool set_timedout_error_;
142};
143
144class ProcessGroupNCCLErrorsTest : public ::testing::Test {
145 protected:
146 bool skipTest() {
147 if (cudaNumDevices() == 0) {
148 LOG(INFO) << "Skipping test since CUDA is not available";
149 return true;
150 }
151#ifdef USE_C10D_NCCL
152 if (torch::cuda::nccl::version() < kNcclErrorHandlingVersion) {
153 LOG(INFO) << "Skipping test since NCCL version is too old";
154 return true;
155 }
156#endif
157 return false;
158 }
159
160 void SetUp() override {
161 size_t numDevices = cudaNumDevices();
162 TemporaryFile file;
163 store_ = c10::make_intrusive<::c10d::FileStore>(file.path, 1);
164
165 at::cuda::OptionalCUDAGuard deviceGuard;
166 tensors_.resize(numDevices);
167 for (const auto i : c10::irange(numDevices)) {
168 deviceGuard.set_index(i);
169 tensors_[i] = at::ones({3, 3}, at::kCUDA);
170 }
171 }
172
173 void TearDown() override {
174 ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "0", 1) == 0);
175 }
176
177 std::vector<at::Tensor> tensors_;
178 c10::intrusive_ptr<::c10d::FileStore> store_;
179};
180
181TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) {
182 if (skipTest()) {
183 return;
184 }
185
186 ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1", 1) == 0);
187 auto options = c10d::ProcessGroupNCCL::Options::create();
188 options->timeout = std::chrono::milliseconds(1000);
189 ProcessGroupNCCLSimulateErrors pg(store_, 0, 1, options);
190
191 auto work = pg.allreduce(tensors_);
192 work->wait();
193 EXPECT_TRUE(work->isSuccess());
194 EXPECT_EQ(1, pg.getNCCLCommCacheSize());
195
196 // Now run all reduce with errors.
197 pg.simulate_error();
198 work = pg.allreduce(tensors_);
199 EXPECT_THROW(work->wait(), std::runtime_error);
200
201 // Verify the work item failed.
202 EXPECT_TRUE(work->isCompleted());
203 EXPECT_FALSE(work->isSuccess());
204 EXPECT_THROW(work->wait(), std::runtime_error);
205
206 // Communicators might be aborted here, further operations would fail.
207}
208
209TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) {
210 if (skipTest()) {
211 return;
212 }
213
214 ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1", 1) == 0);
215 auto options = c10d::ProcessGroupNCCL::Options::create();
216 options->timeout = std::chrono::milliseconds(3000);
217 ProcessGroupNCCLTimedOutErrors pg(store_, 0, 1, options);
218
219 auto work = pg.allreduce(tensors_);
220 work->wait();
221 EXPECT_TRUE(work->isSuccess());
222 EXPECT_EQ(1, pg.getNCCLCommCacheSize());
223
224 // Now run all reduce with errors.
225 pg.set_timedout_error();
226 work = pg.allreduce(tensors_);
227 EXPECT_THROW(work->wait(), c10::Error);
228
229 // Communicators might be aborted here, further operations would fail.
230}
231
232TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) {
233 if (skipTest()) {
234 return;
235 }
236
237 auto options = c10d::ProcessGroupNCCL::Options::create();
238 options->timeout = std::chrono::milliseconds(3000);
239 ProcessGroupNCCLSimulateErrors pg(store_, 0, 1, options);
240
241 auto work = pg.allreduce(tensors_);
242 pg.barrier()->wait();
243 EXPECT_TRUE(work->isSuccess());
244 EXPECT_EQ(1, pg.getNCCLCommCacheSize());
245
246 // Now run all reduce with errors.
247 pg.simulate_error();
248 work = pg.allreduce(tensors_);
249
250 // Should not throw exceptions.
251 work->wait();
252 pg.barrier()->wait();
253
254 // Verify the work item failed.
255 EXPECT_TRUE(work->isCompleted());
256 EXPECT_FALSE(work->isSuccess());
257
258 // Communicators might be aborted here, further operations would fail.
259}
260