1 | #include <gtest/gtest.h> |
2 | |
3 | #include <c10/util/irange.h> |
4 | #include <torch/csrc/autograd/functions/comm.h> |
5 | #include <torch/nn/module.h> |
6 | #include <torch/nn/modules/conv.h> |
7 | #include <torch/nn/modules/linear.h> |
8 | #include <torch/nn/parallel/data_parallel.h> |
9 | #include <torch/nn/pimpl.h> |
10 | #include <torch/optim/sgd.h> |
11 | #include <torch/types.h> |
12 | #include <torch/utils.h> |
13 | |
14 | #include <test/cpp/api/support.h> |
15 | |
16 | #include <iostream> |
17 | #include <memory> |
18 | #include <utility> |
19 | #include <vector> |
20 | |
21 | using namespace torch::autograd; |
22 | using namespace torch::nn; |
23 | |
24 | struct ParallelTest : torch::test::SeedingFixture {}; |
25 | |
26 | TEST_F(ParallelTest, DifferentiableScatter_MultiCUDA) { |
27 | Scatter scatter( |
28 | {torch::Device(torch::kCUDA, 0), torch::Device(torch::kCUDA, 1)}); |
29 | |
30 | auto input = torch::ones(10, torch::requires_grad(true)); |
31 | auto output = scatter.apply({input}); |
32 | |
33 | ASSERT_EQ(output.size(), 2); |
34 | ASSERT_EQ(output[0].size(0), 5); |
35 | ASSERT_EQ(output[1].size(0), 5); |
36 | |
37 | ASSERT_TRUE(torch::cat({output[0].to(torch::kCPU), output[1].to(torch::kCPU)}) |
38 | .allclose(input)); |
39 | |
40 | torch::Tensor sum = output[0].to({torch::kCUDA, 1}) + output[1]; |
41 | sum.backward(torch::ones_like(sum)); |
42 | |
43 | ASSERT_TRUE(input.grad().defined()); |
44 | ASSERT_TRUE(input.grad().device().is_cpu()); |
45 | ASSERT_EQ(input.grad().sum().item<int32_t>(), 10); |
46 | } |
47 | |
48 | TEST_F(ParallelTest, DifferentiableGather_MultiCUDA) { |
49 | Gather gather(torch::Device(torch::kCUDA, 1)); |
50 | |
51 | auto a = torch::ones(5, torch::requires_grad(true).device(torch::kCUDA, 0)); |
52 | auto b = torch::ones(5, torch::requires_grad(true).device(torch::kCUDA, 1)); |
53 | |
54 | auto outputs = gather.apply({a, b}); |
55 | ASSERT_EQ(outputs.size(), 1); |
56 | torch::Tensor output = outputs.front(); |
57 | |
58 | ASSERT_EQ(output.size(0), 10); |
59 | ASSERT_EQ(output.device(), torch::Device(torch::kCUDA, 1)); |
60 | |
61 | auto chunks = output.chunk(2); |
62 | ASSERT_TRUE(chunks[0].to({torch::kCUDA, 0}).allclose(a)); |
63 | ASSERT_TRUE(chunks[1].allclose(b)); |
64 | |
65 | output.backward(torch::ones_like(output)); |
66 | |
67 | ASSERT_TRUE(a.grad().defined()); |
68 | ASSERT_EQ(a.grad().device(), torch::Device(torch::kCUDA, 0)); |
69 | ASSERT_EQ(a.grad().sum().item<int32_t>(), 5); |
70 | |
71 | ASSERT_TRUE(b.grad().defined()); |
72 | ASSERT_EQ(b.grad().device(), torch::Device(torch::kCUDA, 1)); |
73 | ASSERT_EQ(b.grad().sum().item<int32_t>(), 5); |
74 | } |
75 | |
76 | TEST_F(ParallelTest, Replicate_MultiCUDA) { |
77 | Linear linear(3, 4); |
78 | auto replicas = parallel::replicate( |
79 | linear, {torch::Device(torch::kCUDA, 0), torch::Device(torch::kCUDA, 1)}); |
80 | ASSERT_EQ(replicas.size(), 2); |
81 | |
82 | auto original_parameters = linear->parameters(); |
83 | |
84 | auto replica1_parameters = replicas[0]->parameters(); |
85 | for (auto& parameter : replica1_parameters) { |
86 | ASSERT_EQ(parameter.device(), torch::Device(torch::kCUDA, 0)); |
87 | } |
88 | replicas[0]->to(torch::kCPU); |
89 | ASSERT_EQ(replica1_parameters.size(), original_parameters.size()); |
90 | for (const auto i : c10::irange(original_parameters.size())) { |
91 | ASSERT_TRUE(replica1_parameters[i].allclose(original_parameters[i])); |
92 | ASSERT_TRUE( |
93 | replica1_parameters[i].data_ptr<float>() != |
94 | original_parameters[i].data_ptr<float>()); |
95 | } |
96 | |
97 | auto replica2_parameters = replicas[1]->parameters(); |
98 | for (auto& parameter : replica2_parameters) { |
99 | ASSERT_EQ(parameter.device(), torch::Device(torch::kCUDA, 1)); |
100 | } |
101 | replicas[1]->to(torch::kCPU); |
102 | ASSERT_EQ(replica2_parameters.size(), original_parameters.size()); |
103 | for (const auto i : c10::irange(original_parameters.size())) { |
104 | ASSERT_TRUE(replica2_parameters[i].allclose(original_parameters[i])); |
105 | ASSERT_TRUE( |
106 | replica2_parameters[i].data_ptr<float>() != |
107 | original_parameters[i].data_ptr<float>()); |
108 | } |
109 | } |
110 | |
111 | TEST_F(ParallelTest, ParallelApply_MultiCUDA) { |
112 | Linear a(3, 4); |
113 | |
114 | Linear b(std::dynamic_pointer_cast<LinearImpl>(a->clone())); |
115 | b->to({torch::kCUDA, 0}); |
116 | |
117 | Linear c(std::dynamic_pointer_cast<LinearImpl>(a->clone())); |
118 | c->to({torch::kCUDA, 1}); |
119 | |
120 | std::vector<Linear> modules = {a, b, c}; |
121 | std::vector<torch::Tensor> inputs = { |
122 | torch::ones({2, 3}), |
123 | torch::ones({2, 3}, torch::device({torch::kCUDA, 0})), |
124 | torch::ones({2, 3}, torch::device({torch::kCUDA, 1}))}; |
125 | |
126 | auto outputs = parallel::parallel_apply(modules, inputs); |
127 | |
128 | ASSERT_EQ(outputs.size(), 3); |
129 | ASSERT_TRUE(outputs[0].device().is_cpu()); |
130 | |
131 | ASSERT_EQ(outputs[1].device(), torch::Device(torch::kCUDA, 0)); |
132 | ASSERT_TRUE(outputs[1].to(torch::kCPU).allclose(outputs[0])); |
133 | |
134 | ASSERT_EQ(outputs[2].device(), torch::Device(torch::kCUDA, 1)); |
135 | ASSERT_TRUE(outputs[2].to(torch::kCPU).allclose(outputs[0])); |
136 | } |
137 | |
138 | TEST_F(ParallelTest, ParallelApplyWithDifferentOutputDevice_MultiCUDA) { |
139 | struct M : torch::nn::Module { |
140 | torch::Tensor forward(torch::Tensor input) { |
141 | return torch::ones(5, torch::kInt32); |
142 | } |
143 | }; |
144 | |
145 | std::vector<std::shared_ptr<M>> modules = { |
146 | std::make_shared<M>(), std::make_shared<M>(), std::make_shared<M>()}; |
147 | std::vector<torch::Tensor> inputs = { |
148 | torch::empty({}), torch::empty({}), torch::empty({})}; |
149 | std::vector<torch::Device> devices = { |
150 | {torch::kCUDA, 1}, {torch::kCUDA, 0}, {torch::kCPU}}; |
151 | |
152 | auto outputs = parallel::parallel_apply(modules, inputs, devices); |
153 | |
154 | ASSERT_EQ(outputs.size(), 3); |
155 | ASSERT_TRUE(outputs[0].device().is_cuda()); |
156 | ASSERT_EQ(outputs[0].device(), torch::Device(torch::kCUDA, 1)); |
157 | |
158 | ASSERT_TRUE(outputs[1].device().is_cuda()); |
159 | ASSERT_EQ(outputs[1].device(), torch::Device(torch::kCUDA, 0)); |
160 | |
161 | ASSERT_TRUE(outputs[2].device().is_cpu()); |
162 | } |
163 | |
164 | TEST_F(ParallelTest, ParallelApplyRethrowsException_MultiCUDA) { |
165 | struct M : torch::nn::Cloneable<M> { |
166 | void reset() override {} |
167 | torch::Tensor forward(torch::Tensor input) { |
168 | throw std::runtime_error("Badness!" ); |
169 | } |
170 | }; |
171 | |
172 | auto m = std::make_shared<M>(); |
173 | auto input = torch::ones({10, 3}); |
174 | ASSERT_THROWS_WITH(parallel::data_parallel(m, input), "Badness!" ); |
175 | } |
176 | |
177 | TEST_F( |
178 | ParallelTest, |
179 | DataParallelPlacesTheOutputOnTheRequestedDevice_MultiCUDA) { |
180 | struct M : torch::nn::Cloneable<M> { |
181 | void reset() override {} |
182 | torch::Tensor forward(torch::Tensor input) { |
183 | // The returned tensor should be on the output device. |
184 | return torch::ones(3); |
185 | } |
186 | }; |
187 | auto m = std::make_shared<M>(); |
188 | auto input = torch::ones({10, 3}); |
189 | { |
190 | auto output = parallel::data_parallel( |
191 | m, |
192 | input, |
193 | /*devices=*/torch::nullopt, |
194 | /*output_device=*/torch::Device(torch::kCUDA, 1)); |
195 | ASSERT_TRUE(output.defined()); |
196 | ASSERT_TRUE(output.device().is_cuda()); |
197 | ASSERT_EQ(output.device().index(), 1); |
198 | } |
199 | { |
200 | // Verify for the single-device case (where we don't scatter/gather). |
201 | auto output = parallel::data_parallel( |
202 | m, |
203 | input, |
204 | /*devices=*/std::vector<torch::Device>{torch::Device(torch::kCUDA, 0)}, |
205 | /*output_device=*/torch::Device(torch::kCUDA, 1)); |
206 | ASSERT_TRUE(output.defined()); |
207 | ASSERT_TRUE(output.device().is_cuda()); |
208 | ASSERT_EQ(output.device().index(), 1); |
209 | } |
210 | } |
211 | |
212 | TEST_F(ParallelTest, DataParallelUsesAllAvailableCUDADevices_CUDA) { |
213 | struct M : torch::nn::Cloneable<M> { |
214 | void reset() override {} |
215 | torch::Tensor forward(torch::Tensor input) { |
216 | return torch::tensor({input.device().index()}); |
217 | } |
218 | }; |
219 | |
220 | auto m = std::make_shared<M>(); |
221 | const auto device_count = torch::cuda::device_count(); |
222 | auto input = torch::ones({std::max(10, int(2 * device_count)), 3}); |
223 | auto output = parallel::data_parallel(m, input); |
224 | |
225 | ASSERT_EQ(output.numel(), device_count); |
226 | for (const auto i : c10::irange(device_count)) { |
227 | ASSERT_EQ(output[i].item<int32_t>(), i); |
228 | } |
229 | } |
230 | |
231 | TEST_F(ParallelTest, DataParallelNumericalEquivalence_MultiCUDA) { |
232 | struct M : torch::nn::Cloneable<M> { |
233 | M() { |
234 | reset(); |
235 | } |
236 | |
237 | void reset() override { |
238 | conv = register_module( |
239 | "conv" , |
240 | torch::nn::Conv2d(torch::nn::Conv2dOptions(2, 2, /*kernel_size=*/2))); |
241 | fc = register_module("fc" , torch::nn::Linear(8, 2)); |
242 | } |
243 | |
244 | torch::Tensor forward(torch::Tensor x) { |
245 | x = conv->forward(x); |
246 | x = torch::relu(x); |
247 | x = x.view({-1, 8}); |
248 | x = fc->forward(x); |
249 | return torch::log_softmax(x, /*dim=*/1); |
250 | } |
251 | |
252 | torch::nn::Conv2d conv{nullptr}; |
253 | torch::nn::Linear fc{nullptr}; |
254 | }; |
255 | |
256 | // prepare modules and inputs |
257 | auto input = torch::ones({16, 2, 3, 3}); |
258 | auto input_dp = torch::ones({16, 2, 3, 3}); |
259 | auto model = std::make_shared<M>(); |
260 | auto model_dp = std::dynamic_pointer_cast<M>(model->clone()); |
261 | |
262 | // run 3 training iterations |
263 | for (const auto i : c10::irange(3)) { |
264 | input += i; |
265 | input_dp += i; |
266 | |
267 | // non-prallel training |
268 | torch::optim::SGD optim(model->parameters(), torch::optim::SGDOptions(0.1)); |
269 | auto output = model->forward(input); |
270 | auto loss = torch::mse_loss(output, torch::zeros_like(output)); |
271 | loss.backward(); |
272 | optim.step(); |
273 | |
274 | // data-parallel training |
275 | torch::optim::SGD optim_dp( |
276 | model_dp->parameters(), torch::optim::SGDOptions(0.1)); |
277 | auto output_dp = parallel::data_parallel(model_dp, input_dp); |
278 | auto loss_dp = torch::mse_loss(output_dp, torch::zeros_like(output_dp)); |
279 | loss_dp.backward(); |
280 | optim_dp.step(); |
281 | |
282 | // make sure that weights are the same |
283 | model->to(torch::kCPU); |
284 | model_dp->to(torch::kCPU); |
285 | auto params = model->parameters(); |
286 | auto params_dp = model_dp->parameters(); |
287 | ASSERT_EQ(params.size(), params_dp.size()); |
288 | for (auto it = params.begin(), it_dp = params_dp.begin(); |
289 | it != params.end() && it_dp != params.end(); |
290 | ++it, ++it_dp) { |
291 | ASSERT_TRUE(torch::allclose(*it, *it_dp)); |
292 | } |
293 | } |
294 | } |
295 | |