1 | #include <chrono> |
2 | |
3 | #include <c10/util/irange.h> |
4 | #include <torch/csrc/cuda/nccl.h> |
5 | #include <torch/csrc/distributed/c10d/FileStore.hpp> |
6 | #include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp> |
7 | #include "CUDATest.hpp" |
8 | #include "TestUtils.hpp" |
9 | |
10 | #include <gtest/gtest.h> |
11 | |
12 | using namespace c10d::test; |
13 | |
14 | constexpr int kNcclErrorHandlingVersion = 2400; |
15 | |
16 | class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL { |
17 | public: |
18 | WorkNCCLSimulateErrors( |
19 | const std::vector<at::Device>& devices, |
20 | bool simulate_error, |
21 | int rank, |
22 | c10d::OpType opType, |
23 | uint64_t seq) |
24 | : WorkNCCL(devices, rank, opType, seq), simulate_error_(simulate_error) {} |
25 | |
26 | std::exception_ptr checkForNCCLErrors( |
27 | const std::vector<std::shared_ptr<c10d::NCCLComm>>& ncclComms) |
28 | const override { |
29 | if (simulate_error_) { |
30 | return std::make_exception_ptr(std::runtime_error("Error" )); |
31 | } |
32 | return c10d::ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors(ncclComms); |
33 | } |
34 | |
35 | private: |
36 | bool simulate_error_; |
37 | }; |
38 | |
39 | class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { |
40 | public: |
41 | ProcessGroupNCCLSimulateErrors( |
42 | const c10::intrusive_ptr<c10d::Store>& store, |
43 | int rank, |
44 | int size, |
45 | c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts) |
46 | : ProcessGroupNCCL(store, rank, size, opts), simulate_error_(false) {} |
47 | |
48 | std::exception_ptr checkForNCCLErrors( |
49 | const std::vector<std::shared_ptr<c10d::NCCLComm>>& ncclComms) override { |
50 | if (simulate_error_) { |
51 | return std::make_exception_ptr(std::runtime_error("Error" )); |
52 | } |
53 | return c10d::ProcessGroupNCCL::checkForNCCLErrors(ncclComms); |
54 | } |
55 | |
56 | std::chrono::duration<int64_t, std::milli> getWatchdogSleepInterval() { |
57 | return std::chrono::milliseconds( |
58 | ProcessGroupNCCLSimulateErrors::kWatchdogThreadSleepMillis); |
59 | } |
60 | |
61 | c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork( |
62 | std::vector<at::Device> devices, |
63 | int rank, |
64 | c10d::OpType opType, |
65 | const char* profilingTitle, |
66 | const c10::optional<std::vector<at::Tensor>>& inputs = |
67 | c10::nullopt) override { |
68 | return c10::make_intrusive<WorkNCCLSimulateErrors>( |
69 | devices, simulate_error_, rank, opType, seq_); |
70 | } |
71 | |
72 | size_t getNCCLCommCacheSize() { |
73 | return devNCCLCommMap_.size(); |
74 | } |
75 | |
76 | void simulate_error() { |
77 | simulate_error_ = true; |
78 | } |
79 | |
80 | void reset_error() { |
81 | simulate_error_ = false; |
82 | } |
83 | |
84 | private: |
85 | bool simulate_error_; |
86 | }; |
87 | |
88 | class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL { |
89 | public: |
90 | WorkNCCLTimedoutErrors( |
91 | const std::vector<at::Device>& devices, |
92 | bool set_timedout_error, |
93 | int rank, |
94 | c10d::OpType opType, |
95 | uint64_t seq) |
96 | : WorkNCCL(devices, rank, opType, seq), |
97 | set_timedout_error_(set_timedout_error) {} |
98 | |
99 | private: |
100 | bool isCompleted() override { |
101 | if (set_timedout_error_) { |
102 | return false; |
103 | } |
104 | return c10d::ProcessGroupNCCL::WorkNCCL::isCompleted(); |
105 | } |
106 | |
107 | private: |
108 | bool set_timedout_error_; |
109 | }; |
110 | |
111 | class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { |
112 | public: |
113 | ProcessGroupNCCLTimedOutErrors( |
114 | const c10::intrusive_ptr<c10d::Store>& store, |
115 | int rank, |
116 | int size, |
117 | c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts) |
118 | : ProcessGroupNCCLSimulateErrors(store, rank, size, opts), |
119 | set_timedout_error_(false) {} |
120 | |
121 | c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork( |
122 | std::vector<at::Device> devices, |
123 | int rank, |
124 | c10d::OpType opType, |
125 | const char* profilingTitle, |
126 | const c10::optional<std::vector<at::Tensor>>& inputs = |
127 | c10::nullopt) override { |
128 | return c10::make_intrusive<WorkNCCLTimedoutErrors>( |
129 | devices, set_timedout_error_, rank, opType, seq_); |
130 | } |
131 | |
132 | void set_timedout_error() { |
133 | set_timedout_error_ = true; |
134 | } |
135 | |
136 | void reset_timedout_error() { |
137 | set_timedout_error_ = false; |
138 | } |
139 | |
140 | private: |
141 | bool set_timedout_error_; |
142 | }; |
143 | |
144 | class ProcessGroupNCCLErrorsTest : public ::testing::Test { |
145 | protected: |
146 | bool skipTest() { |
147 | if (cudaNumDevices() == 0) { |
148 | LOG(INFO) << "Skipping test since CUDA is not available" ; |
149 | return true; |
150 | } |
151 | #ifdef USE_C10D_NCCL |
152 | if (torch::cuda::nccl::version() < kNcclErrorHandlingVersion) { |
153 | LOG(INFO) << "Skipping test since NCCL version is too old" ; |
154 | return true; |
155 | } |
156 | #endif |
157 | return false; |
158 | } |
159 | |
160 | void SetUp() override { |
161 | size_t numDevices = cudaNumDevices(); |
162 | TemporaryFile file; |
163 | store_ = c10::make_intrusive<::c10d::FileStore>(file.path, 1); |
164 | |
165 | at::cuda::OptionalCUDAGuard deviceGuard; |
166 | tensors_.resize(numDevices); |
167 | for (const auto i : c10::irange(numDevices)) { |
168 | deviceGuard.set_index(i); |
169 | tensors_[i] = at::ones({3, 3}, at::kCUDA); |
170 | } |
171 | } |
172 | |
173 | void TearDown() override { |
174 | ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "0" , 1) == 0); |
175 | } |
176 | |
177 | std::vector<at::Tensor> tensors_; |
178 | c10::intrusive_ptr<::c10d::FileStore> store_; |
179 | }; |
180 | |
181 | TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) { |
182 | if (skipTest()) { |
183 | return; |
184 | } |
185 | |
186 | ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1" , 1) == 0); |
187 | auto options = c10d::ProcessGroupNCCL::Options::create(); |
188 | options->timeout = std::chrono::milliseconds(1000); |
189 | ProcessGroupNCCLSimulateErrors pg(store_, 0, 1, options); |
190 | |
191 | auto work = pg.allreduce(tensors_); |
192 | work->wait(); |
193 | EXPECT_TRUE(work->isSuccess()); |
194 | EXPECT_EQ(1, pg.getNCCLCommCacheSize()); |
195 | |
196 | // Now run all reduce with errors. |
197 | pg.simulate_error(); |
198 | work = pg.allreduce(tensors_); |
199 | EXPECT_THROW(work->wait(), std::runtime_error); |
200 | |
201 | // Verify the work item failed. |
202 | EXPECT_TRUE(work->isCompleted()); |
203 | EXPECT_FALSE(work->isSuccess()); |
204 | EXPECT_THROW(work->wait(), std::runtime_error); |
205 | |
206 | // Communicators might be aborted here, further operations would fail. |
207 | } |
208 | |
209 | TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) { |
210 | if (skipTest()) { |
211 | return; |
212 | } |
213 | |
214 | ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1" , 1) == 0); |
215 | auto options = c10d::ProcessGroupNCCL::Options::create(); |
216 | options->timeout = std::chrono::milliseconds(3000); |
217 | ProcessGroupNCCLTimedOutErrors pg(store_, 0, 1, options); |
218 | |
219 | auto work = pg.allreduce(tensors_); |
220 | work->wait(); |
221 | EXPECT_TRUE(work->isSuccess()); |
222 | EXPECT_EQ(1, pg.getNCCLCommCacheSize()); |
223 | |
224 | // Now run all reduce with errors. |
225 | pg.set_timedout_error(); |
226 | work = pg.allreduce(tensors_); |
227 | EXPECT_THROW(work->wait(), c10::Error); |
228 | |
229 | // Communicators might be aborted here, further operations would fail. |
230 | } |
231 | |
232 | TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) { |
233 | if (skipTest()) { |
234 | return; |
235 | } |
236 | |
237 | auto options = c10d::ProcessGroupNCCL::Options::create(); |
238 | options->timeout = std::chrono::milliseconds(3000); |
239 | ProcessGroupNCCLSimulateErrors pg(store_, 0, 1, options); |
240 | |
241 | auto work = pg.allreduce(tensors_); |
242 | pg.barrier()->wait(); |
243 | EXPECT_TRUE(work->isSuccess()); |
244 | EXPECT_EQ(1, pg.getNCCLCommCacheSize()); |
245 | |
246 | // Now run all reduce with errors. |
247 | pg.simulate_error(); |
248 | work = pg.allreduce(tensors_); |
249 | |
250 | // Should not throw exceptions. |
251 | work->wait(); |
252 | pg.barrier()->wait(); |
253 | |
254 | // Verify the work item failed. |
255 | EXPECT_TRUE(work->isCompleted()); |
256 | EXPECT_FALSE(work->isSuccess()); |
257 | |
258 | // Communicators might be aborted here, further operations would fail. |
259 | } |
260 | |