1#include <gtest/gtest.h>
2
3#include <c10/util/irange.h>
4#include <torch/csrc/autograd/functions/comm.h>
5#include <torch/nn/module.h>
6#include <torch/nn/modules/conv.h>
7#include <torch/nn/modules/linear.h>
8#include <torch/nn/parallel/data_parallel.h>
9#include <torch/nn/pimpl.h>
10#include <torch/optim/sgd.h>
11#include <torch/types.h>
12#include <torch/utils.h>
13
14#include <test/cpp/api/support.h>
15
16#include <iostream>
17#include <memory>
18#include <utility>
19#include <vector>
20
21using namespace torch::autograd;
22using namespace torch::nn;
23
24struct ParallelTest : torch::test::SeedingFixture {};
25
26TEST_F(ParallelTest, DifferentiableScatter_MultiCUDA) {
27 Scatter scatter(
28 {torch::Device(torch::kCUDA, 0), torch::Device(torch::kCUDA, 1)});
29
30 auto input = torch::ones(10, torch::requires_grad(true));
31 auto output = scatter.apply({input});
32
33 ASSERT_EQ(output.size(), 2);
34 ASSERT_EQ(output[0].size(0), 5);
35 ASSERT_EQ(output[1].size(0), 5);
36
37 ASSERT_TRUE(torch::cat({output[0].to(torch::kCPU), output[1].to(torch::kCPU)})
38 .allclose(input));
39
40 torch::Tensor sum = output[0].to({torch::kCUDA, 1}) + output[1];
41 sum.backward(torch::ones_like(sum));
42
43 ASSERT_TRUE(input.grad().defined());
44 ASSERT_TRUE(input.grad().device().is_cpu());
45 ASSERT_EQ(input.grad().sum().item<int32_t>(), 10);
46}
47
48TEST_F(ParallelTest, DifferentiableGather_MultiCUDA) {
49 Gather gather(torch::Device(torch::kCUDA, 1));
50
51 auto a = torch::ones(5, torch::requires_grad(true).device(torch::kCUDA, 0));
52 auto b = torch::ones(5, torch::requires_grad(true).device(torch::kCUDA, 1));
53
54 auto outputs = gather.apply({a, b});
55 ASSERT_EQ(outputs.size(), 1);
56 torch::Tensor output = outputs.front();
57
58 ASSERT_EQ(output.size(0), 10);
59 ASSERT_EQ(output.device(), torch::Device(torch::kCUDA, 1));
60
61 auto chunks = output.chunk(2);
62 ASSERT_TRUE(chunks[0].to({torch::kCUDA, 0}).allclose(a));
63 ASSERT_TRUE(chunks[1].allclose(b));
64
65 output.backward(torch::ones_like(output));
66
67 ASSERT_TRUE(a.grad().defined());
68 ASSERT_EQ(a.grad().device(), torch::Device(torch::kCUDA, 0));
69 ASSERT_EQ(a.grad().sum().item<int32_t>(), 5);
70
71 ASSERT_TRUE(b.grad().defined());
72 ASSERT_EQ(b.grad().device(), torch::Device(torch::kCUDA, 1));
73 ASSERT_EQ(b.grad().sum().item<int32_t>(), 5);
74}
75
76TEST_F(ParallelTest, Replicate_MultiCUDA) {
77 Linear linear(3, 4);
78 auto replicas = parallel::replicate(
79 linear, {torch::Device(torch::kCUDA, 0), torch::Device(torch::kCUDA, 1)});
80 ASSERT_EQ(replicas.size(), 2);
81
82 auto original_parameters = linear->parameters();
83
84 auto replica1_parameters = replicas[0]->parameters();
85 for (auto& parameter : replica1_parameters) {
86 ASSERT_EQ(parameter.device(), torch::Device(torch::kCUDA, 0));
87 }
88 replicas[0]->to(torch::kCPU);
89 ASSERT_EQ(replica1_parameters.size(), original_parameters.size());
90 for (const auto i : c10::irange(original_parameters.size())) {
91 ASSERT_TRUE(replica1_parameters[i].allclose(original_parameters[i]));
92 ASSERT_TRUE(
93 replica1_parameters[i].data_ptr<float>() !=
94 original_parameters[i].data_ptr<float>());
95 }
96
97 auto replica2_parameters = replicas[1]->parameters();
98 for (auto& parameter : replica2_parameters) {
99 ASSERT_EQ(parameter.device(), torch::Device(torch::kCUDA, 1));
100 }
101 replicas[1]->to(torch::kCPU);
102 ASSERT_EQ(replica2_parameters.size(), original_parameters.size());
103 for (const auto i : c10::irange(original_parameters.size())) {
104 ASSERT_TRUE(replica2_parameters[i].allclose(original_parameters[i]));
105 ASSERT_TRUE(
106 replica2_parameters[i].data_ptr<float>() !=
107 original_parameters[i].data_ptr<float>());
108 }
109}
110
111TEST_F(ParallelTest, ParallelApply_MultiCUDA) {
112 Linear a(3, 4);
113
114 Linear b(std::dynamic_pointer_cast<LinearImpl>(a->clone()));
115 b->to({torch::kCUDA, 0});
116
117 Linear c(std::dynamic_pointer_cast<LinearImpl>(a->clone()));
118 c->to({torch::kCUDA, 1});
119
120 std::vector<Linear> modules = {a, b, c};
121 std::vector<torch::Tensor> inputs = {
122 torch::ones({2, 3}),
123 torch::ones({2, 3}, torch::device({torch::kCUDA, 0})),
124 torch::ones({2, 3}, torch::device({torch::kCUDA, 1}))};
125
126 auto outputs = parallel::parallel_apply(modules, inputs);
127
128 ASSERT_EQ(outputs.size(), 3);
129 ASSERT_TRUE(outputs[0].device().is_cpu());
130
131 ASSERT_EQ(outputs[1].device(), torch::Device(torch::kCUDA, 0));
132 ASSERT_TRUE(outputs[1].to(torch::kCPU).allclose(outputs[0]));
133
134 ASSERT_EQ(outputs[2].device(), torch::Device(torch::kCUDA, 1));
135 ASSERT_TRUE(outputs[2].to(torch::kCPU).allclose(outputs[0]));
136}
137
138TEST_F(ParallelTest, ParallelApplyWithDifferentOutputDevice_MultiCUDA) {
139 struct M : torch::nn::Module {
140 torch::Tensor forward(torch::Tensor input) {
141 return torch::ones(5, torch::kInt32);
142 }
143 };
144
145 std::vector<std::shared_ptr<M>> modules = {
146 std::make_shared<M>(), std::make_shared<M>(), std::make_shared<M>()};
147 std::vector<torch::Tensor> inputs = {
148 torch::empty({}), torch::empty({}), torch::empty({})};
149 std::vector<torch::Device> devices = {
150 {torch::kCUDA, 1}, {torch::kCUDA, 0}, {torch::kCPU}};
151
152 auto outputs = parallel::parallel_apply(modules, inputs, devices);
153
154 ASSERT_EQ(outputs.size(), 3);
155 ASSERT_TRUE(outputs[0].device().is_cuda());
156 ASSERT_EQ(outputs[0].device(), torch::Device(torch::kCUDA, 1));
157
158 ASSERT_TRUE(outputs[1].device().is_cuda());
159 ASSERT_EQ(outputs[1].device(), torch::Device(torch::kCUDA, 0));
160
161 ASSERT_TRUE(outputs[2].device().is_cpu());
162}
163
164TEST_F(ParallelTest, ParallelApplyRethrowsException_MultiCUDA) {
165 struct M : torch::nn::Cloneable<M> {
166 void reset() override {}
167 torch::Tensor forward(torch::Tensor input) {
168 throw std::runtime_error("Badness!");
169 }
170 };
171
172 auto m = std::make_shared<M>();
173 auto input = torch::ones({10, 3});
174 ASSERT_THROWS_WITH(parallel::data_parallel(m, input), "Badness!");
175}
176
177TEST_F(
178 ParallelTest,
179 DataParallelPlacesTheOutputOnTheRequestedDevice_MultiCUDA) {
180 struct M : torch::nn::Cloneable<M> {
181 void reset() override {}
182 torch::Tensor forward(torch::Tensor input) {
183 // The returned tensor should be on the output device.
184 return torch::ones(3);
185 }
186 };
187 auto m = std::make_shared<M>();
188 auto input = torch::ones({10, 3});
189 {
190 auto output = parallel::data_parallel(
191 m,
192 input,
193 /*devices=*/torch::nullopt,
194 /*output_device=*/torch::Device(torch::kCUDA, 1));
195 ASSERT_TRUE(output.defined());
196 ASSERT_TRUE(output.device().is_cuda());
197 ASSERT_EQ(output.device().index(), 1);
198 }
199 {
200 // Verify for the single-device case (where we don't scatter/gather).
201 auto output = parallel::data_parallel(
202 m,
203 input,
204 /*devices=*/std::vector<torch::Device>{torch::Device(torch::kCUDA, 0)},
205 /*output_device=*/torch::Device(torch::kCUDA, 1));
206 ASSERT_TRUE(output.defined());
207 ASSERT_TRUE(output.device().is_cuda());
208 ASSERT_EQ(output.device().index(), 1);
209 }
210}
211
212TEST_F(ParallelTest, DataParallelUsesAllAvailableCUDADevices_CUDA) {
213 struct M : torch::nn::Cloneable<M> {
214 void reset() override {}
215 torch::Tensor forward(torch::Tensor input) {
216 return torch::tensor({input.device().index()});
217 }
218 };
219
220 auto m = std::make_shared<M>();
221 const auto device_count = torch::cuda::device_count();
222 auto input = torch::ones({std::max(10, int(2 * device_count)), 3});
223 auto output = parallel::data_parallel(m, input);
224
225 ASSERT_EQ(output.numel(), device_count);
226 for (const auto i : c10::irange(device_count)) {
227 ASSERT_EQ(output[i].item<int32_t>(), i);
228 }
229}
230
231TEST_F(ParallelTest, DataParallelNumericalEquivalence_MultiCUDA) {
232 struct M : torch::nn::Cloneable<M> {
233 M() {
234 reset();
235 }
236
237 void reset() override {
238 conv = register_module(
239 "conv",
240 torch::nn::Conv2d(torch::nn::Conv2dOptions(2, 2, /*kernel_size=*/2)));
241 fc = register_module("fc", torch::nn::Linear(8, 2));
242 }
243
244 torch::Tensor forward(torch::Tensor x) {
245 x = conv->forward(x);
246 x = torch::relu(x);
247 x = x.view({-1, 8});
248 x = fc->forward(x);
249 return torch::log_softmax(x, /*dim=*/1);
250 }
251
252 torch::nn::Conv2d conv{nullptr};
253 torch::nn::Linear fc{nullptr};
254 };
255
256 // prepare modules and inputs
257 auto input = torch::ones({16, 2, 3, 3});
258 auto input_dp = torch::ones({16, 2, 3, 3});
259 auto model = std::make_shared<M>();
260 auto model_dp = std::dynamic_pointer_cast<M>(model->clone());
261
262 // run 3 training iterations
263 for (const auto i : c10::irange(3)) {
264 input += i;
265 input_dp += i;
266
267 // non-prallel training
268 torch::optim::SGD optim(model->parameters(), torch::optim::SGDOptions(0.1));
269 auto output = model->forward(input);
270 auto loss = torch::mse_loss(output, torch::zeros_like(output));
271 loss.backward();
272 optim.step();
273
274 // data-parallel training
275 torch::optim::SGD optim_dp(
276 model_dp->parameters(), torch::optim::SGDOptions(0.1));
277 auto output_dp = parallel::data_parallel(model_dp, input_dp);
278 auto loss_dp = torch::mse_loss(output_dp, torch::zeros_like(output_dp));
279 loss_dp.backward();
280 optim_dp.step();
281
282 // make sure that weights are the same
283 model->to(torch::kCPU);
284 model_dp->to(torch::kCPU);
285 auto params = model->parameters();
286 auto params_dp = model_dp->parameters();
287 ASSERT_EQ(params.size(), params_dp.size());
288 for (auto it = params.begin(), it_dp = params_dp.begin();
289 it != params.end() && it_dp != params.end();
290 ++it, ++it_dp) {
291 ASSERT_TRUE(torch::allclose(*it, *it_dp));
292 }
293 }
294}
295