1 | #include <gtest/gtest.h> |
2 | |
3 | #include <c10/util/irange.h> |
4 | #include <torch/torch.h> |
5 | |
6 | #include <test/cpp/api/support.h> |
7 | |
8 | #include <torch/expanding_array.h> |
9 | #include <torch/nn/functional/activation.h> |
10 | #include <torch/nn/options/activation.h> |
11 | #include <limits> |
12 | #include <random> |
13 | |
14 | using namespace torch::nn; |
15 | using namespace torch::test; |
16 | |
17 | class TestModel : public torch::nn::Module { |
18 | public: |
19 | TestModel() |
20 | : l1(register_module("l1" , Linear(10, 3))), |
21 | l2(register_module("l2" , Linear(3, 5))), |
22 | l3(register_module("l3" , Linear(5, 100))) {} |
23 | |
24 | Linear l1, l2, l3; |
25 | }; |
26 | |
27 | class NestedModel : public torch::nn::Module { |
28 | public: |
29 | NestedModel() |
30 | : param_(register_parameter("param" , torch::empty({3, 2, 21}))), |
31 | l1(register_module("l1" , Linear(5, 20))), |
32 | t(register_module("test" , std::make_shared<TestModel>())) {} |
33 | |
34 | torch::Tensor param_; |
35 | Linear l1; |
36 | std::shared_ptr<TestModel> t; |
37 | }; |
38 | |
39 | struct ModulesTest : torch::test::SeedingFixture {}; |
40 | |
41 | TEST_F(ModulesTest, Conv1d) { |
42 | Conv1d model(Conv1dOptions(3, 2, 3).stride(1).bias(false)); |
43 | model->weight.set_data( |
44 | torch::arange(18, torch::dtype(torch::kFloat)).reshape({2, 3, 3})); |
45 | auto x = torch::arange(30, torch::dtype(torch::kFloat).requires_grad(true)) |
46 | .reshape({2, 3, 5}); |
47 | auto y = model(x); |
48 | auto expected = torch::tensor( |
49 | {{{312., 348., 384.}, {798., 915., 1032.}}, |
50 | |
51 | {{852., 888., 924.}, {2553., 2670., 2787.}}}, |
52 | torch::kFloat); |
53 | ASSERT_TRUE(torch::allclose(y, expected)); |
54 | |
55 | torch::Tensor s = y.sum(); |
56 | s.backward(); |
57 | ASSERT_EQ(s.ndimension(), 0); |
58 | ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3); |
59 | } |
60 | |
61 | TEST_F(ModulesTest, Conv1dSameStrided) { |
62 | auto options = Conv1dOptions(3, 2, 3); |
63 | options.stride(1).padding(torch::kSame); |
64 | Conv1d model_valid(options); |
65 | ASSERT_THROWS_WITH( |
66 | [&] { Conv1d model_invalid(options.stride(2)); }(), |
67 | "padding='same' is not supported for strided convolutions" ); |
68 | } |
69 | |
70 | TEST_F(ModulesTest, Conv2dEven) { |
71 | Conv2d model(Conv2dOptions(3, 2, 3).stride(1).bias(false)); |
72 | model->weight.set_data( |
73 | torch::arange(54, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 3})); |
74 | auto x = torch::arange(75, torch::dtype(torch::kFloat).requires_grad(true)) |
75 | .reshape({1, 3, 5, 5}); |
76 | auto y = model(x); |
77 | auto expected = torch::tensor( |
78 | {{{{15219., 15570., 15921.}, |
79 | {16974., 17325., 17676.}, |
80 | {18729., 19080., 19431.}}, |
81 | |
82 | {{37818., 38898., 39978.}, |
83 | {43218., 44298., 45378.}, |
84 | {48618., 49698., 50778.}}}}, |
85 | torch::kFloat); |
86 | ASSERT_TRUE(torch::allclose(y, expected)); |
87 | |
88 | torch::Tensor s = y.sum(); |
89 | s.backward(); |
90 | ASSERT_EQ(s.ndimension(), 0); |
91 | ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 3); |
92 | } |
93 | |
94 | TEST_F(ModulesTest, Conv2dUneven) { |
95 | Conv2d model(Conv2dOptions(3, 2, {3, 2}).stride({1, 1}).bias(false)); |
96 | model->weight.set_data( |
97 | torch::arange(36, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 2})); |
98 | auto x = torch::arange(60, torch::dtype(torch::kFloat).requires_grad(true)) |
99 | .reshape({1, 3, 5, 4}); |
100 | auto y = model(x); |
101 | auto expected = torch::tensor( |
102 | {{{{5289., 5442., 5595.}, {5901., 6054., 6207.}, {6513., 6666., 6819.}}, |
103 | |
104 | {{13227., 13704., 14181.}, |
105 | {15135., 15612., 16089.}, |
106 | {17043., 17520., 17997.}}}}, |
107 | torch::kFloat); |
108 | ASSERT_TRUE(torch::allclose(y, expected)); |
109 | |
110 | torch::Tensor s = y.sum(); |
111 | s.backward(); |
112 | ASSERT_EQ(s.ndimension(), 0); |
113 | ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 2); |
114 | } |
115 | |
116 | TEST_F(ModulesTest, Conv2dSameStrided) { |
117 | auto options = Conv2dOptions(3, 2, {3, 4}); |
118 | options.stride(1).padding(torch::kSame); |
119 | Conv2d model_valid(options); |
120 | ASSERT_THROWS_WITH( |
121 | [&] { Conv2d model_invalid(options.stride(2)); }(), |
122 | "padding='same' is not supported for strided convolutions" ); |
123 | ASSERT_THROWS_WITH( |
124 | [&] { |
125 | Conv2d model_invalid(options.stride({1, 2})); |
126 | }(), |
127 | "padding='same' is not supported for strided convolutions" ); |
128 | } |
129 | |
130 | TEST_F(ModulesTest, Conv3d) { |
131 | Conv3d model(Conv3dOptions(3, 2, 3).stride(1).bias(false)); |
132 | model->weight.set_data( |
133 | torch::arange(162, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 3, 3})); |
134 | auto x = torch::arange(375, torch::dtype(torch::kFloat).requires_grad(true)) |
135 | .reshape({1, 3, 5, 5, 5}); |
136 | auto y = model(x); |
137 | auto expected = torch::tensor( |
138 | {{{{{700704., 703944., 707184.}, |
139 | {716904., 720144., 723384.}, |
140 | {733104., 736344., 739584.}}, |
141 | |
142 | {{781704., 784944., 788184.}, |
143 | {797904., 801144., 804384.}, |
144 | {814104., 817344., 820584.}}, |
145 | |
146 | {{862704., 865944., 869184.}, |
147 | {878904., 882144., 885384.}, |
148 | {895104., 898344., 901584.}}}, |
149 | |
150 | {{{1724220., 1734021., 1743822.}, |
151 | {1773225., 1783026., 1792827.}, |
152 | {1822230., 1832031., 1841832.}}, |
153 | |
154 | {{1969245., 1979046., 1988847.}, |
155 | {2018250., 2028051., 2037852.}, |
156 | {2067255., 2077056., 2086857.}}, |
157 | |
158 | {{2214270., 2224071., 2233872.}, |
159 | {2263275., 2273076., 2282877.}, |
160 | {2312280., 2322081., 2331882.}}}}}, |
161 | torch::kFloat); |
162 | ASSERT_TRUE(torch::allclose(y, expected)); |
163 | |
164 | torch::Tensor s = y.sum(); |
165 | s.backward(); |
166 | ASSERT_EQ(s.ndimension(), 0); |
167 | ASSERT_TRUE(model->weight.grad().numel() == 3 * 2 * 3 * 3 * 3); |
168 | } |
169 | |
170 | TEST_F(ModulesTest, Conv3dSameStrided) { |
171 | auto options = Conv3dOptions(3, 2, {3, 4, 5}); |
172 | options.stride(1).padding(torch::kSame); |
173 | Conv3d model_valid(options); |
174 | ASSERT_THROWS_WITH( |
175 | [&] { Conv3d model_invalid(options.stride(2)); }(), |
176 | "padding='same' is not supported for strided convolutions" ); |
177 | ASSERT_THROWS_WITH( |
178 | [&] { |
179 | Conv3d model_invalid(options.stride({1, 2, 1})); |
180 | }(), |
181 | "padding='same' is not supported for strided convolutions" ); |
182 | } |
183 | |
184 | TEST_F(ModulesTest, ConvTranspose1d) { |
185 | ConvTranspose1d model(ConvTranspose1dOptions(3, 2, 3).stride(1).bias(false)); |
186 | model->weight.set_data(torch::arange(18.).view({2, 3, 3})); |
187 | auto x = torch::arange(20.).reshape({2, 2, 5}); |
188 | auto y = model(x); |
189 | auto expected = torch::tensor( |
190 | {{{45., 104., 179., 212., 245., 188., 107.}, |
191 | {60., 140., 242., 293., 344., 260., 146.}, |
192 | {75., 176., 305., 374., 443., 332., 185.}}, |
193 | {{135., 304., 509., 542., 575., 428., 237.}, |
194 | {210., 460., 752., 803., 854., 620., 336.}, |
195 | {285., 616., 995., 1064., 1133., 812., 435.}}}); |
196 | ASSERT_TRUE(torch::allclose(y, expected)); |
197 | |
198 | torch::Tensor s = y.sum(); |
199 | s.backward(); |
200 | ASSERT_EQ(s.ndimension(), 0); |
201 | ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3); |
202 | } |
203 | |
204 | TEST_F(ModulesTest, ConvTranspose2dEven) { |
205 | ConvTranspose2d model(ConvTranspose2dOptions(3, 2, 3).stride(1).bias(false)); |
206 | model->weight.set_data(torch::arange(54.).view({2, 3, 3, 3})); |
207 | auto x = torch::arange(50.).view({1, 2, 5, 5}); |
208 | auto y = model(x); |
209 | auto expected = torch::tensor( |
210 | {{{{675., 1402., 2183., 2270., 2357., 1634., 849.}, |
211 | {1560., 3240., 5044., 5236., 5428., 3760., 1952.}, |
212 | {2685., 5574., 8673., 8988., 9303., 6438., 3339.}, |
213 | {3180., 6594., 10248., 10563., 10878., 7518., 3894.}, |
214 | {3675., 7614., 11823., 12138., 12453., 8598., 4449.}, |
215 | {2820., 5832., 9040., 9268., 9496., 6544., 3380.}, |
216 | {1605., 3314., 5129., 5252., 5375., 3698., 1907.}}, |
217 | {{900., 1870., 2912., 3053., 3194., 2210., 1146.}, |
218 | {2100., 4356., 6772., 7072., 7372., 5092., 2636.}, |
219 | {3630., 7518., 11670., 12147., 12624., 8706., 4500.}, |
220 | {4395., 9078., 14055., 14532., 15009., 10326., 5325.}, |
221 | {5160., 10638., 16440., 16917., 17394., 11946., 6150.}, |
222 | {3900., 8028., 12388., 12724., 13060., 8956., 4604.}, |
223 | {2190., 4502., 6938., 7115., 7292., 4994., 2564.}}, |
224 | {{1125., 2338., 3641., 3836., 4031., 2786., 1443.}, |
225 | {2640., 5472., 8500., 8908., 9316., 6424., 3320.}, |
226 | {4575., 9462., 14667., 15306., 15945., 10974., 5661.}, |
227 | {5610., 11562., 17862., 18501., 19140., 13134., 6756.}, |
228 | {6645., 13662., 21057., 21696., 22335., 15294., 7851.}, |
229 | {4980., 10224., 15736., 16180., 16624., 11368., 5828.}, |
230 | {2775., 5690., 8747., 8978., 9209., 6290., 3221.}}}}); |
231 | ASSERT_TRUE(torch::allclose(y, expected)); |
232 | |
233 | torch::Tensor s = y.sum(); |
234 | s.backward(); |
235 | ASSERT_EQ(s.ndimension(), 0); |
236 | ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 3); |
237 | } |
238 | |
239 | TEST_F(ModulesTest, ConvTranspose2dUneven) { |
240 | ConvTranspose2d model( |
241 | ConvTranspose2dOptions(3, 2, {3, 2}).stride({1, 1}).bias(false)); |
242 | model->weight.set_data(torch::arange(36.).view({2, 3, 3, 2})); |
243 | auto x = torch::arange(40.).view({1, 2, 5, 4}); |
244 | auto y = model(x); |
245 | auto expected = torch::tensor( |
246 | {{{{360., 758., 796., 834., 440.}, |
247 | {832., 1752., 1836., 1920., 1012.}, |
248 | {1432., 3014., 3152., 3290., 1732.}, |
249 | {1696., 3566., 3704., 3842., 2020.}, |
250 | {1960., 4118., 4256., 4394., 2308.}, |
251 | {1504., 3152., 3252., 3352., 1756.}, |
252 | {856., 1790., 1844., 1898., 992.}}, |
253 | {{480., 1010., 1072., 1134., 596.}, |
254 | {1120., 2352., 2484., 2616., 1372.}, |
255 | {1936., 4058., 4268., 4478., 2344.}, |
256 | {2344., 4898., 5108., 5318., 2776.}, |
257 | {2752., 5738., 5948., 6158., 3208.}, |
258 | {2080., 4328., 4476., 4624., 2404.}, |
259 | {1168., 2426., 2504., 2582., 1340.}}, |
260 | {{600., 1262., 1348., 1434., 752.}, |
261 | {1408., 2952., 3132., 3312., 1732.}, |
262 | {2440., 5102., 5384., 5666., 2956.}, |
263 | {2992., 6230., 6512., 6794., 3532.}, |
264 | {3544., 7358., 7640., 7922., 4108.}, |
265 | {2656., 5504., 5700., 5896., 3052.}, |
266 | {1480., 3062., 3164., 3266., 1688.}}}}); |
267 | ASSERT_TRUE(torch::allclose(y, expected)); |
268 | |
269 | torch::Tensor s = y.sum(); |
270 | s.backward(); |
271 | ASSERT_EQ(s.ndimension(), 0); |
272 | ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 2); |
273 | } |
274 | |
275 | TEST_F(ModulesTest, ConvTranspose3d) { |
276 | ConvTranspose3d model(ConvTranspose3dOptions(2, 2, 2).stride(1).bias(false)); |
277 | model->weight.set_data(torch::arange(32.).reshape({2, 2, 2, 2, 2})); |
278 | auto x = torch::arange(16.).reshape({1, 2, 2, 2, 2}); |
279 | auto y = model(x); |
280 | auto expected = torch::tensor( |
281 | {{{{{128., 280., 154.}, {304., 664., 364.}, {184., 400., 218.}}, |
282 | {{352., 768., 420.}, {832., 1808., 984.}, {496., 1072., 580.}}, |
283 | {{256., 552., 298.}, {592., 1272., 684.}, {344., 736., 394.}}}, |
284 | {{{192., 424., 234.}, {464., 1016., 556.}, {280., 608., 330.}}, |
285 | {{544., 1184., 644.}, {1280., 2768., 1496.}, {752., 1616., 868.}}, |
286 | {{384., 824., 442.}, {880., 1880., 1004.}, {504., 1072., 570.}}}}}); |
287 | ASSERT_TRUE(torch::allclose(y, expected)); |
288 | |
289 | torch::Tensor s = y.sum(); |
290 | s.backward(); |
291 | ASSERT_EQ(s.ndimension(), 0); |
292 | ASSERT_TRUE(model->weight.grad().numel() == 2 * 2 * 2 * 2 * 2); |
293 | } |
294 | |
295 | TEST_F(ModulesTest, MaxPool1d) { |
296 | MaxPool1d model(MaxPool1dOptions(3).stride(2)); |
297 | auto x = torch::ones({1, 1, 5}, torch::requires_grad()); |
298 | auto y = model(x); |
299 | torch::Tensor s = y.sum(); |
300 | |
301 | s.backward(); |
302 | ASSERT_EQ(y.ndimension(), 3); |
303 | ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2}))); |
304 | ASSERT_EQ(s.ndimension(), 0); |
305 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2})); |
306 | } |
307 | |
308 | TEST_F(ModulesTest, MaxPool1dReturnIndices) { |
309 | MaxPool1d model(MaxPool1dOptions(3).stride(2)); |
310 | auto x = torch::ones({1, 1, 5}, torch::requires_grad()); |
311 | torch::Tensor y, indices; |
312 | std::tie(y, indices) = model->forward_with_indices(x); |
313 | |
314 | ASSERT_EQ(y.dim(), 3); |
315 | ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2}))); |
316 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2})); |
317 | |
318 | ASSERT_TRUE( |
319 | torch::allclose(indices, torch::tensor({{{0, 2}}}, torch::kLong))); |
320 | ASSERT_EQ(indices.sizes(), std::vector<int64_t>({1, 1, 2})); |
321 | } |
322 | |
323 | TEST_F(ModulesTest, MaxPool2dEven) { |
324 | MaxPool2d model(MaxPool2dOptions(3).stride(2)); |
325 | auto x = torch::ones({2, 5, 5}, torch::requires_grad()); |
326 | auto y = model(x); |
327 | torch::Tensor s = y.sum(); |
328 | |
329 | s.backward(); |
330 | ASSERT_EQ(y.ndimension(), 3); |
331 | ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2}))); |
332 | ASSERT_EQ(s.ndimension(), 0); |
333 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2})); |
334 | } |
335 | |
336 | TEST_F(ModulesTest, MaxPool2dUneven) { |
337 | MaxPool2d model(MaxPool2dOptions({3, 2}).stride({2, 2})); |
338 | auto x = torch::ones({2, 5, 4}, torch::requires_grad()); |
339 | auto y = model(x); |
340 | torch::Tensor s = y.sum(); |
341 | |
342 | s.backward(); |
343 | ASSERT_EQ(y.ndimension(), 3); |
344 | ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2}))); |
345 | ASSERT_EQ(s.ndimension(), 0); |
346 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2})); |
347 | } |
348 | |
349 | TEST_F(ModulesTest, MaxPool2dReturnIndices) { |
350 | MaxPool2d model(MaxPool2dOptions(3).stride(2)); |
351 | auto x = torch::ones({2, 5, 5}, torch::requires_grad()); |
352 | torch::Tensor y, indices; |
353 | std::tie(y, indices) = model->forward_with_indices(x); |
354 | |
355 | ASSERT_EQ(y.dim(), 3); |
356 | ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2}))); |
357 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2})); |
358 | ASSERT_TRUE(torch::allclose( |
359 | indices, |
360 | torch::tensor({{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}}, torch::kLong))); |
361 | ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 2, 2})); |
362 | } |
363 | |
364 | TEST_F(ModulesTest, MaxPool3d) { |
365 | MaxPool3d model(MaxPool3dOptions(3).stride(2)); |
366 | auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad()); |
367 | auto y = model(x); |
368 | torch::Tensor s = y.sum(); |
369 | |
370 | s.backward(); |
371 | ASSERT_EQ(y.ndimension(), 4); |
372 | ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2}))); |
373 | ASSERT_EQ(s.ndimension(), 0); |
374 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2})); |
375 | } |
376 | |
377 | TEST_F(ModulesTest, MaxPool3dReturnIndices) { |
378 | MaxPool3d model(MaxPool3dOptions(3).stride(2)); |
379 | auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad()); |
380 | torch::Tensor y, indices; |
381 | std::tie(y, indices) = model->forward_with_indices(x); |
382 | |
383 | ASSERT_EQ(y.dim(), 4); |
384 | ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2}))); |
385 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2})); |
386 | |
387 | ASSERT_TRUE(torch::allclose( |
388 | indices, |
389 | torch::tensor( |
390 | {{{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}}, |
391 | {{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}}}, |
392 | torch::kLong))); |
393 | ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 2, 2, 2})); |
394 | } |
395 | |
396 | TEST_F(ModulesTest, AvgPool1d) { |
397 | AvgPool1d model(AvgPool1dOptions(3).stride(2)); |
398 | auto x = torch::ones({1, 1, 5}, torch::requires_grad()); |
399 | auto y = model(x); |
400 | torch::Tensor s = y.sum(); |
401 | |
402 | s.backward(); |
403 | ASSERT_EQ(y.ndimension(), 3); |
404 | ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2}))); |
405 | ASSERT_EQ(s.ndimension(), 0); |
406 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2})); |
407 | } |
408 | |
409 | TEST_F(ModulesTest, AvgPool2dEven) { |
410 | AvgPool2d model(AvgPool2dOptions(3).stride(2)); |
411 | auto x = torch::ones({2, 5, 5}, torch::requires_grad()); |
412 | auto y = model(x); |
413 | torch::Tensor s = y.sum(); |
414 | |
415 | s.backward(); |
416 | ASSERT_EQ(y.ndimension(), 3); |
417 | ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2}))); |
418 | ASSERT_EQ(s.ndimension(), 0); |
419 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2})); |
420 | } |
421 | |
422 | TEST_F(ModulesTest, AvgPool2dUneven) { |
423 | AvgPool2d model(AvgPool2dOptions({3, 2}).stride({2, 2})); |
424 | auto x = torch::ones({2, 5, 4}, torch::requires_grad()); |
425 | auto y = model(x); |
426 | torch::Tensor s = y.sum(); |
427 | |
428 | s.backward(); |
429 | ASSERT_EQ(y.ndimension(), 3); |
430 | ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2}))); |
431 | ASSERT_EQ(s.ndimension(), 0); |
432 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2})); |
433 | } |
434 | |
435 | TEST_F(ModulesTest, AvgPool3d) { |
436 | AvgPool3d model(AvgPool3dOptions(3).stride(2)); |
437 | auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad()); |
438 | auto y = model(x); |
439 | torch::Tensor s = y.sum(); |
440 | |
441 | s.backward(); |
442 | ASSERT_EQ(y.ndimension(), 4); |
443 | ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2}))); |
444 | ASSERT_EQ(s.ndimension(), 0); |
445 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2})); |
446 | } |
447 | |
448 | TEST_F(ModulesTest, FractionalMaxPool2d) { |
449 | FractionalMaxPool2d model(FractionalMaxPool2dOptions(3).output_size(2)); |
450 | auto x = torch::ones({2, 5, 5}, torch::requires_grad()); |
451 | auto y = model(x); |
452 | torch::Tensor s = y.sum(); |
453 | |
454 | s.backward(); |
455 | ASSERT_EQ(y.ndimension(), 3); |
456 | ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2}))); |
457 | ASSERT_EQ(s.ndimension(), 0); |
458 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2})); |
459 | } |
460 | |
461 | TEST_F(ModulesTest, FractionalMaxPool2dReturnIndices) { |
462 | FractionalMaxPool2d model(FractionalMaxPool2dOptions(3).output_size(2)); |
463 | auto x = torch::ones({2, 5, 5}, torch::requires_grad()); |
464 | torch::Tensor y, indices; |
465 | std::tie(y, indices) = model->forward_with_indices(x); |
466 | |
467 | ASSERT_EQ(y.dim(), 3); |
468 | ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2}))); |
469 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2})); |
470 | ASSERT_TRUE(torch::allclose( |
471 | indices, torch::tensor({{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}}))); |
472 | ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 2, 2})); |
473 | } |
474 | |
475 | TEST_F(ModulesTest, FractionalMaxPool3d) { |
476 | FractionalMaxPool3d model(FractionalMaxPool3dOptions(3).output_size(2)); |
477 | auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad()); |
478 | auto y = model(x); |
479 | torch::Tensor s = y.sum(); |
480 | |
481 | s.backward(); |
482 | ASSERT_EQ(y.ndimension(), 4); |
483 | ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2}))); |
484 | ASSERT_EQ(s.ndimension(), 0); |
485 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2})); |
486 | } |
487 | |
488 | TEST_F(ModulesTest, FractionalMaxPool3dReturnIndices) { |
489 | FractionalMaxPool3d model(FractionalMaxPool3dOptions(3).output_size(2)); |
490 | auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad()); |
491 | torch::Tensor y, indices; |
492 | std::tie(y, indices) = model->forward_with_indices(x); |
493 | |
494 | ASSERT_EQ(y.dim(), 4); |
495 | ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2}))); |
496 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2})); |
497 | |
498 | ASSERT_TRUE(torch::allclose( |
499 | indices, |
500 | torch::tensor( |
501 | {{{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}}, |
502 | {{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}}}))); |
503 | ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 2, 2, 2})); |
504 | } |
505 | |
506 | TEST_F(ModulesTest, LPPool1d) { |
507 | int norm_type = 2; |
508 | int stride = 2; |
509 | int kernel_size = 3; |
510 | |
511 | LPPool1d model(LPPool1dOptions(norm_type, kernel_size).stride(stride)); |
512 | auto x = torch::ones({1, 1, 5}); |
513 | auto y = model(x); |
514 | auto expected = |
515 | (torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) * |
516 | kernel_size) |
517 | .pow(1. / norm_type); |
518 | |
519 | ASSERT_EQ(y.ndimension(), 3); |
520 | ASSERT_TRUE(torch::allclose(y, expected)); |
521 | ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2})); |
522 | } |
523 | |
524 | TEST_F(ModulesTest, LPPool2d) { |
525 | int norm_type = 2; |
526 | int stride = 2; |
527 | std::vector<int64_t> kernel_size({2, 3}); |
528 | |
529 | LPPool2d model(LPPool2dOptions(norm_type, kernel_size).stride(stride)); |
530 | auto x = torch::ones({1, 2, 5}); |
531 | auto y = model(x); |
532 | auto expected = |
533 | (torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) * |
534 | (kernel_size[0] * kernel_size[1])) |
535 | .pow(1. / norm_type); |
536 | |
537 | ASSERT_EQ(y.ndimension(), 3); |
538 | ASSERT_TRUE(torch::allclose(y, expected)); |
539 | ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2})); |
540 | } |
541 | |
542 | TEST_F(ModulesTest, Identity) { |
543 | Identity identity; |
544 | auto input = torch::tensor( |
545 | {{1, 3, 4}, {2, 3, 4}}, torch::dtype(torch::kFloat).requires_grad(true)); |
546 | auto output = identity->forward(input); |
547 | auto expected = torch::tensor({{1, 3, 4}, {2, 3, 4}}, torch::kFloat); |
548 | auto s = output.sum(); |
549 | s.backward(); |
550 | |
551 | ASSERT_TRUE(torch::equal(output, expected)); |
552 | ASSERT_TRUE(torch::equal(input.grad(), torch::ones_like(input))); |
553 | } |
554 | |
555 | TEST_F(ModulesTest, Flatten) { |
556 | Flatten flatten; |
557 | auto input = torch::tensor( |
558 | {{1, 3, 4}, {2, 5, 6}}, torch::dtype(torch::kFloat).requires_grad(true)); |
559 | auto output = flatten->forward(input); |
560 | auto expected = torch::tensor({{1, 3, 4}, {2, 5, 6}}, torch::kFloat); |
561 | auto s = output.sum(); |
562 | |
563 | s.backward(); |
564 | ASSERT_TRUE(torch::equal(output, expected)); |
565 | ASSERT_TRUE(torch::equal(input.grad(), torch::ones_like(input))); |
566 | |
567 | // Testing with optional arguments start_dim and end_dim |
568 | Flatten flatten_optional_dims(FlattenOptions().start_dim(2).end_dim(3)); |
569 | input = torch::tensor( |
570 | {{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}, |
571 | {{{9, 10}, {11, 12}}, {{13, 14}, {15, 16}}}}, |
572 | torch::dtype(torch::kFloat) |
573 | .requires_grad(true)); // Tensor with sizes (2, 2, 2, 2) |
574 | |
575 | output = flatten_optional_dims->forward(input); |
576 | expected = torch::tensor( |
577 | {{{1, 2, 3, 4}, {5, 6, 7, 8}}, {{9, 10, 11, 12}, {13, 14, 15, 16}}}, |
578 | torch::kFloat); // Tensor with sizes (2, 2, 4) |
579 | |
580 | s = output.sum(); |
581 | s.backward(); |
582 | ASSERT_TRUE(torch::equal(output, expected)); |
583 | ASSERT_TRUE(torch::equal(input.grad(), torch::ones_like(input))); |
584 | } |
585 | |
586 | TEST_F(ModulesTest, Unflatten) { |
587 | // Non-named tensor |
588 | Unflatten unflatten(UnflattenOptions(0, {2, 2})); |
589 | auto output = unflatten->forward(torch::tensor({1, 2, 3, 4})); |
590 | auto expected = torch::tensor({{1, 2}, {3, 4}}); |
591 | ASSERT_TRUE(torch::equal(output, expected)); |
592 | |
593 | // Named tensor |
594 | auto make_dimnames = [](std::vector<std::string> names) { |
595 | std::vector<torch::Dimname> dimnames; |
596 | // NOLINTNEXTLINE(performance-for-range-copy) |
597 | for (auto name : names) { |
598 | // NOLINTNEXTLINE(performance-inefficient-vector-operation) |
599 | dimnames.push_back( |
600 | torch::Dimname::fromSymbol(torch::Symbol::dimname(name))); |
601 | } |
602 | return dimnames; |
603 | }; |
604 | |
605 | unflatten = Unflatten(UnflattenOptions( |
606 | "B" , |
607 | {std::pair<std::string, int64_t>{"B1" , 2}, |
608 | std::pair<std::string, int64_t>{"B2" , 2}})); |
609 | output = unflatten->forward( |
610 | torch::tensor({{1, 2, 3, 4}}).refine_names(make_dimnames({"A" , "B" }))); |
611 | expected = torch::tensor({{{1, 2}, {3, 4}}}) |
612 | .refine_names(make_dimnames({"A" , "B1" , "B2" })); |
613 | ASSERT_TRUE(torch::equal(output, expected)); |
614 | } |
615 | |
616 | TEST_F(ModulesTest, AdaptiveMaxPool1d) { |
617 | AdaptiveMaxPool1d model(3); |
618 | auto x = torch::tensor( |
619 | {{{1, 2, 3, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true)); |
620 | auto y = model(x); |
621 | torch::Tensor s = y.sum(); |
622 | |
623 | s.backward(); |
624 | ASSERT_EQ(y.ndimension(), 3); |
625 | ASSERT_TRUE(torch::allclose(y, torch::tensor({{{2, 4, 5}}}, torch::kFloat))); |
626 | ASSERT_EQ(s.ndimension(), 0); |
627 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3})); |
628 | } |
629 | |
630 | TEST_F(ModulesTest, AdaptiveMaxPool1dReturnIndices) { |
631 | AdaptiveMaxPool1d model(3); |
632 | auto x = torch::tensor( |
633 | {{{1, 2, 3, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true)); |
634 | torch::Tensor y, indices; |
635 | std::tie(y, indices) = model->forward_with_indices(x); |
636 | |
637 | ASSERT_EQ(y.dim(), 3); |
638 | ASSERT_TRUE(torch::allclose(y, torch::tensor({{{2, 4, 5}}}, torch::kFloat))); |
639 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3})); |
640 | ASSERT_TRUE( |
641 | torch::allclose(indices, torch::tensor({{{1, 3, 4}}}, torch::kLong))); |
642 | ASSERT_EQ(indices.sizes(), std::vector<int64_t>({1, 1, 3})); |
643 | } |
644 | |
645 | TEST_F(ModulesTest, AdaptiveMaxPool2dEven) { |
646 | AdaptiveMaxPool2d model(3); |
647 | auto x = torch::arange(0., 50); |
648 | x.resize_({2, 5, 5}).set_requires_grad(true); |
649 | auto y = model(x); |
650 | torch::Tensor s = y.sum(); |
651 | |
652 | s.backward(); |
653 | ASSERT_EQ(y.ndimension(), 3); |
654 | ASSERT_TRUE(torch::allclose( |
655 | y, |
656 | torch::tensor( |
657 | { |
658 | {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}, |
659 | {{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}, |
660 | }, |
661 | torch::kFloat))); |
662 | ASSERT_EQ(s.ndimension(), 0); |
663 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3})); |
664 | } |
665 | |
666 | TEST_F(ModulesTest, AdaptiveMaxPool2dUneven) { |
667 | AdaptiveMaxPool2d model(AdaptiveMaxPool2dOptions({3, 2})); |
668 | auto x = torch::arange(0., 40); |
669 | x.resize_({2, 5, 4}).set_requires_grad(true); |
670 | auto y = model(x); |
671 | torch::Tensor s = y.sum(); |
672 | |
673 | s.backward(); |
674 | ASSERT_EQ(y.ndimension(), 3); |
675 | ASSERT_TRUE(torch::allclose( |
676 | y, |
677 | torch::tensor( |
678 | { |
679 | {{5, 7}, {13, 15}, {17, 19}}, |
680 | {{25, 27}, {33, 35}, {37, 39}}, |
681 | }, |
682 | torch::kFloat))); |
683 | ASSERT_EQ(s.ndimension(), 0); |
684 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 2})); |
685 | } |
686 | |
687 | TEST_F(ModulesTest, AdaptiveMaxPool2dReturnIndicesEven) { |
688 | AdaptiveMaxPool2d model(3); |
689 | auto x = torch::arange(0., 50); |
690 | x.resize_({2, 5, 5}).set_requires_grad(true); |
691 | torch::Tensor y, indices; |
692 | std::tie(y, indices) = model->forward_with_indices(x); |
693 | torch::Tensor s = y.sum(); |
694 | |
695 | s.backward(); |
696 | ASSERT_EQ(s.ndimension(), 0); |
697 | |
698 | ASSERT_EQ(y.ndimension(), 3); |
699 | ASSERT_TRUE(torch::allclose( |
700 | y, |
701 | torch::tensor( |
702 | { |
703 | {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}, |
704 | {{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}, |
705 | }, |
706 | torch::kFloat))); |
707 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3})); |
708 | |
709 | ASSERT_EQ(indices.ndimension(), 3); |
710 | ASSERT_TRUE(torch::allclose( |
711 | indices, |
712 | torch::tensor( |
713 | { |
714 | {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}, |
715 | {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}, |
716 | }, |
717 | torch::kLong))); |
718 | ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 3, 3})); |
719 | } |
720 | |
721 | TEST_F(ModulesTest, AdaptiveMaxPool2dReturnIndicesUneven) { |
722 | AdaptiveMaxPool2d model(AdaptiveMaxPool2dOptions({3, 2})); |
723 | auto x = torch::arange(0., 40); |
724 | x.resize_({2, 5, 4}).set_requires_grad(true); |
725 | torch::Tensor y, indices; |
726 | std::tie(y, indices) = model->forward_with_indices(x); |
727 | torch::Tensor s = y.sum(); |
728 | |
729 | s.backward(); |
730 | ASSERT_EQ(s.ndimension(), 0); |
731 | |
732 | ASSERT_EQ(y.ndimension(), 3); |
733 | ASSERT_TRUE(torch::allclose( |
734 | y, |
735 | torch::tensor( |
736 | { |
737 | {{5, 7}, {13, 15}, {17, 19}}, |
738 | {{25, 27}, {33, 35}, {37, 39}}, |
739 | }, |
740 | torch::kFloat))); |
741 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 2})); |
742 | |
743 | ASSERT_EQ(indices.ndimension(), 3); |
744 | ASSERT_TRUE(torch::allclose( |
745 | indices, |
746 | torch::tensor( |
747 | { |
748 | {{5, 7}, {13, 15}, {17, 19}}, |
749 | {{5, 7}, {13, 15}, {17, 19}}, |
750 | }, |
751 | torch::kLong))); |
752 | ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 3, 2})); |
753 | } |
754 | |
755 | TEST_F(ModulesTest, AdaptiveMaxPool3d) { |
756 | AdaptiveMaxPool3d model(3); |
757 | auto x = torch::arange(0., 64); |
758 | x.resize_({1, 4, 4, 4}).set_requires_grad(true); |
759 | auto y = model(x); |
760 | torch::Tensor s = y.sum(); |
761 | |
762 | s.backward(); |
763 | ASSERT_EQ(s.ndimension(), 0); |
764 | |
765 | ASSERT_EQ(y.ndimension(), 4); |
766 | ASSERT_TRUE(torch::allclose( |
767 | y, |
768 | torch::tensor( |
769 | { |
770 | {{21, 22, 23}, {25, 26, 27}, {29, 30, 31}}, |
771 | {{37, 38, 39}, {41, 42, 43}, {45, 46, 47}}, |
772 | {{53, 54, 55}, {57, 58, 59}, {61, 62, 63}}, |
773 | }, |
774 | torch::kFloat))); |
775 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 3, 3, 3})); |
776 | } |
777 | |
778 | TEST_F(ModulesTest, AdaptiveMaxPool3dReturnIndices) { |
779 | AdaptiveMaxPool3d model(3); |
780 | auto x = torch::arange(0., 64); |
781 | x.resize_({1, 4, 4, 4}).set_requires_grad(true); |
782 | torch::Tensor y, indices; |
783 | std::tie(y, indices) = model->forward_with_indices(x); |
784 | torch::Tensor s = y.sum(); |
785 | |
786 | s.backward(); |
787 | ASSERT_EQ(s.ndimension(), 0); |
788 | |
789 | ASSERT_EQ(y.ndimension(), 4); |
790 | ASSERT_TRUE(torch::allclose( |
791 | y, |
792 | torch::tensor( |
793 | { |
794 | {{21, 22, 23}, {25, 26, 27}, {29, 30, 31}}, |
795 | {{37, 38, 39}, {41, 42, 43}, {45, 46, 47}}, |
796 | {{53, 54, 55}, {57, 58, 59}, {61, 62, 63}}, |
797 | }, |
798 | torch::kFloat))); |
799 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 3, 3, 3})); |
800 | |
801 | ASSERT_EQ(indices.ndimension(), 4); |
802 | ASSERT_TRUE(torch::allclose( |
803 | indices, |
804 | torch::tensor( |
805 | { |
806 | {{21, 22, 23}, {25, 26, 27}, {29, 30, 31}}, |
807 | {{37, 38, 39}, {41, 42, 43}, {45, 46, 47}}, |
808 | {{53, 54, 55}, {57, 58, 59}, {61, 62, 63}}, |
809 | }, |
810 | torch::kLong))); |
811 | ASSERT_EQ(indices.sizes(), std::vector<int64_t>({1, 3, 3, 3})); |
812 | } |
813 | |
814 | TEST_F(ModulesTest, AdaptiveAvgPool1d) { |
815 | AdaptiveAvgPool1d model(3); |
816 | auto x = torch::tensor( |
817 | {{{1, 2, 3, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true)); |
818 | auto y = model(x); |
819 | torch::Tensor s = y.sum(); |
820 | |
821 | s.backward(); |
822 | ASSERT_EQ(s.ndimension(), 0); |
823 | |
824 | ASSERT_EQ(y.ndimension(), 3); |
825 | ASSERT_TRUE( |
826 | torch::allclose(y, torch::tensor({{{1.5, 3.0, 4.5}}}, torch::kFloat))); |
827 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3})); |
828 | } |
829 | |
830 | TEST_F(ModulesTest, AdaptiveAvgPool2dEven) { |
831 | AdaptiveAvgPool2d model(3); |
832 | auto x = torch::arange(0., 50); |
833 | x.resize_({2, 5, 5}).set_requires_grad(true); |
834 | auto y = model(x); |
835 | torch::Tensor s = y.sum(); |
836 | |
837 | s.backward(); |
838 | ASSERT_EQ(s.ndimension(), 0); |
839 | |
840 | ASSERT_EQ(y.ndimension(), 3); |
841 | ASSERT_TRUE(torch::allclose( |
842 | y, |
843 | torch::tensor( |
844 | { |
845 | {{3.0, 4.5, 6.0}, {10.5, 12.0, 13.5}, {18.0, 19.5, 21.0}}, |
846 | {{28.0, 29.5, 31.0}, {35.5, 37.0, 38.5}, {43.0, 44.5, 46.0}}, |
847 | }, |
848 | torch::kFloat))); |
849 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3})); |
850 | } |
851 | |
852 | TEST_F(ModulesTest, AdaptiveAvgPool2dUneven) { |
853 | AdaptiveAvgPool2d model(AdaptiveAvgPool2dOptions({3, 2})); |
854 | auto x = torch::arange(0., 40); |
855 | x.resize_({2, 5, 4}).set_requires_grad(true); |
856 | auto y = model(x); |
857 | torch::Tensor s = y.sum(); |
858 | |
859 | s.backward(); |
860 | ASSERT_EQ(s.ndimension(), 0); |
861 | |
862 | ASSERT_EQ(y.ndimension(), 3); |
863 | ASSERT_TRUE(torch::allclose( |
864 | y, |
865 | torch::tensor( |
866 | { |
867 | {{2.5, 4.5}, {8.5, 10.5}, {14.5, 16.5}}, |
868 | {{22.5, 24.5}, {28.5, 30.5}, {34.5, 36.5}}, |
869 | }, |
870 | torch::kFloat))); |
871 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 2})); |
872 | } |
873 | |
874 | TEST_F(ModulesTest, AdaptiveAvgPool3d) { |
875 | AdaptiveAvgPool3d model(3); |
876 | auto x = torch::arange(0., 64); |
877 | x.resize_({1, 4, 4, 4}).set_requires_grad(true); |
878 | auto y = model(x); |
879 | torch::Tensor s = y.sum(); |
880 | |
881 | s.backward(); |
882 | ASSERT_EQ(s.ndimension(), 0); |
883 | |
884 | ASSERT_EQ(y.ndimension(), 4); |
885 | ASSERT_TRUE(torch::allclose( |
886 | y, |
887 | torch::tensor( |
888 | { |
889 | {{10.5, 11.5, 12.5}, {14.5, 15.5, 16.5}, {18.5, 19.5, 20.5}}, |
890 | {{26.5, 27.5, 28.5}, {30.5, 31.5, 32.5}, {34.5, 35.5, 36.5}}, |
891 | {{42.5, 43.5, 44.5}, {46.5, 47.5, 48.5}, {50.5, 51.5, 52.5}}, |
892 | }, |
893 | torch::kFloat))); |
894 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 3, 3, 3})); |
895 | } |
896 | |
897 | TEST_F(ModulesTest, MaxUnpool1d) { |
898 | auto indices = torch::tensor({{{1, 3, 4}}}, torch::kLong); |
899 | auto x = torch::tensor( |
900 | {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true)); |
901 | auto model = MaxUnpool1d{3}; |
902 | auto y = model->forward(x, indices); |
903 | |
904 | ASSERT_EQ(y.dim(), 3); |
905 | ASSERT_TRUE(torch::allclose( |
906 | y, torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat))); |
907 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 9})); |
908 | |
909 | indices = torch::tensor({{{1, 3, 4}}}, torch::kLong); |
910 | x = torch::tensor( |
911 | {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true)); |
912 | model = MaxUnpool1d{MaxUnpool1dOptions(3).stride(2).padding(1)}; |
913 | y = model->forward(x, indices, std::vector<int64_t>({1, 1, 5})); |
914 | |
915 | ASSERT_EQ(y.dim(), 3); |
916 | ASSERT_TRUE( |
917 | torch::allclose(y, torch::tensor({{{0, 2, 0, 4, 5}}}, torch::kFloat))); |
918 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 5})); |
919 | } |
920 | |
921 | TEST_F(ModulesTest, MaxPool1d_MaxUnpool1d) { |
922 | MaxPool1d pool{MaxPool1dOptions(2).stride(2)}; |
923 | MaxUnpool1d unpool{MaxUnpool1dOptions(2).stride(2)}; |
924 | auto input = torch::tensor({{{1, 2, 3, 4, 5, 6, 7, 8}}}, torch::kFloat); |
925 | torch::Tensor output, indices; |
926 | std::tie(output, indices) = pool->forward_with_indices(input); |
927 | ASSERT_TRUE(torch::allclose( |
928 | unpool(output, indices), |
929 | torch::tensor({{{0, 2, 0, 4, 0, 6, 0, 8}}}, torch::kFloat))); |
930 | |
931 | // Example showcasing the use of output_size |
932 | input = torch::tensor({{{1, 2, 3, 4, 5, 6, 7, 8, 9}}}, torch::kFloat); |
933 | std::tie(output, indices) = pool->forward_with_indices(input); |
934 | ASSERT_TRUE(torch::allclose( |
935 | unpool(output, indices, input.sizes().vec()), |
936 | torch::tensor({{{0, 2, 0, 4, 0, 6, 0, 8, 0}}}, torch::kFloat))); |
937 | ASSERT_TRUE(torch::allclose( |
938 | unpool(output, indices), |
939 | torch::tensor({{{0, 2, 0, 4, 0, 6, 0, 8}}}, torch::kFloat))); |
940 | } |
941 | |
942 | TEST_F(ModulesTest, MaxUnpool2d) { |
943 | auto indices = torch::tensor( |
944 | {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}, |
945 | {{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}}, |
946 | torch::kLong); |
947 | auto x = torch::tensor( |
948 | {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}, |
949 | {{{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}}}, |
950 | torch::dtype(torch::kFloat).requires_grad(true)); |
951 | auto model = MaxUnpool2d{MaxUnpool2dOptions(3).stride(2).padding(1)}; |
952 | auto y = model->forward(x, indices); |
953 | |
954 | ASSERT_EQ(y.dim(), 4); |
955 | ASSERT_TRUE(torch::allclose( |
956 | y, |
957 | torch::tensor( |
958 | {{{{0, 0, 0, 0, 0}, |
959 | {0, 6, 0, 8, 9}, |
960 | {0, 0, 0, 0, 0}, |
961 | {0, 16, 0, 18, 19}, |
962 | {0, 21, 0, 23, 24}}}, |
963 | {{{0, 0, 0, 0, 0}, |
964 | {0, 31, 0, 33, 34}, |
965 | {0, 0, 0, 0, 0}, |
966 | {0, 41, 0, 43, 44}, |
967 | {0, 46, 0, 48, 49}}}}, |
968 | torch::kFloat))); |
969 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 1, 5, 5})); |
970 | } |
971 | |
972 | TEST_F(ModulesTest, MaxPool2d_MaxUnpool2d) { |
973 | MaxPool2d pool{MaxPool2dOptions(2).stride(2)}; |
974 | MaxUnpool2d unpool{MaxUnpool2dOptions(2).stride(2)}; |
975 | auto input = torch::tensor( |
976 | {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}}}}, |
977 | torch::kFloat); |
978 | torch::Tensor output, indices; |
979 | std::tie(output, indices) = pool->forward_with_indices(input); |
980 | ASSERT_TRUE(torch::allclose( |
981 | unpool(output, indices), |
982 | torch::tensor( |
983 | {{{{0, 0, 0, 0}, {0, 6, 0, 8}, {0, 0, 0, 0}, {0, 14, 0, 16}}}}, |
984 | torch::kFloat))); |
985 | |
986 | ASSERT_TRUE(torch::allclose( |
987 | unpool(output, indices, std::vector<int64_t>{1, 1, 5, 5}), |
988 | torch::tensor( |
989 | {{{{0, 0, 0, 0, 0}, |
990 | {6, 0, 8, 0, 0}, |
991 | {0, 0, 0, 14, 0}, |
992 | {16, 0, 0, 0, 0}, |
993 | {0, 0, 0, 0, 0}}}}, |
994 | torch::kFloat))); |
995 | } |
996 | |
997 | TEST_F(ModulesTest, MaxUnpool3d) { |
998 | auto indices = torch::tensor({{{{{26}}}}}, torch::kLong); |
999 | auto x = torch::tensor( |
1000 | {{{{{26}}}}}, torch::dtype(torch::kFloat).requires_grad(true)); |
1001 | auto model = MaxUnpool3d{3}; |
1002 | auto y = model->forward(x, indices); |
1003 | |
1004 | ASSERT_EQ(y.dim(), 5); |
1005 | ASSERT_TRUE(torch::allclose( |
1006 | y, |
1007 | torch::tensor( |
1008 | {{{{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}, |
1009 | {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}, |
1010 | {{0, 0, 0}, {0, 0, 0}, {0, 0, 26}}}}}, |
1011 | torch::kFloat))); |
1012 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3, 3, 3})); |
1013 | } |
1014 | |
1015 | TEST_F(ModulesTest, MaxUnpool3dOutputSize) { |
1016 | auto indices = torch::tensor( |
1017 | {{{{{21, 23}, {29, 31}}, {{53, 55}, {61, 63}}}}}, torch::kLong); |
1018 | auto x = torch::tensor( |
1019 | {{{{{21, 23}, {29, 31}}, {{53, 55}, {61, 63}}}}}, |
1020 | torch::dtype(torch::kFloat).requires_grad(true)); |
1021 | auto model = MaxUnpool3d{MaxUnpool3dOptions(3).stride(2).padding(1)}; |
1022 | auto y = model->forward(x, indices, std::vector<int64_t>({1, 1, 4, 4, 4})); |
1023 | |
1024 | ASSERT_EQ(y.dim(), 5); |
1025 | ASSERT_TRUE(torch::allclose( |
1026 | y, |
1027 | torch::tensor( |
1028 | {{{{{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}}, |
1029 | {{0, 0, 0, 0}, {0, 21, 0, 23}, {0, 0, 0, 0}, {0, 29, 0, 31}}, |
1030 | {{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}}, |
1031 | {{0, 0, 0, 0}, {0, 53, 0, 55}, {0, 0, 0, 0}, {0, 61, 0, 63}}}}}, |
1032 | torch::kFloat))); |
1033 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 4, 4, 4})); |
1034 | } |
1035 | |
1036 | TEST_F(ModulesTest, MaxPool3d_MaxUnpool3d) { |
1037 | MaxPool3d pool{MaxPool3dOptions(3).stride(2)}; |
1038 | MaxUnpool3d unpool{MaxUnpool3dOptions(3).stride(2)}; |
1039 | auto input = torch::randn({20, 16, 51, 33, 15}); |
1040 | torch::Tensor output, indices; |
1041 | std::tie(output, indices) = pool->forward_with_indices(input); |
1042 | auto unpooled_output = unpool(output, indices); |
1043 | ASSERT_EQ( |
1044 | unpooled_output.sizes(), std::vector<int64_t>({20, 16, 51, 33, 15})); |
1045 | } |
1046 | |
1047 | TEST_F(ModulesTest, Linear) { |
1048 | { |
1049 | Linear model(5, 2); |
1050 | auto x = torch::randn({10, 5}, torch::requires_grad()); |
1051 | auto y = model(x); |
1052 | torch::Tensor s = y.sum(); |
1053 | |
1054 | s.backward(); |
1055 | ASSERT_EQ(y.ndimension(), 2); |
1056 | ASSERT_EQ(s.ndimension(), 0); |
1057 | ASSERT_EQ(y.size(0), 10); |
1058 | ASSERT_EQ(y.size(1), 2); |
1059 | |
1060 | ASSERT_EQ(model->weight.grad().numel(), 2 * 5); |
1061 | |
1062 | auto y_exp = torch::addmm(model->bias, x, model->weight.t()); |
1063 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1064 | } |
1065 | { |
1066 | Linear model(LinearOptions(5, 2).bias(false)); |
1067 | auto x = torch::randn({10, 5}, torch::requires_grad()); |
1068 | auto y = model(x); |
1069 | torch::Tensor s = y.sum(); |
1070 | |
1071 | s.backward(); |
1072 | ASSERT_EQ(y.ndimension(), 2); |
1073 | ASSERT_EQ(s.ndimension(), 0); |
1074 | ASSERT_EQ(y.size(0), 10); |
1075 | ASSERT_EQ(y.size(1), 2); |
1076 | |
1077 | ASSERT_EQ(model->weight.grad().numel(), 2 * 5); |
1078 | |
1079 | auto y_exp = torch::mm(x, model->weight.t()); |
1080 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1081 | } |
1082 | } |
1083 | |
1084 | TEST_F(ModulesTest, LocalResponseNorm) { |
1085 | { |
1086 | LocalResponseNorm model(LocalResponseNormOptions(2)); |
1087 | const auto x = |
1088 | torch::arange(100., 136, torch::requires_grad()).reshape({2, 3, 3, 2}); |
1089 | auto y = model(x); |
1090 | const auto y_exp = torch::tensor( |
1091 | {{{{73.7788, 74.1462}, {74.5031, 74.8572}, {75.2010, 75.5420}}, |
1092 | |
1093 | {{61.6057, 61.7227}, {61.8347, 61.9418}, {62.0441, 62.1418}}, |
1094 | |
1095 | {{62.2349, 62.3235}, {62.4077, 62.4877}, {62.5635, 62.6353}}}, |
1096 | |
1097 | {{{79.3915, 79.6491}, {79.8978, 80.1446}, {80.3827, 80.6190}}, |
1098 | |
1099 | {{63.0317, 63.0742}, {63.1135, 63.1496}, {63.1826, 63.2126}}, |
1100 | |
1101 | {{63.2396, 63.2637}, {63.2850, 63.3036}, {63.3195, 63.3328}}}}, |
1102 | torch::kFloat); |
1103 | torch::Tensor s = y.sum(); |
1104 | |
1105 | s.backward(); |
1106 | ASSERT_EQ(y.ndimension(), 4); |
1107 | ASSERT_EQ(s.ndimension(), 0); |
1108 | ASSERT_EQ(y.sizes(), x.sizes()); |
1109 | ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7)); |
1110 | } |
1111 | } |
1112 | |
1113 | TEST_F(ModulesTest, LayerNorm) { |
1114 | LayerNorm model(LayerNormOptions({2, 2}).eps(2e-5)); |
1115 | auto x = torch::randn({2, 2}, torch::requires_grad()); |
1116 | auto y = model(x); |
1117 | auto y_exp = torch::layer_norm(x, {2, 2}, model->weight, model->bias, 2e-5); |
1118 | torch::Tensor s = y.sum(); |
1119 | |
1120 | s.backward(); |
1121 | ASSERT_EQ(y.ndimension(), 2); |
1122 | ASSERT_EQ(s.ndimension(), 0); |
1123 | for (const auto i : c10::irange(2)) { |
1124 | ASSERT_EQ(y.size(i), 2); |
1125 | } |
1126 | |
1127 | ASSERT_EQ(model->weight.grad().numel(), 2 * 2); |
1128 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1129 | } |
1130 | |
1131 | TEST_F(ModulesTest, GroupNorm) { |
1132 | GroupNorm model(GroupNormOptions(2, 2).eps(2e-5)); |
1133 | auto x = torch::randn({2, 2}, torch::requires_grad()); |
1134 | auto y = model(x); |
1135 | auto y_exp = torch::group_norm(x, 2, model->weight, model->bias, 2e-5); |
1136 | torch::Tensor s = y.sum(); |
1137 | |
1138 | s.backward(); |
1139 | ASSERT_EQ(y.ndimension(), 2); |
1140 | ASSERT_EQ(s.ndimension(), 0); |
1141 | for (const auto i : c10::irange(2)) { |
1142 | ASSERT_EQ(y.size(i), 2); |
1143 | } |
1144 | |
1145 | ASSERT_EQ(model->weight.grad().numel(), 2); |
1146 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1147 | } |
1148 | |
1149 | TEST_F(ModulesTest, Bilinear) { |
1150 | Bilinear model(5, 3, 2); |
1151 | auto x1 = torch::randn({10, 5}, torch::requires_grad()); |
1152 | auto x2 = torch::randn({10, 3}, torch::requires_grad()); |
1153 | auto y = model(x1, x2); |
1154 | torch::Tensor s = y.sum(); |
1155 | |
1156 | s.backward(); |
1157 | ASSERT_EQ(y.ndimension(), 2); |
1158 | ASSERT_EQ(s.ndimension(), 0); |
1159 | ASSERT_EQ(y.size(0), 10); |
1160 | ASSERT_EQ(y.size(1), 2); |
1161 | |
1162 | ASSERT_EQ(model->weight.grad().numel(), 2 * 5 * 3); |
1163 | } |
1164 | |
1165 | TEST_F(ModulesTest, Fold) { |
1166 | { |
1167 | Fold model(FoldOptions({3, 2}, {2, 2})); |
1168 | auto input = torch::ones({1, 3 * 2 * 2, 2}, torch::requires_grad()); |
1169 | auto output = model(input); |
1170 | auto expected = torch::tensor( |
1171 | {{{{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}}, |
1172 | {{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}}, |
1173 | {{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}}}}, |
1174 | torch::kFloat); |
1175 | auto s = output.sum(); |
1176 | s.backward(); |
1177 | |
1178 | ASSERT_EQ(s.ndimension(), 0); |
1179 | ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 3, 3, 2})); |
1180 | ASSERT_TRUE(output.allclose(expected)); |
1181 | } |
1182 | { |
1183 | // input wrong dimension |
1184 | Fold model(FoldOptions({8, 8}, {3, 3})); |
1185 | ASSERT_THROWS_WITH( |
1186 | model(torch::randn({1, 3, 16, 16})), |
1187 | "Input Error: Only unbatched (2D) or batched (3D) input Tensors are supported (got 4D)" ); |
1188 | } |
1189 | } |
1190 | |
1191 | TEST_F(ModulesTest, Unfold) { |
1192 | { |
1193 | Unfold model(UnfoldOptions({2, 2}).padding(1).stride(2)); |
1194 | auto input = |
1195 | torch::arange(2., 14, torch::requires_grad()).view({1, 2, 2, 3}); |
1196 | auto output = model(input); |
1197 | auto expected = torch::tensor( |
1198 | {{{0.0, 0.0, 0.0, 6.0}, |
1199 | {0.0, 0.0, 5.0, 7.0}, |
1200 | {0.0, 3.0, 0.0, 0.0}, |
1201 | {2.0, 4.0, 0.0, 0.0}, |
1202 | {0.0, 0.0, 0.0, 12.0}, |
1203 | {0.0, 0.0, 11.0, 13.0}, |
1204 | {0.0, 9.0, 0.0, 0.0}, |
1205 | {8.0, 10.0, 0.0, 0.0}}}, |
1206 | torch::kFloat); |
1207 | auto s = output.sum(); |
1208 | s.backward(); |
1209 | |
1210 | ASSERT_EQ(s.ndimension(), 0); |
1211 | ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 8, 4})); |
1212 | ASSERT_TRUE(output.allclose(expected)); |
1213 | } |
1214 | { |
1215 | // input wrong dimension |
1216 | Unfold model(UnfoldOptions({2, 4})); |
1217 | ASSERT_THROWS_WITH( |
1218 | model(torch::randn({1, 5, 2})), |
1219 | "Input Error: Only 4D input Tensors are supported (got 3D)" ); |
1220 | } |
1221 | { |
1222 | // calculated output shape is too small |
1223 | Unfold model(UnfoldOptions({2, 3})); |
1224 | ASSERT_THROWS_WITH( |
1225 | model(torch::randn({1, 2, 2, 2})), |
1226 | "Given input with spatial size (2, 2), kernel_size=(2, 3), " |
1227 | "dilation=(1, 1), padding=(0, 0), calculated shape of the array of " |
1228 | "sliding blocks as (1, 0), but its components must be at least one." ); |
1229 | } |
1230 | } |
1231 | |
1232 | TEST_F(ModulesTest, SimpleContainer) { |
1233 | auto model = std::make_shared<SimpleContainer>(); |
1234 | auto l1 = model->add(Linear(10, 3), "l1" ); |
1235 | auto l2 = model->add(Linear(3, 5), "l2" ); |
1236 | auto l3 = model->add(Linear(5, 100), "l3" ); |
1237 | |
1238 | auto x = torch::randn({1000, 10}, torch::requires_grad()); |
1239 | x = l1(x).clamp_min(0); |
1240 | x = l2(x).clamp_min(0); |
1241 | x = l3(x).clamp_min(0); |
1242 | |
1243 | x.backward(torch::ones_like(x)); |
1244 | ASSERT_EQ(x.ndimension(), 2); |
1245 | ASSERT_EQ(x.size(0), 1000); |
1246 | ASSERT_EQ(x.size(1), 100); |
1247 | ASSERT_EQ(x.min().item<float>(), 0); |
1248 | } |
1249 | |
1250 | TEST_F(ModulesTest, EmbeddingBasic) { |
1251 | const int64_t dict_size = 10; |
1252 | Embedding model(dict_size, 2); |
1253 | ASSERT_TRUE(model->named_parameters().contains("weight" )); |
1254 | ASSERT_EQ(model->weight.ndimension(), 2); |
1255 | ASSERT_EQ(model->weight.size(0), dict_size); |
1256 | ASSERT_EQ(model->weight.size(1), 2); |
1257 | |
1258 | // Cannot get gradients to change indices (input) - only for embedding |
1259 | // params |
1260 | auto x = torch::full({10}, dict_size - 1, torch::kInt64); |
1261 | auto y = model(x); |
1262 | torch::Tensor s = y.sum(); |
1263 | |
1264 | s.backward(); |
1265 | ASSERT_EQ(y.ndimension(), 2); |
1266 | ASSERT_EQ(s.ndimension(), 0); |
1267 | ASSERT_EQ(y.size(0), 10); |
1268 | ASSERT_EQ(y.size(1), 2); |
1269 | |
1270 | ASSERT_EQ(model->weight.grad().numel(), 2 * dict_size); |
1271 | } |
1272 | |
1273 | TEST_F(ModulesTest, EmbeddingList) { |
1274 | Embedding model(6, 4); |
1275 | auto x = torch::full({2, 3}, 5, torch::kInt64); |
1276 | auto y = model(x); |
1277 | torch::Tensor s = y.sum(); |
1278 | |
1279 | s.backward(); |
1280 | ASSERT_EQ(y.ndimension(), 3); |
1281 | ASSERT_EQ(y.size(0), 2); |
1282 | ASSERT_EQ(y.size(1), 3); |
1283 | ASSERT_EQ(y.size(2), 4); |
1284 | } |
1285 | |
1286 | TEST_F(ModulesTest, EmbeddingFromPretrained) { |
1287 | auto weight = torch::tensor({{1., 2.3, 3.}, {4., 5.1, 6.3}}); |
1288 | Embedding embedding = torch::nn::Embedding::from_pretrained(weight); |
1289 | auto input = torch::tensor({1}, torch::kLong); |
1290 | ASSERT_TRUE(torch::allclose( |
1291 | embedding(input), torch::tensor({4.0000, 5.1000, 6.3000}))); |
1292 | } |
1293 | |
1294 | TEST_F(ModulesTest, EmbeddingBagFromPretrained) { |
1295 | auto weight = torch::tensor({{1., 2.3, 3.}, {4., 5.1, 6.3}}); |
1296 | EmbeddingBag embeddingbag = torch::nn::EmbeddingBag::from_pretrained(weight); |
1297 | auto input = torch::zeros({{1, 2}}, torch::kLong); |
1298 | input[0] = torch::tensor({1, 0}); |
1299 | ASSERT_TRUE(torch::allclose( |
1300 | embeddingbag(input), torch::tensor({2.5000, 3.7000, 4.6500}))); |
1301 | } |
1302 | |
1303 | TEST_F(ModulesTest, AlphaDropout) { |
1304 | AlphaDropout alpha_dropout(0.5); |
1305 | torch::Tensor x = torch::ones(100, torch::requires_grad()); |
1306 | torch::Tensor y = alpha_dropout(x); |
1307 | |
1308 | y.backward(torch::ones_like(y)); |
1309 | |
1310 | ASSERT_EQ(y.ndimension(), 1); |
1311 | ASSERT_EQ(y.size(0), 100); |
1312 | ASSERT_LT(y.sum().item<float>(), 130); // Probably |
1313 | ASSERT_GT(y.sum().item<float>(), 40); // Probably |
1314 | |
1315 | alpha_dropout->eval(); |
1316 | y = alpha_dropout(x); |
1317 | |
1318 | ASSERT_EQ(y.sum().item<float>(), 100); |
1319 | } |
1320 | |
1321 | TEST_F(ModulesTest, FeatureAlphaDropout) { |
1322 | FeatureAlphaDropout feature_alpha_dropout(0.5); |
1323 | torch::Tensor x = torch::ones({10, 10}, torch::requires_grad()); |
1324 | torch::Tensor y = feature_alpha_dropout(x); |
1325 | |
1326 | y.backward(torch::ones_like(y)); |
1327 | |
1328 | ASSERT_EQ(y.ndimension(), 2); |
1329 | ASSERT_EQ(y.size(0), 10); |
1330 | ASSERT_EQ(y.size(1), 10); |
1331 | ASSERT_LT(y.sum().item<float>(), 130); // Probably |
1332 | ASSERT_GT(y.sum().item<float>(), 40); // Probably |
1333 | |
1334 | feature_alpha_dropout->eval(); |
1335 | y = feature_alpha_dropout(x); |
1336 | |
1337 | ASSERT_EQ(y.sum().item<float>(), 100); |
1338 | } |
1339 | |
1340 | TEST_F(ModulesTest, Dropout) { |
1341 | for (const auto inplace : {false, true}) { |
1342 | Dropout dropout(DropoutOptions(0.5).inplace(inplace)); |
1343 | torch::Tensor x = torch::ones(100); |
1344 | if (!inplace) { |
1345 | x.requires_grad_(true); |
1346 | } |
1347 | torch::Tensor y = dropout(x); |
1348 | |
1349 | ASSERT_EQ(y.ndimension(), 1); |
1350 | ASSERT_EQ(y.size(0), 100); |
1351 | ASSERT_LT(y.sum().item<float>(), 130); // Probably |
1352 | ASSERT_GT(y.sum().item<float>(), 70); // Probably |
1353 | if (inplace) { |
1354 | ASSERT_TRUE(y.allclose(x)); |
1355 | } else { |
1356 | y.backward(torch::ones_like(y)); |
1357 | } |
1358 | |
1359 | dropout->eval(); |
1360 | y = dropout(torch::ones(100)); |
1361 | ASSERT_EQ(y.sum().item<float>(), 100); |
1362 | } |
1363 | } |
1364 | |
1365 | TEST_F(ModulesTest, Dropout2d) { |
1366 | auto p = 0.5; |
1367 | for (const auto inplace : {false, true}) { |
1368 | Dropout2d dropout(Dropout2dOptions(p).inplace(inplace)); |
1369 | torch::Tensor x = torch::empty({50, 50, 2, 2}).fill_(1 - p); |
1370 | if (!inplace) { |
1371 | x.requires_grad_(true); |
1372 | } |
1373 | torch::Tensor y = dropout(x); |
1374 | |
1375 | ASSERT_EQ(y.ndimension(), 4); |
1376 | ASSERT_EQ(y.size(0), 50); |
1377 | ASSERT_EQ(y.size(1), 50); |
1378 | ASSERT_EQ(y.size(2), 2); |
1379 | ASSERT_EQ(y.size(3), 2); |
1380 | ASSERT_LT((y.mean() - (1 - p)).abs().item<float>(), 0.05); |
1381 | |
1382 | if (inplace) { |
1383 | ASSERT_TRUE(y.allclose(x)); |
1384 | } else { |
1385 | y.backward(torch::ones_like(y)); |
1386 | } |
1387 | |
1388 | dropout->eval(); |
1389 | y = dropout(torch::ones({2, 2, 10, 10})); |
1390 | ASSERT_EQ(y.sum().item<float>(), 400); |
1391 | } |
1392 | } |
1393 | |
1394 | TEST_F(ModulesTest, Dropout3d) { |
1395 | for (const auto inplace : {false, true}) { |
1396 | auto p = 0.5; |
1397 | Dropout3d dropout(Dropout3dOptions(p).inplace(inplace)); |
1398 | torch::Tensor x = torch::empty({50, 50, 2, 2, 2}).fill_(1 - p); |
1399 | if (!inplace) { |
1400 | x.requires_grad_(true); |
1401 | } |
1402 | torch::Tensor y = dropout(x); |
1403 | |
1404 | ASSERT_EQ(y.ndimension(), 5); |
1405 | ASSERT_EQ(y.size(0), 50); |
1406 | ASSERT_EQ(y.size(1), 50); |
1407 | ASSERT_EQ(y.size(2), 2); |
1408 | ASSERT_EQ(y.size(3), 2); |
1409 | ASSERT_EQ(y.size(4), 2); |
1410 | ASSERT_LT((y.mean() - (1 - p)).abs().item<float>(), 0.05); |
1411 | |
1412 | if (inplace) { |
1413 | ASSERT_TRUE(y.allclose(x)); |
1414 | } else { |
1415 | y.backward(torch::ones_like(y)); |
1416 | } |
1417 | |
1418 | dropout->eval(); |
1419 | y = dropout(torch::ones({4, 4, 5, 5})); |
1420 | ASSERT_EQ(y.sum().item<float>(), 400); |
1421 | } |
1422 | } |
1423 | |
1424 | TEST_F(ModulesTest, Parameters) { |
1425 | auto model = std::make_shared<NestedModel>(); |
1426 | auto parameters = model->named_parameters(); |
1427 | ASSERT_EQ(parameters["param" ].size(0), 3); |
1428 | ASSERT_EQ(parameters["param" ].size(1), 2); |
1429 | ASSERT_EQ(parameters["param" ].size(2), 21); |
1430 | ASSERT_EQ(parameters["l1.bias" ].size(0), 20); |
1431 | ASSERT_EQ(parameters["l1.weight" ].size(0), 20); |
1432 | ASSERT_EQ(parameters["l1.weight" ].size(1), 5); |
1433 | ASSERT_EQ(parameters["test.l1.bias" ].size(0), 3); |
1434 | ASSERT_EQ(parameters["test.l1.weight" ].size(0), 3); |
1435 | ASSERT_EQ(parameters["test.l1.weight" ].size(1), 10); |
1436 | ASSERT_EQ(parameters["test.l2.bias" ].size(0), 5); |
1437 | ASSERT_EQ(parameters["test.l2.weight" ].size(0), 5); |
1438 | ASSERT_EQ(parameters["test.l2.weight" ].size(1), 3); |
1439 | ASSERT_EQ(parameters["test.l3.bias" ].size(0), 100); |
1440 | ASSERT_EQ(parameters["test.l3.weight" ].size(0), 100); |
1441 | ASSERT_EQ(parameters["test.l3.weight" ].size(1), 5); |
1442 | } |
1443 | |
1444 | TEST_F(ModulesTest, FunctionalCallsSuppliedFunction) { |
1445 | bool was_called = false; |
1446 | auto functional = Functional([&was_called](torch::Tensor input) { |
1447 | was_called = true; |
1448 | return input; |
1449 | }); |
1450 | auto output = functional(torch::ones(5, torch::requires_grad())); |
1451 | ASSERT_TRUE(was_called); |
1452 | ASSERT_TRUE(output.equal(torch::ones(5, torch::requires_grad()))); |
1453 | |
1454 | was_called = false; |
1455 | // Use the call operator overload here. |
1456 | output = functional(torch::ones(5, torch::requires_grad())); |
1457 | ASSERT_TRUE(was_called); |
1458 | ASSERT_TRUE(output.equal(torch::ones(5, torch::requires_grad()))); |
1459 | } |
1460 | |
1461 | TEST_F(ModulesTest, FunctionalWithTorchFunction) { |
1462 | auto functional = Functional(torch::relu); |
1463 | ASSERT_EQ(functional(torch::ones({})).item<float>(), 1); |
1464 | ASSERT_EQ(functional(torch::ones({})).item<float>(), 1); |
1465 | ASSERT_EQ(functional(torch::ones({}) * -1).item<float>(), 0); |
1466 | } |
1467 | |
1468 | TEST_F(ModulesTest, FunctionalArgumentBinding) { |
1469 | auto functional = |
1470 | Functional(torch::elu, /*alpha=*/1, /*scale=*/0, /*input_scale=*/1); |
1471 | ASSERT_EQ(functional(torch::ones({})).item<float>(), 0); |
1472 | } |
1473 | |
1474 | TEST_F(ModulesTest, BatchNorm1dStateful) { |
1475 | BatchNorm1d bn(5); |
1476 | |
1477 | ASSERT_TRUE(bn->options.track_running_stats()); |
1478 | |
1479 | ASSERT_TRUE(bn->running_mean.defined()); |
1480 | ASSERT_EQ(bn->running_mean.dim(), 1); |
1481 | ASSERT_EQ(bn->running_mean.size(0), 5); |
1482 | |
1483 | ASSERT_TRUE(bn->running_var.defined()); |
1484 | ASSERT_EQ(bn->running_var.dim(), 1); |
1485 | ASSERT_EQ(bn->running_var.size(0), 5); |
1486 | |
1487 | ASSERT_TRUE(bn->num_batches_tracked.defined()); |
1488 | ASSERT_EQ(bn->num_batches_tracked.dim(), 0); |
1489 | |
1490 | ASSERT_TRUE(bn->options.affine()); |
1491 | |
1492 | ASSERT_TRUE(bn->weight.defined()); |
1493 | ASSERT_EQ(bn->weight.dim(), 1); |
1494 | ASSERT_EQ(bn->weight.size(0), 5); |
1495 | |
1496 | ASSERT_TRUE(bn->bias.defined()); |
1497 | ASSERT_EQ(bn->bias.dim(), 1); |
1498 | ASSERT_EQ(bn->bias.size(0), 5); |
1499 | } |
1500 | |
1501 | TEST_F(ModulesTest, BatchNorm1dStateless) { |
1502 | BatchNorm1d bn( |
1503 | BatchNorm1dOptions(5).track_running_stats(false).affine(false)); |
1504 | |
1505 | ASSERT_FALSE(bn->running_mean.defined()); |
1506 | ASSERT_FALSE(bn->running_var.defined()); |
1507 | ASSERT_FALSE(bn->num_batches_tracked.defined()); |
1508 | ASSERT_FALSE(bn->weight.defined()); |
1509 | ASSERT_FALSE(bn->bias.defined()); |
1510 | } |
1511 | |
1512 | TEST_F(ModulesTest, BatchNorm1d) { |
1513 | BatchNorm1d bn(5); |
1514 | bn->eval(); |
1515 | |
1516 | auto input = torch::arange(2. * 5 * 2).view({2, 5, 2}).requires_grad_(); |
1517 | auto output = bn->forward(input); |
1518 | auto expected = torch::tensor( |
1519 | {{{0.0000, 1.0000}, |
1520 | {2.0000, 3.0000}, |
1521 | {4.0000, 5.0000}, |
1522 | {6.0000, 7.0000}, |
1523 | {8.0000, 9.0000}}, |
1524 | {{10.0000, 10.9999}, |
1525 | {11.9999, 12.9999}, |
1526 | {13.9999, 14.9999}, |
1527 | {15.9999, 16.9999}, |
1528 | {17.9999, 18.9999}}}); |
1529 | ASSERT_TRUE(output.allclose(expected)); |
1530 | auto s = output.sum(); |
1531 | s.backward(); |
1532 | |
1533 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
1534 | } |
1535 | |
1536 | TEST_F(ModulesTest, BatchNorm2dStateful) { |
1537 | BatchNorm2d bn(5); |
1538 | |
1539 | ASSERT_TRUE(bn->options.track_running_stats()); |
1540 | |
1541 | ASSERT_TRUE(bn->running_mean.defined()); |
1542 | ASSERT_EQ(bn->running_mean.dim(), 1); |
1543 | ASSERT_EQ(bn->running_mean.size(0), 5); |
1544 | |
1545 | ASSERT_TRUE(bn->running_var.defined()); |
1546 | ASSERT_EQ(bn->running_var.dim(), 1); |
1547 | ASSERT_EQ(bn->running_var.size(0), 5); |
1548 | |
1549 | ASSERT_TRUE(bn->num_batches_tracked.defined()); |
1550 | ASSERT_EQ(bn->num_batches_tracked.dim(), 0); |
1551 | |
1552 | ASSERT_TRUE(bn->options.affine()); |
1553 | |
1554 | ASSERT_TRUE(bn->weight.defined()); |
1555 | ASSERT_EQ(bn->weight.dim(), 1); |
1556 | ASSERT_EQ(bn->weight.size(0), 5); |
1557 | |
1558 | ASSERT_TRUE(bn->bias.defined()); |
1559 | ASSERT_EQ(bn->bias.dim(), 1); |
1560 | ASSERT_EQ(bn->bias.size(0), 5); |
1561 | } |
1562 | |
1563 | TEST_F(ModulesTest, BatchNorm2dStateless) { |
1564 | BatchNorm2d bn( |
1565 | BatchNorm2dOptions(5).track_running_stats(false).affine(false)); |
1566 | |
1567 | ASSERT_FALSE(bn->running_mean.defined()); |
1568 | ASSERT_FALSE(bn->running_var.defined()); |
1569 | ASSERT_FALSE(bn->num_batches_tracked.defined()); |
1570 | ASSERT_FALSE(bn->weight.defined()); |
1571 | ASSERT_FALSE(bn->bias.defined()); |
1572 | } |
1573 | |
1574 | TEST_F(ModulesTest, BatchNorm2d) { |
1575 | BatchNorm2d bn(5); |
1576 | bn->eval(); |
1577 | |
1578 | auto input = |
1579 | torch::arange(2. * 5 * 2 * 2).view({2, 5, 2, 2}).requires_grad_(); |
1580 | auto output = bn->forward(input); |
1581 | auto expected = torch::tensor( |
1582 | {{{{0.0000, 1.0000}, {2.0000, 3.0000}}, |
1583 | {{4.0000, 5.0000}, {6.0000, 7.0000}}, |
1584 | {{8.0000, 9.0000}, {10.0000, 10.9999}}, |
1585 | {{11.9999, 12.9999}, {13.9999, 14.9999}}, |
1586 | {{15.9999, 16.9999}, {17.9999, 18.9999}}}, |
1587 | {{{19.9999, 20.9999}, {21.9999, 22.9999}}, |
1588 | {{23.9999, 24.9999}, {25.9999, 26.9999}}, |
1589 | {{27.9999, 28.9999}, {29.9998, 30.9998}}, |
1590 | {{31.9998, 32.9998}, {33.9998, 34.9998}}, |
1591 | {{35.9998, 36.9998}, {37.9998, 38.9998}}}}); |
1592 | ASSERT_TRUE(output.allclose(expected)); |
1593 | auto s = output.sum(); |
1594 | s.backward(); |
1595 | |
1596 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
1597 | } |
1598 | |
1599 | TEST_F(ModulesTest, BatchNorm3dStateful) { |
1600 | BatchNorm3d bn(5); |
1601 | |
1602 | ASSERT_TRUE(bn->options.track_running_stats()); |
1603 | |
1604 | ASSERT_TRUE(bn->running_mean.defined()); |
1605 | ASSERT_EQ(bn->running_mean.dim(), 1); |
1606 | ASSERT_EQ(bn->running_mean.size(0), 5); |
1607 | |
1608 | ASSERT_TRUE(bn->running_var.defined()); |
1609 | ASSERT_EQ(bn->running_var.dim(), 1); |
1610 | ASSERT_EQ(bn->running_var.size(0), 5); |
1611 | |
1612 | ASSERT_TRUE(bn->num_batches_tracked.defined()); |
1613 | ASSERT_EQ(bn->num_batches_tracked.dim(), 0); |
1614 | |
1615 | ASSERT_TRUE(bn->options.affine()); |
1616 | |
1617 | ASSERT_TRUE(bn->weight.defined()); |
1618 | ASSERT_EQ(bn->weight.dim(), 1); |
1619 | ASSERT_EQ(bn->weight.size(0), 5); |
1620 | |
1621 | ASSERT_TRUE(bn->bias.defined()); |
1622 | ASSERT_EQ(bn->bias.dim(), 1); |
1623 | ASSERT_EQ(bn->bias.size(0), 5); |
1624 | } |
1625 | |
1626 | TEST_F(ModulesTest, BatchNorm3dStateless) { |
1627 | BatchNorm3d bn( |
1628 | BatchNorm3dOptions(5).track_running_stats(false).affine(false)); |
1629 | |
1630 | ASSERT_FALSE(bn->running_mean.defined()); |
1631 | ASSERT_FALSE(bn->running_var.defined()); |
1632 | ASSERT_FALSE(bn->num_batches_tracked.defined()); |
1633 | ASSERT_FALSE(bn->weight.defined()); |
1634 | ASSERT_FALSE(bn->bias.defined()); |
1635 | } |
1636 | |
1637 | TEST_F(ModulesTest, BatchNorm3d) { |
1638 | BatchNorm3d bn(5); |
1639 | bn->eval(); |
1640 | |
1641 | auto input = |
1642 | torch::arange(2. * 5 * 2 * 2 * 2).view({2, 5, 2, 2, 2}).requires_grad_(); |
1643 | auto output = bn->forward(input); |
1644 | auto expected = torch::tensor( |
1645 | {{{{{0.0000, 1.0000}, {2.0000, 3.0000}}, |
1646 | {{4.0000, 5.0000}, {6.0000, 7.0000}}}, |
1647 | {{{8.0000, 9.0000}, {10.0000, 10.9999}}, |
1648 | {{11.9999, 12.9999}, {13.9999, 14.9999}}}, |
1649 | {{{15.9999, 16.9999}, {17.9999, 18.9999}}, |
1650 | {{19.9999, 20.9999}, {21.9999, 22.9999}}}, |
1651 | {{{23.9999, 24.9999}, {25.9999, 26.9999}}, |
1652 | {{27.9999, 28.9999}, {29.9998, 30.9998}}}, |
1653 | {{{31.9998, 32.9998}, {33.9998, 34.9998}}, |
1654 | {{35.9998, 36.9998}, {37.9998, 38.9998}}}}, |
1655 | {{{{39.9998, 40.9998}, {41.9998, 42.9998}}, |
1656 | {{43.9998, 44.9998}, {45.9998, 46.9998}}}, |
1657 | {{{47.9998, 48.9998}, {49.9997, 50.9997}}, |
1658 | {{51.9997, 52.9997}, {53.9997, 54.9997}}}, |
1659 | {{{55.9997, 56.9997}, {57.9997, 58.9997}}, |
1660 | {{59.9997, 60.9997}, {61.9997, 62.9997}}}, |
1661 | {{{63.9997, 64.9997}, {65.9997, 66.9997}}, |
1662 | {{67.9997, 68.9997}, {69.9996, 70.9996}}}, |
1663 | {{{71.9996, 72.9996}, {73.9996, 74.9996}}, |
1664 | {{75.9996, 76.9996}, {77.9996, 78.9996}}}}}); |
1665 | ASSERT_TRUE(output.allclose(expected)); |
1666 | auto s = output.sum(); |
1667 | s.backward(); |
1668 | |
1669 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
1670 | } |
1671 | |
1672 | TEST_F(ModulesTest, InstanceNorm1dStateful) { |
1673 | InstanceNorm1d instance_norm( |
1674 | InstanceNorm1dOptions(5).track_running_stats(true).affine(true)); |
1675 | |
1676 | ASSERT_TRUE(instance_norm->options.track_running_stats()); |
1677 | |
1678 | ASSERT_TRUE(instance_norm->running_mean.defined()); |
1679 | ASSERT_EQ(instance_norm->running_mean.dim(), 1); |
1680 | ASSERT_EQ(instance_norm->running_mean.size(0), 5); |
1681 | |
1682 | ASSERT_TRUE(instance_norm->running_var.defined()); |
1683 | ASSERT_EQ(instance_norm->running_var.dim(), 1); |
1684 | ASSERT_EQ(instance_norm->running_var.size(0), 5); |
1685 | |
1686 | ASSERT_TRUE(instance_norm->num_batches_tracked.defined()); |
1687 | ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0); |
1688 | |
1689 | ASSERT_TRUE(instance_norm->options.affine()); |
1690 | |
1691 | ASSERT_TRUE(instance_norm->weight.defined()); |
1692 | ASSERT_EQ(instance_norm->weight.dim(), 1); |
1693 | ASSERT_EQ(instance_norm->weight.size(0), 5); |
1694 | |
1695 | ASSERT_TRUE(instance_norm->bias.defined()); |
1696 | ASSERT_EQ(instance_norm->bias.dim(), 1); |
1697 | ASSERT_EQ(instance_norm->bias.size(0), 5); |
1698 | } |
1699 | |
1700 | TEST_F(ModulesTest, InstanceNorm1dStateless) { |
1701 | InstanceNorm1d instance_norm( |
1702 | InstanceNorm1dOptions(5).track_running_stats(false).affine(false)); |
1703 | |
1704 | ASSERT_FALSE(instance_norm->running_mean.defined()); |
1705 | ASSERT_FALSE(instance_norm->running_var.defined()); |
1706 | ASSERT_FALSE(instance_norm->num_batches_tracked.defined()); |
1707 | ASSERT_FALSE(instance_norm->weight.defined()); |
1708 | ASSERT_FALSE(instance_norm->bias.defined()); |
1709 | } |
1710 | |
1711 | TEST_F(ModulesTest, InstanceNorm1d) { |
1712 | InstanceNorm1d instance_norm(5); |
1713 | instance_norm->eval(); |
1714 | |
1715 | auto input = torch::arange(2. * 5 * 2).view({2, 5, 2}).requires_grad_(); |
1716 | auto output = instance_norm->forward(input); |
1717 | auto expected = torch::tensor( |
1718 | {{{-1.0000, 1.0000}, |
1719 | {-1.0000, 1.0000}, |
1720 | {-1.0000, 1.0000}, |
1721 | {-1.0000, 1.0000}, |
1722 | {-1.0000, 1.0000}}, |
1723 | {{-1.0000, 1.0000}, |
1724 | {-1.0000, 1.0000}, |
1725 | {-1.0000, 1.0000}, |
1726 | {-1.0000, 1.0000}, |
1727 | {-1.0000, 1.0000}}}); |
1728 | ASSERT_TRUE(output.allclose(expected, 1e-3)); |
1729 | auto s = output.sum(); |
1730 | s.backward(); |
1731 | |
1732 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
1733 | } |
1734 | |
1735 | TEST_F(ModulesTest, InstanceNorm2dStateful) { |
1736 | InstanceNorm2d instance_norm( |
1737 | InstanceNorm2dOptions(5).track_running_stats(true).affine(true)); |
1738 | |
1739 | ASSERT_TRUE(instance_norm->options.track_running_stats()); |
1740 | |
1741 | ASSERT_TRUE(instance_norm->running_mean.defined()); |
1742 | ASSERT_EQ(instance_norm->running_mean.dim(), 1); |
1743 | ASSERT_EQ(instance_norm->running_mean.size(0), 5); |
1744 | |
1745 | ASSERT_TRUE(instance_norm->running_var.defined()); |
1746 | ASSERT_EQ(instance_norm->running_var.dim(), 1); |
1747 | ASSERT_EQ(instance_norm->running_var.size(0), 5); |
1748 | |
1749 | ASSERT_TRUE(instance_norm->num_batches_tracked.defined()); |
1750 | ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0); |
1751 | |
1752 | ASSERT_TRUE(instance_norm->options.affine()); |
1753 | |
1754 | ASSERT_TRUE(instance_norm->weight.defined()); |
1755 | ASSERT_EQ(instance_norm->weight.dim(), 1); |
1756 | ASSERT_EQ(instance_norm->weight.size(0), 5); |
1757 | |
1758 | ASSERT_TRUE(instance_norm->bias.defined()); |
1759 | ASSERT_EQ(instance_norm->bias.dim(), 1); |
1760 | ASSERT_EQ(instance_norm->bias.size(0), 5); |
1761 | } |
1762 | |
1763 | TEST_F(ModulesTest, InstanceNorm2dStateless) { |
1764 | InstanceNorm2d instance_norm( |
1765 | InstanceNorm2dOptions(5).track_running_stats(false).affine(false)); |
1766 | |
1767 | ASSERT_FALSE(instance_norm->running_mean.defined()); |
1768 | ASSERT_FALSE(instance_norm->running_var.defined()); |
1769 | ASSERT_FALSE(instance_norm->num_batches_tracked.defined()); |
1770 | ASSERT_FALSE(instance_norm->weight.defined()); |
1771 | ASSERT_FALSE(instance_norm->bias.defined()); |
1772 | } |
1773 | |
1774 | TEST_F(ModulesTest, InstanceNorm2d) { |
1775 | InstanceNorm2d instance_norm(5); |
1776 | instance_norm->eval(); |
1777 | |
1778 | auto input = |
1779 | torch::arange(2. * 5 * 2 * 2).view({2, 5, 2, 2}).requires_grad_(); |
1780 | auto output = instance_norm->forward(input); |
1781 | auto expected = torch::tensor( |
1782 | {{{{-1.3416, -0.4472}, {0.4472, 1.3416}}, |
1783 | {{-1.3416, -0.4472}, {0.4472, 1.3416}}, |
1784 | {{-1.3416, -0.4472}, {0.4472, 1.3416}}, |
1785 | {{-1.3416, -0.4472}, {0.4472, 1.3416}}, |
1786 | {{-1.3416, -0.4472}, {0.4472, 1.3416}}}, |
1787 | {{{-1.3416, -0.4472}, {0.4472, 1.3416}}, |
1788 | {{-1.3416, -0.4472}, {0.4472, 1.3416}}, |
1789 | {{-1.3416, -0.4472}, {0.4472, 1.3416}}, |
1790 | {{-1.3416, -0.4472}, {0.4472, 1.3416}}, |
1791 | {{-1.3416, -0.4472}, {0.4472, 1.3416}}}}); |
1792 | ASSERT_TRUE(output.allclose(expected, 1e-3)); |
1793 | auto s = output.sum(); |
1794 | s.backward(); |
1795 | |
1796 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
1797 | } |
1798 | |
1799 | TEST_F(ModulesTest, InstanceNorm3dStateful) { |
1800 | InstanceNorm3d instance_norm( |
1801 | InstanceNorm3dOptions(5).track_running_stats(true).affine(true)); |
1802 | |
1803 | ASSERT_TRUE(instance_norm->options.track_running_stats()); |
1804 | |
1805 | ASSERT_TRUE(instance_norm->running_mean.defined()); |
1806 | ASSERT_EQ(instance_norm->running_mean.dim(), 1); |
1807 | ASSERT_EQ(instance_norm->running_mean.size(0), 5); |
1808 | |
1809 | ASSERT_TRUE(instance_norm->running_var.defined()); |
1810 | ASSERT_EQ(instance_norm->running_var.dim(), 1); |
1811 | ASSERT_EQ(instance_norm->running_var.size(0), 5); |
1812 | |
1813 | ASSERT_TRUE(instance_norm->num_batches_tracked.defined()); |
1814 | ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0); |
1815 | |
1816 | ASSERT_TRUE(instance_norm->options.affine()); |
1817 | |
1818 | ASSERT_TRUE(instance_norm->weight.defined()); |
1819 | ASSERT_EQ(instance_norm->weight.dim(), 1); |
1820 | ASSERT_EQ(instance_norm->weight.size(0), 5); |
1821 | |
1822 | ASSERT_TRUE(instance_norm->bias.defined()); |
1823 | ASSERT_EQ(instance_norm->bias.dim(), 1); |
1824 | ASSERT_EQ(instance_norm->bias.size(0), 5); |
1825 | } |
1826 | |
1827 | TEST_F(ModulesTest, InstanceNorm3dStateless) { |
1828 | InstanceNorm3d instance_norm( |
1829 | InstanceNorm3dOptions(5).track_running_stats(false).affine(false)); |
1830 | |
1831 | ASSERT_FALSE(instance_norm->running_mean.defined()); |
1832 | ASSERT_FALSE(instance_norm->running_var.defined()); |
1833 | ASSERT_FALSE(instance_norm->num_batches_tracked.defined()); |
1834 | ASSERT_FALSE(instance_norm->weight.defined()); |
1835 | ASSERT_FALSE(instance_norm->bias.defined()); |
1836 | } |
1837 | |
1838 | TEST_F(ModulesTest, InstanceNorm3d) { |
1839 | InstanceNorm3d instance_norm(5); |
1840 | instance_norm->eval(); |
1841 | |
1842 | auto input = |
1843 | torch::arange(2. * 5 * 2 * 2 * 2).view({2, 5, 2, 2, 2}).requires_grad_(); |
1844 | auto output = instance_norm->forward(input); |
1845 | auto expected = torch::tensor( |
1846 | {{{{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, |
1847 | {{0.2182, 0.6547}, {1.0911, 1.5275}}}, |
1848 | {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, |
1849 | {{0.2182, 0.6547}, {1.0911, 1.5275}}}, |
1850 | {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, |
1851 | {{0.2182, 0.6547}, {1.0911, 1.5275}}}, |
1852 | {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, |
1853 | {{0.2182, 0.6547}, {1.0911, 1.5275}}}, |
1854 | {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, |
1855 | {{0.2182, 0.6547}, {1.0911, 1.5275}}}}, |
1856 | {{{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, |
1857 | {{0.2182, 0.6547}, {1.0911, 1.5275}}}, |
1858 | {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, |
1859 | {{0.2182, 0.6547}, {1.0911, 1.5275}}}, |
1860 | {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, |
1861 | {{0.2182, 0.6547}, {1.0911, 1.5275}}}, |
1862 | {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, |
1863 | {{0.2182, 0.6547}, {1.0911, 1.5275}}}, |
1864 | {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, |
1865 | {{0.2182, 0.6547}, {1.0911, 1.5275}}}}}); |
1866 | ASSERT_TRUE(output.allclose(expected, 1e-3)); |
1867 | auto s = output.sum(); |
1868 | s.backward(); |
1869 | |
1870 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
1871 | } |
1872 | |
1873 | TEST_F(ModulesTest, Linear_CUDA) { |
1874 | Linear model(5, 2); |
1875 | model->to(torch::kCUDA); |
1876 | auto x = |
1877 | torch::randn({10, 5}, torch::device(torch::kCUDA).requires_grad(true)); |
1878 | auto y = model(x); |
1879 | torch::Tensor s = y.sum(); |
1880 | |
1881 | s.backward(); |
1882 | ASSERT_EQ(y.ndimension(), 2); |
1883 | ASSERT_EQ(s.ndimension(), 0); |
1884 | ASSERT_EQ(y.size(0), 10); |
1885 | ASSERT_EQ(y.size(1), 2); |
1886 | |
1887 | ASSERT_EQ(model->weight.grad().numel(), 2 * 5); |
1888 | } |
1889 | |
1890 | TEST_F(ModulesTest, Linear2_CUDA) { |
1891 | Linear model(5, 2); |
1892 | model->to(torch::kCUDA); |
1893 | model->to(torch::kCPU); |
1894 | auto x = torch::randn({10, 5}, torch::requires_grad()); |
1895 | auto y = model(x); |
1896 | torch::Tensor s = y.sum(); |
1897 | |
1898 | s.backward(); |
1899 | ASSERT_EQ(y.ndimension(), 2); |
1900 | ASSERT_EQ(s.ndimension(), 0); |
1901 | ASSERT_EQ(y.size(0), 10); |
1902 | ASSERT_EQ(y.size(1), 2); |
1903 | |
1904 | ASSERT_EQ(model->weight.grad().numel(), 2 * 5); |
1905 | } |
1906 | |
1907 | TEST_F(ModulesTest, L1Loss) { |
1908 | L1Loss loss; |
1909 | auto input = torch::randn({5, 6}, torch::requires_grad()); |
1910 | auto target = torch::empty({5, 6}).random_(2); |
1911 | auto output = loss->forward(torch::sigmoid(input), target); |
1912 | auto s = output.sum(); |
1913 | s.backward(); |
1914 | |
1915 | ASSERT_EQ(output.sizes(), std::vector<int64_t>()); |
1916 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
1917 | } |
1918 | |
1919 | TEST_F(ModulesTest, MSELoss) { |
1920 | MSELoss loss; |
1921 | auto input = torch::randn({5, 6}, torch::requires_grad()); |
1922 | auto target = torch::empty({5, 6}).random_(2); |
1923 | auto output = loss->forward(torch::sigmoid(input), target); |
1924 | auto s = output.sum(); |
1925 | s.backward(); |
1926 | |
1927 | ASSERT_EQ(output.sizes(), torch::IntArrayRef()); |
1928 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
1929 | } |
1930 | |
1931 | TEST_F(ModulesTest, BCELoss) { |
1932 | BCELoss loss; |
1933 | auto input = torch::randn({5, 6}, torch::requires_grad()); |
1934 | auto target = torch::empty({5, 6}).random_(2); |
1935 | auto output = loss->forward(torch::sigmoid(input), target); |
1936 | auto s = output.sum(); |
1937 | s.backward(); |
1938 | |
1939 | ASSERT_EQ(output.sizes(), torch::IntArrayRef()); |
1940 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
1941 | } |
1942 | |
1943 | TEST_F(ModulesTest, KLDivLoss) { |
1944 | KLDivLoss loss; |
1945 | auto input = torch::randn({5, 6}, torch::requires_grad()); |
1946 | auto target = torch::empty({5, 6}).random_(2); |
1947 | auto output = loss->forward(torch::sigmoid(input), target); |
1948 | auto s = output.sum(); |
1949 | s.backward(); |
1950 | |
1951 | ASSERT_EQ(output.sizes(), torch::IntArrayRef()); |
1952 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
1953 | } |
1954 | |
1955 | TEST_F(ModulesTest, HingeEmbeddingLoss) { |
1956 | HingeEmbeddingLoss loss(HingeEmbeddingLossOptions().margin(2)); |
1957 | auto input = torch::tensor( |
1958 | {{2, 22, 4}, {20, 10, 0}}, |
1959 | torch::dtype(torch::kFloat).requires_grad(true)); |
1960 | auto target = torch::tensor({{2, 6, 4}, {1, 10, 0}}, torch::kFloat); |
1961 | auto output = loss->forward(input, target); |
1962 | auto expected = torch::tensor({10}, torch::kFloat); |
1963 | auto s = output.sum(); |
1964 | s.backward(); |
1965 | |
1966 | ASSERT_TRUE(output.allclose(expected)); |
1967 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
1968 | } |
1969 | |
1970 | TEST_F(ModulesTest, MultiMarginLoss) { |
1971 | auto weight = torch::tensor({0.3, 0.3, 0.4}, torch::kFloat); |
1972 | MultiMarginLoss loss(MultiMarginLossOptions().margin(2).weight(weight)); |
1973 | auto input = torch::tensor( |
1974 | {{0.2, 0.2, 0.6}, {0.1, 0.8, 0.1}, {0.9, 0.09, 0.01}}, |
1975 | torch::dtype(torch::kFloat).requires_grad(true)); |
1976 | auto target = torch::tensor({2, 1, 0}, torch::kLong); |
1977 | auto output = loss->forward(input, target); |
1978 | auto expected = torch::tensor({0.305556}, torch::kFloat); |
1979 | auto s = output.sum(); |
1980 | s.backward(); |
1981 | |
1982 | ASSERT_TRUE(output.allclose(expected, 1e-04)); |
1983 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
1984 | } |
1985 | |
1986 | TEST_F(ModulesTest, CosineEmbeddingLoss) { |
1987 | CosineEmbeddingLoss cos(CosineEmbeddingLossOptions().margin(0.5)); |
1988 | auto input1 = torch::tensor( |
1989 | {{2, 3, 4}, {6, 2, 4}}, torch::dtype(torch::kFloat).requires_grad(true)); |
1990 | auto input2 = torch::tensor( |
1991 | {{2, 3, 5}, {9, 12, 0}}, torch::dtype(torch::kFloat).requires_grad(true)); |
1992 | auto target = torch::tensor({1, -1}); |
1993 | auto output = cos(input1, input2, target); |
1994 | auto expected = torch::tensor({0.1004}, torch::kFloat); |
1995 | auto s = output.sum(); |
1996 | s.backward(); |
1997 | |
1998 | ASSERT_TRUE(output.allclose(expected, 1e-4)); |
1999 | ASSERT_EQ(input1.sizes(), input1.grad().sizes()); |
2000 | ASSERT_EQ(input2.sizes(), input2.grad().sizes()); |
2001 | } |
2002 | |
2003 | TEST_F(ModulesTest, SmoothL1LossDefaultOptions) { |
2004 | SmoothL1Loss loss; |
2005 | auto input = torch::tensor( |
2006 | {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); |
2007 | auto target = torch::tensor({0., 1., 5.}, torch::kFloat); |
2008 | auto output = loss(input, target); |
2009 | auto expected = torch::tensor(0.0233335, torch::kFloat); |
2010 | auto s = output.sum(); |
2011 | s.backward(); |
2012 | |
2013 | ASSERT_TRUE(output.allclose(expected)); |
2014 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
2015 | } |
2016 | |
2017 | TEST_F(ModulesTest, HuberLossDefaultOptions) { |
2018 | HuberLoss loss; |
2019 | auto input = torch::tensor( |
2020 | {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); |
2021 | auto target = torch::tensor({0., 1., 5.}, torch::kFloat); |
2022 | auto output = loss(input, target); |
2023 | auto expected = torch::tensor(0.0233335, torch::kFloat); |
2024 | auto s = output.sum(); |
2025 | s.backward(); |
2026 | |
2027 | ASSERT_TRUE(output.allclose(expected)); |
2028 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
2029 | } |
2030 | |
2031 | TEST_F(ModulesTest, MultiLabelMarginLossDefaultOptions) { |
2032 | MultiLabelMarginLoss loss; |
2033 | auto input = torch::tensor( |
2034 | {{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true)); |
2035 | auto target = torch::tensor({{3, 0, -1, 1}}, torch::kLong); |
2036 | auto output = loss->forward(input, target); |
2037 | auto expected = torch::tensor({0.8500}, torch::kFloat); |
2038 | auto s = output.sum(); |
2039 | s.backward(); |
2040 | |
2041 | ASSERT_TRUE(output.allclose(expected)); |
2042 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
2043 | } |
2044 | |
2045 | TEST_F(ModulesTest, SmoothL1LossNoReduction) { |
2046 | SmoothL1Loss loss(/*reduction=*/torch::kNone); |
2047 | auto input = torch::tensor( |
2048 | {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); |
2049 | auto target = torch::tensor({0., 1., 5.}, torch::kFloat); |
2050 | auto output = loss(input, target); |
2051 | auto expected = torch::tensor({0.005, 0.02, 0.045}, torch::kFloat); |
2052 | auto s = output.sum(); |
2053 | s.backward(); |
2054 | |
2055 | ASSERT_TRUE(output.allclose(expected)); |
2056 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
2057 | } |
2058 | |
2059 | TEST_F(ModulesTest, HuberLossNoReduction) { |
2060 | HuberLoss loss(/*reduction=*/torch::kNone); |
2061 | auto input = torch::tensor( |
2062 | {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); |
2063 | auto target = torch::tensor({0., 1., 5.}, torch::kFloat); |
2064 | auto output = loss(input, target); |
2065 | auto expected = torch::tensor({0.005, 0.02, 0.045}, torch::kFloat); |
2066 | auto s = output.sum(); |
2067 | s.backward(); |
2068 | |
2069 | ASSERT_TRUE(output.allclose(expected)); |
2070 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
2071 | } |
2072 | |
2073 | TEST_F(ModulesTest, MultiLabelMarginLossNoReduction) { |
2074 | MultiLabelMarginLoss loss(torch::kNone); |
2075 | auto input = torch::tensor( |
2076 | {{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true)); |
2077 | auto target = torch::tensor({{3, 0, -1, 1}}, torch::kLong); |
2078 | auto output = loss->forward(input, target); |
2079 | auto expected = torch::tensor({0.8500}, torch::kFloat); |
2080 | auto s = output.sum(); |
2081 | s.backward(); |
2082 | |
2083 | ASSERT_TRUE(output.allclose(expected)); |
2084 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
2085 | } |
2086 | |
2087 | TEST_F(ModulesTest, SmoothL1LossBeta) { |
2088 | auto options = SmoothL1LossOptions().beta(0.2); |
2089 | SmoothL1Loss loss(options); |
2090 | auto input = torch::tensor( |
2091 | {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); |
2092 | auto target = torch::tensor({0., 1., 5.}, torch::kFloat); |
2093 | auto output = loss(input, target); |
2094 | auto expected = torch::tensor(0.108333, torch::kFloat); |
2095 | auto s = output.sum(); |
2096 | s.backward(); |
2097 | |
2098 | ASSERT_TRUE(output.allclose(expected)); |
2099 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
2100 | } |
2101 | |
2102 | TEST_F(ModulesTest, HuberLossDelta) { |
2103 | auto options = HuberLossOptions().delta(0.2); |
2104 | HuberLoss loss(options); |
2105 | auto input = torch::tensor( |
2106 | {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); |
2107 | auto target = torch::tensor({0., 1., 5.}, torch::kFloat); |
2108 | auto output = loss(input, target); |
2109 | auto expected = torch::tensor(0.0216666, torch::kFloat); |
2110 | auto s = output.sum(); |
2111 | s.backward(); |
2112 | |
2113 | ASSERT_TRUE(output.allclose(expected)); |
2114 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
2115 | } |
2116 | |
2117 | TEST_F(ModulesTest, TripletMarginLoss) { |
2118 | TripletMarginLoss loss(TripletMarginLossOptions().margin(1.0)); |
2119 | auto anchor = torch::tensor( |
2120 | {{3., 3.}}, torch::dtype(torch::kFloat).requires_grad(true)); |
2121 | auto positive = torch::tensor( |
2122 | {{2., 2.}}, torch::dtype(torch::kFloat).requires_grad(true)); |
2123 | auto negative = torch::tensor( |
2124 | {{0., 0.}}, torch::dtype(torch::kFloat).requires_grad(true)); |
2125 | auto output = loss->forward(anchor, positive, negative); |
2126 | auto expected = torch::tensor({0.}, torch::kFloat); |
2127 | auto s = output.sum(); |
2128 | s.backward(); |
2129 | |
2130 | ASSERT_TRUE(output.allclose(expected, 1e-04)); |
2131 | ASSERT_EQ(anchor.sizes(), anchor.grad().sizes()); |
2132 | } |
2133 | |
2134 | TEST_F(ModulesTest, TripletMarginWithDistanceLossDefaultParity) { |
2135 | // Check that if we use torch::pairwise_distance with the default |
2136 | // TripletMarginLoss options as our distance function, the outputs |
2137 | // are equal (i.e., equal under defaults). |
2138 | |
2139 | std::vector<TripletMarginWithDistanceLossOptions::reduction_t> reductions = { |
2140 | torch::kSum, torch::kMean, torch::kNone}; |
2141 | std::vector<float> margins = {0.5, 1.0, 1.5}; |
2142 | std::vector<bool> swaps = {true, false}; |
2143 | |
2144 | for (auto& reduction : reductions) { |
2145 | for (auto& margin : margins) { |
2146 | for (const auto swap : swaps) { |
2147 | auto anchor = torch::randn( |
2148 | {100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); |
2149 | auto positive = torch::randn( |
2150 | {100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); |
2151 | auto negative = torch::randn( |
2152 | {100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); |
2153 | |
2154 | auto basicOptions = |
2155 | TripletMarginLossOptions().reduction(reduction).margin(margin).swap( |
2156 | swap); |
2157 | auto distanceOptions = TripletMarginWithDistanceLossOptions() |
2158 | .reduction(reduction) |
2159 | .margin(margin) |
2160 | .swap(swap); |
2161 | TripletMarginLoss basicLoss(basicOptions); |
2162 | TripletMarginWithDistanceLoss distanceLoss(distanceOptions); |
2163 | |
2164 | auto basicOutput = basicLoss->forward(anchor, positive, negative); |
2165 | auto distanceOutput = distanceLoss->forward(anchor, positive, negative); |
2166 | auto basicOperatorOutput = basicLoss(anchor, positive, negative); |
2167 | auto distanceOperatorOutput = distanceLoss(anchor, positive, negative); |
2168 | |
2169 | ASSERT_TRUE(distanceOutput.allclose(basicOutput, 1e-6, 1e-6)); |
2170 | ASSERT_TRUE( |
2171 | distanceOperatorOutput.allclose(distanceOutput, 1e-6, 1e-6)); |
2172 | ASSERT_TRUE( |
2173 | distanceOperatorOutput.allclose(basicOperatorOutput, 1e-6, 1e-6)); |
2174 | |
2175 | // handle for torch::kNone reduction |
2176 | auto sum = distanceOutput.sum(); |
2177 | sum.backward(); |
2178 | ASSERT_EQ(anchor.sizes(), anchor.grad().sizes()); |
2179 | ASSERT_EQ(positive.sizes(), positive.grad().sizes()); |
2180 | ASSERT_EQ(negative.sizes(), negative.grad().sizes()); |
2181 | } |
2182 | } |
2183 | } |
2184 | } |
2185 | |
2186 | TEST_F(ModulesTest, TripletMarginWithDistanceLossFunctionalParity) { |
2187 | // Check for parity between F::triplet_margin_with_distance_loss and |
2188 | // TripletMarginWithDistanceLoss. |
2189 | auto pairwise_distance = [&](const torch::Tensor& x, const torch::Tensor& y) { |
2190 | return torch::pairwise_distance(x, y); |
2191 | }; |
2192 | auto cosine_distance = [&](const torch::Tensor& x, const torch::Tensor& y) { |
2193 | return 1.0 - torch::cosine_similarity(x, y); |
2194 | }; |
2195 | std::vector<TripletMarginWithDistanceLossOptions::distance_function_t> |
2196 | distance_functions = {pairwise_distance, cosine_distance}; |
2197 | |
2198 | std::vector<TripletMarginWithDistanceLossOptions::reduction_t> reductions = { |
2199 | torch::kSum, torch::kMean, torch::kNone}; |
2200 | std::vector<float> margins = {0.5, 1.0, 1.5}; |
2201 | std::vector<bool> swaps = {true, false}; |
2202 | |
2203 | for (auto& function : distance_functions) { |
2204 | for (auto& reduction : reductions) { |
2205 | for (auto& margin : margins) { |
2206 | for (const auto swap : swaps) { |
2207 | auto moduleOptions = TripletMarginWithDistanceLossOptions() |
2208 | .distance_function(function) |
2209 | .reduction(reduction) |
2210 | .margin(margin) |
2211 | .swap(swap); |
2212 | auto functionOptions = |
2213 | torch::nn::functional::TripletMarginWithDistanceLossFuncOptions() |
2214 | .distance_function(function) |
2215 | .reduction(reduction) |
2216 | .margin(margin) |
2217 | .swap(swap); |
2218 | |
2219 | auto anchor = torch::randn( |
2220 | {100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); |
2221 | auto positive = torch::randn( |
2222 | {100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); |
2223 | auto negative = torch::randn( |
2224 | {100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); |
2225 | |
2226 | TripletMarginWithDistanceLoss distanceLoss(moduleOptions); |
2227 | |
2228 | auto moduleOutput = distanceLoss->forward(anchor, positive, negative); |
2229 | auto moduleOperatorOutput = distanceLoss(anchor, positive, negative); |
2230 | auto functionOutput = |
2231 | torch::nn::functional::triplet_margin_with_distance_loss( |
2232 | anchor, positive, negative, functionOptions); |
2233 | |
2234 | ASSERT_TRUE(moduleOutput.allclose(functionOutput, 1e-6, 1e-6)); |
2235 | ASSERT_TRUE( |
2236 | moduleOperatorOutput.allclose(functionOutput, 1e-6, 1e-6)); |
2237 | } |
2238 | } |
2239 | } |
2240 | } |
2241 | } |
2242 | |
2243 | TEST_F(ModulesTest, NLLLoss) { |
2244 | NLLLoss loss; |
2245 | auto input = torch::tensor( |
2246 | {{-0.1315, -3.1315, -2.5315}, |
2247 | {-3.7038, -0.1038, -2.6038}, |
2248 | {-2.3422, -1.3422, -0.4422}}, |
2249 | torch::dtype(torch::kFloat).requires_grad(true)); |
2250 | auto target = torch::tensor({1, 0, 2}, torch::kLong); |
2251 | auto output = loss->forward(input, target); |
2252 | auto expected = torch::tensor(2.4258, torch::kFloat); |
2253 | auto s = output.sum(); |
2254 | s.backward(); |
2255 | |
2256 | ASSERT_TRUE(output.allclose(expected, 1e-04)); |
2257 | ASSERT_TRUE( |
2258 | NLLLoss(NLLLossOptions().ignore_index(-100).reduction(torch::kMean)) |
2259 | ->forward(input, target) |
2260 | .allclose(expected, 1e-04)); |
2261 | } |
2262 | |
2263 | TEST_F(ModulesTest, CrossEntropyLoss) { |
2264 | CrossEntropyLoss loss; |
2265 | auto input = torch::tensor( |
2266 | {{3., 3.}, {2., 2.}}, torch::dtype(torch::kFloat).requires_grad(true)); |
2267 | auto target = torch::tensor({0, 1}, torch::kLong); |
2268 | auto output = loss->forward(input, target); |
2269 | auto expected = torch::tensor(0.6931, torch::kFloat); |
2270 | auto s = output.sum(); |
2271 | s.backward(); |
2272 | |
2273 | ASSERT_TRUE(output.allclose(expected, 1e-04)); |
2274 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
2275 | ASSERT_TRUE( |
2276 | CrossEntropyLoss( |
2277 | CrossEntropyLossOptions().ignore_index(-100).reduction(torch::kMean)) |
2278 | ->forward(input, target) |
2279 | .allclose(expected, 1e-04)); |
2280 | |
2281 | // label smoothing with class indices |
2282 | loss = CrossEntropyLoss( |
2283 | CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kMean)); |
2284 | input = torch::tensor( |
2285 | {{3., 1.}, {1., 2.}}, torch::dtype(torch::kFloat).requires_grad(true)); |
2286 | target = torch::tensor({0, 1}, torch::kLong); |
2287 | output = loss->forward(input, target); |
2288 | expected = torch::tensor(0.3326, torch::kFloat); |
2289 | s = output.sum(); |
2290 | s.backward(); |
2291 | |
2292 | ASSERT_TRUE(output.allclose(expected, 1e-04)); |
2293 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
2294 | |
2295 | // label smoothing with with target probabilities |
2296 | loss = CrossEntropyLoss( |
2297 | CrossEntropyLossOptions().label_smoothing(0.2).reduction(torch::kMean)); |
2298 | input = torch::tensor( |
2299 | {{3., 1.}, {1., 2.}}, torch::dtype(torch::kFloat).requires_grad(true)); |
2300 | target = torch::tensor({{0.8, 0.2}, {0.1, 0.9}}, torch::kFloat); |
2301 | output = loss->forward(input, target); |
2302 | expected = torch::tensor(0.5701, torch::kFloat); |
2303 | s = output.sum(); |
2304 | s.backward(); |
2305 | |
2306 | ASSERT_TRUE(output.allclose(expected, 1e-04)); |
2307 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
2308 | } |
2309 | |
2310 | TEST_F(ModulesTest, CosineSimilarity) { |
2311 | CosineSimilarity cos(CosineSimilarityOptions().dim(1)); |
2312 | auto input1 = torch::tensor( |
2313 | {{1, 2, 3}, {4, 5, 6}}, torch::dtype(torch::kFloat).requires_grad(true)); |
2314 | auto input2 = torch::tensor( |
2315 | {{1, 8, 3}, {2, 1, 6}}, torch::dtype(torch::kFloat).requires_grad(true)); |
2316 | auto output = cos->forward(input1, input2); |
2317 | auto expected = torch::tensor({0.8078, 0.8721}, torch::kFloat); |
2318 | auto s = output.sum(); |
2319 | s.backward(); |
2320 | |
2321 | ASSERT_TRUE(output.allclose(expected, 1e-04)); |
2322 | ASSERT_EQ(input1.sizes(), input1.grad().sizes()); |
2323 | } |
2324 | |
2325 | TEST_F(ModulesTest, SoftMarginLossDefaultOptions) { |
2326 | SoftMarginLoss loss; |
2327 | auto input = torch::tensor( |
2328 | {2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true)); |
2329 | auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat); |
2330 | auto output = loss->forward(input, target); |
2331 | auto expected = torch::tensor({1.3767317}, torch::kFloat); |
2332 | auto s = output.sum(); |
2333 | s.backward(); |
2334 | |
2335 | ASSERT_TRUE(output.allclose(expected)); |
2336 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
2337 | } |
2338 | |
2339 | TEST_F(ModulesTest, MultiLabelSoftMarginLossDefaultOptions) { |
2340 | MultiLabelSoftMarginLoss loss; |
2341 | auto input = torch::tensor( |
2342 | {{0., 2., 2., 0.}, {2., 1., 0., 1.}}, |
2343 | torch::dtype(torch::kFloat).requires_grad(true)); |
2344 | auto target = |
2345 | torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat); |
2346 | auto output = loss->forward(input, target); |
2347 | auto expected = torch::tensor({0.7608436}, torch::kFloat); |
2348 | auto s = output.sum(); |
2349 | s.backward(); |
2350 | |
2351 | ASSERT_TRUE(output.allclose(expected)); |
2352 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
2353 | } |
2354 | |
2355 | TEST_F(ModulesTest, SoftMarginLossNoReduction) { |
2356 | SoftMarginLoss loss(torch::kNone); |
2357 | auto input = torch::tensor( |
2358 | {2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true)); |
2359 | auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat); |
2360 | auto output = loss->forward(input, target); |
2361 | auto expected = torch::tensor( |
2362 | {2.1269281, 0.01814993, 0.3132617, 3.0485873}, torch::kFloat); |
2363 | auto s = output.sum(); |
2364 | s.backward(); |
2365 | |
2366 | ASSERT_TRUE(output.allclose(expected)); |
2367 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
2368 | } |
2369 | |
2370 | TEST_F(ModulesTest, MultiLabelSoftMarginLossWeightedNoReduction) { |
2371 | auto input = torch::tensor( |
2372 | {{0., 2., 2., 0.}, {2., 1., 0., 1.}}, |
2373 | torch::dtype(torch::kFloat).requires_grad(true)); |
2374 | auto target = |
2375 | torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat); |
2376 | auto weight = torch::tensor({0.1, 0.6, 0.4, 0.8}, torch::kFloat); |
2377 | auto options = |
2378 | MultiLabelSoftMarginLossOptions().reduction(torch::kNone).weight(weight); |
2379 | MultiLabelSoftMarginLoss loss = MultiLabelSoftMarginLoss(options); |
2380 | auto output = loss->forward(input, target); |
2381 | auto expected = torch::tensor({0.4876902, 0.3321295}, torch::kFloat); |
2382 | auto s = output.sum(); |
2383 | s.backward(); |
2384 | |
2385 | ASSERT_TRUE(output.allclose(expected)); |
2386 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
2387 | } |
2388 | |
2389 | TEST_F(ModulesTest, PairwiseDistance) { |
2390 | PairwiseDistance dist(PairwiseDistanceOptions().p(1)); |
2391 | auto input1 = torch::tensor( |
2392 | {{1, 2, 3}, {4, 5, 6}}, torch::dtype(torch::kFloat).requires_grad(true)); |
2393 | auto input2 = torch::tensor( |
2394 | {{1, 8, 3}, {2, 1, 6}}, torch::dtype(torch::kFloat).requires_grad(true)); |
2395 | auto output = dist->forward(input1, input2); |
2396 | auto expected = torch::tensor({6, 6}, torch::kFloat); |
2397 | auto s = output.sum(); |
2398 | s.backward(); |
2399 | |
2400 | ASSERT_TRUE(output.allclose(expected)); |
2401 | ASSERT_EQ(input1.sizes(), input1.grad().sizes()); |
2402 | } |
2403 | |
2404 | TEST_F(ModulesTest, ELU) { |
2405 | const auto size = 3; |
2406 | for (const auto alpha : {0.0, 0.42, 1.0, 4.2, 42.42}) { |
2407 | for (const auto inplace : {false, true}) { |
2408 | ELU model{ELUOptions().alpha(alpha).inplace(inplace)}; |
2409 | auto x = torch::linspace(-10.0, 10.0, size * size * size); |
2410 | x.resize_({size, size, size}); |
2411 | if (!inplace) { |
2412 | x.requires_grad_(true); |
2413 | } |
2414 | auto x_orig = x.clone(); |
2415 | auto y = model(x); |
2416 | torch::Tensor s = y.sum(); |
2417 | |
2418 | ASSERT_EQ(s.ndimension(), 0); |
2419 | |
2420 | ASSERT_EQ(y.ndimension(), 3); |
2421 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
2422 | auto y_exp = torch::max(torch::zeros_like(x_orig), x_orig) + |
2423 | torch::min(torch::zeros_like(x_orig), |
2424 | alpha * (torch::exp(x_orig) - 1.0)); |
2425 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
2426 | if (inplace) { |
2427 | ASSERT_TRUE(torch::allclose(x, y_exp)); |
2428 | } else { |
2429 | s.backward(); |
2430 | } |
2431 | } |
2432 | } |
2433 | } |
2434 | |
2435 | TEST_F(ModulesTest, SELU) { |
2436 | for (const auto inplace : {false, true}) { |
2437 | SELU model(inplace); |
2438 | auto input = torch::randn({5, 5}); |
2439 | if (!inplace) { |
2440 | input.requires_grad_(true); |
2441 | } |
2442 | auto input_orig = input.clone(); |
2443 | auto output = model->forward(input); |
2444 | const double scale = 1.0507009873554804934193349852946; |
2445 | const double alpha = 1.6732632423543772848170429916717; |
2446 | auto zero = torch::zeros_like(input); |
2447 | auto expected = scale * |
2448 | (torch::max(zero, input_orig) + |
2449 | torch::min(zero, alpha * (torch::exp(input_orig) - 1))); |
2450 | auto s = output.sum(); |
2451 | |
2452 | ASSERT_EQ(s.ndimension(), 0); |
2453 | ASSERT_TRUE(output.allclose(expected)); |
2454 | if (inplace) { |
2455 | ASSERT_TRUE(input.allclose(expected)); |
2456 | } else { |
2457 | s.backward(); |
2458 | } |
2459 | } |
2460 | } |
2461 | |
2462 | TEST_F(ModulesTest, Hardshrink) { |
2463 | const auto size = 3; |
2464 | for (const auto lambda : {-4.2, -1.0, -0.42, 0.0, 0.42, 1.0, 4.2, 42.42}) { |
2465 | Hardshrink model{HardshrinkOptions().lambda(lambda)}; |
2466 | auto x = torch::linspace(-10.0, 10.0, size * size * size); |
2467 | x.resize_({size, size, size}).set_requires_grad(true); |
2468 | auto y = model(x); |
2469 | torch::Tensor s = y.sum(); |
2470 | |
2471 | s.backward(); |
2472 | ASSERT_EQ(s.ndimension(), 0); |
2473 | ASSERT_EQ(y.ndimension(), 3); |
2474 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
2475 | auto y_exp = (x.abs() > lambda) * x; |
2476 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
2477 | } |
2478 | } |
2479 | |
2480 | TEST_F(ModulesTest, Hardtanh) { |
2481 | const auto size = 3; |
2482 | for (const auto min_val : {-4.2, -1.0, -0.42, 0.0}) { |
2483 | for (const auto max_val : {0.42, 1.0, 4.2}) { |
2484 | for (const auto inplace : {false, true}) { |
2485 | Hardtanh model{ |
2486 | HardtanhOptions().min_val(min_val).max_val(max_val).inplace( |
2487 | inplace)}; |
2488 | auto x = torch::linspace(-10.0, 10.0, size * size * size); |
2489 | x.resize_({size, size, size}); |
2490 | if (!inplace) { |
2491 | x.requires_grad_(true); |
2492 | } |
2493 | auto x_orig = x.clone(); |
2494 | auto y = model(x); |
2495 | torch::Tensor s = y.sum(); |
2496 | |
2497 | ASSERT_EQ(s.ndimension(), 0); |
2498 | ASSERT_EQ(y.ndimension(), 3); |
2499 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
2500 | auto y_exp = (x_orig < min_val) * min_val + |
2501 | ((x_orig >= min_val) * (x_orig <= max_val)) * x_orig + |
2502 | (x_orig > max_val) * max_val; |
2503 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
2504 | if (inplace) { |
2505 | ASSERT_TRUE(torch::allclose(x, y_exp)); |
2506 | } else { |
2507 | s.backward(); |
2508 | } |
2509 | } |
2510 | } |
2511 | } |
2512 | } |
2513 | |
2514 | TEST_F(ModulesTest, HardtanhMinValGEMaxVal) { |
2515 | ASSERT_THROWS_WITH( |
2516 | Hardtanh{HardtanhOptions().min_val(0.42).max_val(0.42)}, |
2517 | "max_val must be greater than min_val" ); |
2518 | ASSERT_THROWS_WITH( |
2519 | Hardtanh{HardtanhOptions().min_val(0.42).max_val(-0.42)}, |
2520 | "max_val must be greater than min_val" ); |
2521 | |
2522 | Hardtanh ht{HardtanhOptions().min_val(-0.42).max_val(0.42)}; |
2523 | ht->options.min_val(0.42); |
2524 | ASSERT_THROWS_WITH(ht->reset(), "max_val must be greater than min_val" ); |
2525 | ht->options.max_val(-0.42); |
2526 | ASSERT_THROWS_WITH(ht->reset(), "max_val must be greater than min_val" ); |
2527 | } |
2528 | |
2529 | TEST_F(ModulesTest, LeakyReLU) { |
2530 | const auto size = 3; |
2531 | for (const auto inplace : {false, true}) { |
2532 | for (const auto negative_slope : {0.0, 0.42, 1.0}) { |
2533 | for (const auto type : {torch::kFloat, torch::kBFloat16}) { |
2534 | LeakyReLU model{ |
2535 | LeakyReLUOptions().negative_slope(negative_slope).inplace(inplace)}; |
2536 | auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type); |
2537 | x.resize_({size, size, size}); |
2538 | if (!inplace) { |
2539 | x.requires_grad_(true); |
2540 | } |
2541 | auto x_orig = x.clone(); |
2542 | auto y = model(x); |
2543 | torch::Tensor s = y.sum(); |
2544 | |
2545 | ASSERT_EQ(s.ndimension(), 0); |
2546 | ASSERT_EQ(y.ndimension(), 3); |
2547 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
2548 | auto y_exp = |
2549 | (x_orig < 0) * x_orig * negative_slope + (x_orig >= 0) * x_orig; |
2550 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
2551 | if (inplace) { |
2552 | ASSERT_TRUE(torch::allclose(x, y_exp)); |
2553 | } else { |
2554 | s.backward(); |
2555 | } |
2556 | } |
2557 | } |
2558 | } |
2559 | } |
2560 | |
2561 | TEST_F(ModulesTest, LogSigmoid) { |
2562 | const auto size = 3; |
2563 | LogSigmoid model; |
2564 | auto x = torch::linspace(-10.0, 10.0, size * size * size); |
2565 | x.resize_({size, size, size}).set_requires_grad(true); |
2566 | auto y = model(x); |
2567 | torch::Tensor s = y.sum(); |
2568 | |
2569 | s.backward(); |
2570 | ASSERT_EQ(s.ndimension(), 0); |
2571 | |
2572 | ASSERT_EQ(y.ndimension(), 3); |
2573 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
2574 | auto y_exp = torch::log( |
2575 | torch::ones_like(x) / (torch::ones_like(x) + torch::exp(torch::neg(x)))); |
2576 | ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7)); |
2577 | } |
2578 | |
2579 | TEST_F(ModulesTest, Softmax) { |
2580 | Softmax m(/*dim=*/1); |
2581 | auto input = torch::arange(10, torch::kFloat).reshape({2, 5}); |
2582 | auto output = m(input); |
2583 | auto sum = torch::sum(torch::exp(input), 1); |
2584 | |
2585 | for (const auto i : c10::irange(2)) { |
2586 | auto expected = torch::exp(input[i]) / sum[i]; |
2587 | ASSERT_TRUE(torch::allclose(output[i], expected)); |
2588 | } |
2589 | } |
2590 | |
2591 | TEST_F(ModulesTest, Softmin) { |
2592 | Softmin m(/*dim=*/1); |
2593 | auto input = torch::arange(10, torch::kFloat).reshape({2, 5}); |
2594 | auto output = m(input); |
2595 | auto sum = torch::sum(torch::exp(-input), 1); |
2596 | |
2597 | for (const auto i : c10::irange(2)) { |
2598 | auto expected = torch::exp(-input[i]) / sum[i]; |
2599 | ASSERT_TRUE(torch::allclose(output[i], expected)); |
2600 | } |
2601 | } |
2602 | |
2603 | TEST_F(ModulesTest, LogSoftmax) { |
2604 | LogSoftmax m(/*dim=*/1); |
2605 | auto input = torch::arange(10, torch::kFloat).reshape({2, 5}); |
2606 | auto output = m(input); |
2607 | auto sum = torch::sum(torch::exp(input), 1); |
2608 | |
2609 | for (const auto i : c10::irange(2)) { |
2610 | auto expected = torch::log(torch::exp(input[i]) / sum[i]); |
2611 | ASSERT_TRUE(torch::allclose(output[i], expected)); |
2612 | } |
2613 | } |
2614 | |
2615 | TEST_F(ModulesTest, AdaptiveLogSoftmaxWithLoss) { |
2616 | { |
2617 | // log_probs actually returns log_proba |
2618 | AdaptiveLogSoftmaxWithLoss asfm( |
2619 | AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.)); |
2620 | auto x = torch::randn({4, 8}); |
2621 | auto logprob_out = asfm->log_prob(x); |
2622 | ASSERT_TRUE( |
2623 | torch::allclose(torch::exp(logprob_out).data().sum(1), torch::ones(4))); |
2624 | } |
2625 | { |
2626 | // test predict |
2627 | AdaptiveLogSoftmaxWithLoss asfm( |
2628 | AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8}) |
2629 | .div_value(2.) |
2630 | .head_bias(true)); |
2631 | auto x = torch::randn({64, 8}); |
2632 | auto logprob_out = asfm->log_prob(x); |
2633 | auto predict_out = asfm->predict(x); |
2634 | ASSERT_TRUE(torch::allclose(predict_out, logprob_out.argmax(1))); |
2635 | } |
2636 | { |
2637 | // cluster sizes |
2638 | AdaptiveLogSoftmaxWithLoss asfm( |
2639 | AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(2.)); |
2640 | auto x = torch::arange(100, 132, torch::kFloat).reshape({2, 16}); |
2641 | auto y = torch::tensor({0, 17}, torch::kLong); |
2642 | auto asm_out = asfm(x, y); |
2643 | ASSERT_EQ(asm_out.output.sizes(), std::vector<int64_t>({2})); |
2644 | } |
2645 | { |
2646 | // forward returns the same thing as log_probs |
2647 | AdaptiveLogSoftmaxWithLoss asfm( |
2648 | AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.)); |
2649 | auto x = torch::randn({4, 8}); |
2650 | auto logprob_out = asfm->log_prob(x); |
2651 | NLLLoss nll_loss; |
2652 | |
2653 | for (const auto v : c10::irange(4)) { |
2654 | auto y = torch::full({4}, v, torch::kLong); |
2655 | auto asm_out = asfm(x, y); |
2656 | auto out = asm_out.output; |
2657 | auto loss = torch::tensor(asm_out.loss, torch::kFloat); |
2658 | auto expected = nll_loss->forward(logprob_out, y); |
2659 | |
2660 | ASSERT_TRUE(torch::allclose(loss, expected)); |
2661 | ASSERT_TRUE(torch::allclose( |
2662 | out, logprob_out.gather(1, y.unsqueeze(1)).squeeze())); |
2663 | } |
2664 | } |
2665 | { |
2666 | // test no batch dim |
2667 | AdaptiveLogSoftmaxWithLoss asfm( |
2668 | AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(2.)); |
2669 | auto x = torch::randn({1, 16}); |
2670 | auto y = torch::tensor({17}); |
2671 | auto x2 = x.squeeze(0); |
2672 | auto y2 = y.squeeze(0); |
2673 | ASSERT_TRUE( |
2674 | torch::allclose(asfm(x, y).output.squeeze(0), asfm(x2, y2).output)); |
2675 | } |
2676 | } |
2677 | |
2678 | TEST_F(ModulesTest, Softmax2d) { |
2679 | Softmax2d m; |
2680 | auto input = torch::arange(24, torch::kFloat).reshape({1, 2, 3, 4}); |
2681 | auto output = m(input); |
2682 | auto sum = torch::sum(torch::exp(input), 1); |
2683 | |
2684 | for (const auto i : c10::irange(1)) { |
2685 | for (const auto j : c10::irange(2)) { |
2686 | for (const auto k : c10::irange(3)) { |
2687 | for (const auto l : c10::irange(4)) { |
2688 | auto expected = torch::exp(input[i][j][k][l]) / sum[i][k][l]; |
2689 | ASSERT_TRUE(torch::allclose(output[i][j][k][l], expected)); |
2690 | } |
2691 | } |
2692 | } |
2693 | } |
2694 | } |
2695 | |
2696 | TEST_F(ModulesTest, PReLU) { |
2697 | const auto num_parameters = 42; |
2698 | const auto init = 0.42; |
2699 | |
2700 | PReLU model{PReLUOptions().num_parameters(num_parameters).init(init)}; |
2701 | |
2702 | ASSERT_EQ(model->weight.sizes(), std::vector<int64_t>({num_parameters})); |
2703 | ASSERT_TRUE( |
2704 | torch::allclose(model->weight, torch::full(num_parameters, init))); |
2705 | |
2706 | const auto x = torch::rand({100, num_parameters}) * 200 - 100; |
2707 | const auto y = model(x); |
2708 | const auto s = y.sum(); |
2709 | |
2710 | s.backward(); |
2711 | ASSERT_EQ(s.ndimension(), 0); |
2712 | |
2713 | ASSERT_EQ(y.ndimension(), x.ndimension()); |
2714 | ASSERT_EQ(y.sizes(), x.sizes()); |
2715 | const auto y_exp = (x < 0) * model->weight * x + (x >= 0) * x; |
2716 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
2717 | } |
2718 | |
2719 | TEST_F(ModulesTest, ReLU) { |
2720 | for (const auto inplace : {false, true}) { |
2721 | const auto size = 3; |
2722 | ReLU model(inplace); |
2723 | auto x = torch::linspace(-10.0, 10.0, size * size * size); |
2724 | x.resize_({size, size, size}); |
2725 | if (!inplace) { |
2726 | x.requires_grad_(true); |
2727 | } |
2728 | auto x_orig = x.clone(); |
2729 | auto y = model(x); |
2730 | torch::Tensor s = y.sum(); |
2731 | |
2732 | ASSERT_EQ(s.ndimension(), 0); |
2733 | ASSERT_EQ(y.ndimension(), 3); |
2734 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
2735 | auto y_exp = (x_orig < 0) * 0 + (x_orig >= 0) * x_orig; |
2736 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
2737 | if (inplace) { |
2738 | ASSERT_TRUE(torch::allclose(x, y_exp)); |
2739 | } else { |
2740 | s.backward(); |
2741 | } |
2742 | } |
2743 | } |
2744 | |
2745 | TEST_F(ModulesTest, ReLU6) { |
2746 | for (const auto inplace : {false, true}) { |
2747 | const auto size = 3; |
2748 | ReLU6 model(inplace); |
2749 | auto x = torch::linspace(-10.0, 10.0, size * size * size); |
2750 | x.resize_({size, size, size}); |
2751 | if (!inplace) { |
2752 | x.requires_grad_(true); |
2753 | } |
2754 | auto x_orig = x.clone(); |
2755 | auto y = model(x); |
2756 | torch::Tensor s = y.sum(); |
2757 | |
2758 | ASSERT_EQ(s.ndimension(), 0); |
2759 | ASSERT_EQ(y.ndimension(), 3); |
2760 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
2761 | auto y_exp = (x_orig < 0) * 0 + ((x_orig >= 0) * (x_orig <= 6)) * x_orig + |
2762 | (x_orig > 6) * 6; |
2763 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
2764 | if (inplace) { |
2765 | ASSERT_TRUE(torch::allclose(x, y_exp)); |
2766 | } else { |
2767 | s.backward(); |
2768 | } |
2769 | } |
2770 | } |
2771 | |
2772 | TEST_F(ModulesTest, RReLU) { |
2773 | const auto size = 3; |
2774 | for (const auto lower : {0.01, 0.1, 0.2}) { |
2775 | for (const auto upper : {0.3, 0.4, 0.5}) { |
2776 | for (const auto inplace : {false, true}) { |
2777 | for (const auto type : {torch::kFloat, torch::kBFloat16}) { |
2778 | RReLU model{ |
2779 | RReLUOptions().lower(lower).upper(upper).inplace(inplace)}; |
2780 | auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type); |
2781 | x.resize_({size, size, size}); |
2782 | if (!inplace) { |
2783 | x.requires_grad_(true); |
2784 | } |
2785 | auto x_orig = x.clone(); |
2786 | auto y = model(x); |
2787 | torch::Tensor s = y.sum(); |
2788 | |
2789 | ASSERT_EQ(s.ndimension(), 0); |
2790 | ASSERT_EQ(y.ndimension(), 3); |
2791 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
2792 | auto z = |
2793 | ((x_orig >= 0) * (x_orig == y) + |
2794 | (x_orig < 0) * (y >= x_orig * upper) * (y <= lower * x_orig)) * |
2795 | 1.0; |
2796 | ASSERT_TRUE(torch::allclose(z, torch::ones_like(z))); |
2797 | if (inplace) { |
2798 | ASSERT_TRUE(torch::allclose(x, y)); |
2799 | } else { |
2800 | s.backward(); |
2801 | } |
2802 | } |
2803 | } |
2804 | } |
2805 | } |
2806 | } |
2807 | |
2808 | TEST_F(ModulesTest, CELU) { |
2809 | const auto size = 3; |
2810 | for (const auto inplace : {false, true}) { |
2811 | for (const auto alpha : {0.42, 1.0, 4.2, 42.42}) { |
2812 | CELU model{CELUOptions().alpha(alpha).inplace(inplace)}; |
2813 | auto x = torch::linspace(-10.0, 10.0, size * size * size); |
2814 | x.resize_({size, size, size}); |
2815 | if (!inplace) { |
2816 | x.requires_grad_(true); |
2817 | } |
2818 | auto x_orig = x.clone(); |
2819 | auto y = model(x); |
2820 | torch::Tensor s = y.sum(); |
2821 | |
2822 | ASSERT_EQ(s.ndimension(), 0); |
2823 | ASSERT_EQ(y.ndimension(), 3); |
2824 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
2825 | auto y_exp = torch::max(torch::zeros_like(x_orig), x_orig) + |
2826 | torch::min(torch::zeros_like(x_orig), |
2827 | alpha * (torch::exp(x_orig / alpha) - 1.0)); |
2828 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
2829 | if (inplace) { |
2830 | ASSERT_TRUE(torch::allclose(x, y_exp)); |
2831 | } else { |
2832 | s.backward(); |
2833 | } |
2834 | } |
2835 | } |
2836 | } |
2837 | |
2838 | TEST_F(ModulesTest, GLU) { |
2839 | int64_t dim = 1; |
2840 | GLU model(dim); |
2841 | auto input = torch::randn({4, 2}, torch::requires_grad()); |
2842 | auto output = model->forward(input); |
2843 | auto input_size = input.sizes()[dim] / 2; |
2844 | auto first_half = input.narrow(dim, 0, input_size); |
2845 | auto second_half = input.narrow(dim, input_size, input_size); |
2846 | auto expected = first_half * torch::sigmoid(second_half); |
2847 | auto s = output.sum(); |
2848 | s.backward(); |
2849 | |
2850 | ASSERT_EQ(s.ndimension(), 0); |
2851 | ASSERT_TRUE(output.allclose(expected)); |
2852 | |
2853 | GLU model_default_options; |
2854 | ASSERT_TRUE(model_default_options->forward(input).allclose(expected)); |
2855 | } |
2856 | |
2857 | TEST_F(ModulesTest, GELU) { |
2858 | GELU model(GELUOptions().approximate("none" )); |
2859 | const auto x = torch::linspace(-3.0, 3.0, 100); |
2860 | const auto y_exp = x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0))); |
2861 | const auto y = model(x); |
2862 | ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05)); |
2863 | } |
2864 | |
2865 | TEST_F(ModulesTest, TanhGELU) { |
2866 | GELU model(GELUOptions().approximate("tanh" )); |
2867 | const auto x = torch::linspace(-3.0, 3.0, 100); |
2868 | const auto inner = std::sqrt(2 / M_PI) * (x + 0.044715 * x.pow(3.0)); |
2869 | const auto y_exp = 0.5 * x * (1.0 + inner.tanh()); |
2870 | const auto y = model(x); |
2871 | ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05)); |
2872 | } |
2873 | |
2874 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
2875 | TEST_F(ModulesTest, Mish) { |
2876 | Mish model; |
2877 | auto x = torch::randn(100) * 10; |
2878 | auto y_exp = x * x.exp().log1p().tanh(); |
2879 | auto y = model(x); |
2880 | |
2881 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
2882 | } |
2883 | |
2884 | TEST_F(ModulesTest, Sigmoid) { |
2885 | Sigmoid model; |
2886 | auto x = torch::randn(100) * 10; |
2887 | auto y_exp = 1 / (1 + torch::exp(-x)); |
2888 | auto y = model(x); |
2889 | |
2890 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
2891 | } |
2892 | |
2893 | TEST_F(ModulesTest, PixelShuffle) { |
2894 | PixelShuffle module(/*upscale_factor=*/2); |
2895 | auto x = torch::tensor( |
2896 | {{{{-17, 19}, {-1, 2}}, |
2897 | {{7, 14}, {-3, 1}}, |
2898 | {{0, -2}, {-12, 14}}, |
2899 | {{-15, 0}, {-3, 9}}}}, |
2900 | torch::kFloat); |
2901 | auto y_exp = torch::tensor( |
2902 | {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}}, |
2903 | torch::kFloat); |
2904 | auto y = module(x); |
2905 | |
2906 | ASSERT_EQ(y.ndimension(), 4); |
2907 | ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 4, 4})); |
2908 | ASSERT_TRUE(y.allclose(y_exp)); |
2909 | } |
2910 | |
2911 | TEST_F(ModulesTest, PixelUnshuffle) { |
2912 | PixelUnshuffle module(/*downscale_factor=*/2); |
2913 | auto x = torch::tensor( |
2914 | {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}}, |
2915 | torch::kFloat); |
2916 | auto y_exp = torch::tensor( |
2917 | {{{{-17, 19}, {-1, 2}}, |
2918 | {{7, 14}, {-3, 1}}, |
2919 | {{0, -2}, {-12, 14}}, |
2920 | {{-15, 0}, {-3, 9}}}}, |
2921 | torch::kFloat); |
2922 | auto y = module(x); |
2923 | |
2924 | ASSERT_EQ(y.ndimension(), 4); |
2925 | ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 4, 2, 2})); |
2926 | ASSERT_TRUE(y.allclose(y_exp)); |
2927 | } |
2928 | |
2929 | TEST_F(ModulesTest, Softplus) { |
2930 | const auto size = 3; |
2931 | for (const auto beta : {0.5, 1.0, 2.0}) { |
2932 | for (const auto threshold : {1.0, 3.0, 5.0}) { |
2933 | Softplus model{SoftplusOptions().beta(beta).threshold(threshold)}; |
2934 | auto x = torch::linspace(-3.0, 3.0, 61); |
2935 | x.resize_({size, size, size}); |
2936 | auto y_exp = |
2937 | (x <= threshold) * torch::log(1 + torch::exp(x * beta)) / beta + |
2938 | (x > threshold) * x; |
2939 | auto y = model(x); |
2940 | |
2941 | ASSERT_EQ(y.ndimension(), 3); |
2942 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
2943 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
2944 | } |
2945 | } |
2946 | } |
2947 | |
2948 | TEST_F(ModulesTest, Softshrink) { |
2949 | const auto size = 3; |
2950 | for (const auto lambda : {0.0, 0.42, 1.0, 4.2, 42.42}) { |
2951 | Softshrink model{/*lambda=*/lambda}; |
2952 | auto x = torch::linspace(-10.0, 10.0, size * size * size); |
2953 | x.resize_({size, size, size}).set_requires_grad(true); |
2954 | auto y = model(x); |
2955 | torch::Tensor s = y.sum(); |
2956 | |
2957 | s.backward(); |
2958 | ASSERT_EQ(s.ndimension(), 0); |
2959 | |
2960 | ASSERT_EQ(y.ndimension(), 3); |
2961 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
2962 | auto y_exp = (x < -lambda) * (x + lambda) + (x > lambda) * (x - lambda); |
2963 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
2964 | } |
2965 | } |
2966 | |
2967 | TEST_F(ModulesTest, Softsign) { |
2968 | Softsign model; |
2969 | auto x = torch::randn(100) * 10; |
2970 | auto y_exp = x / (1 + x.abs()); |
2971 | auto y = model(x); |
2972 | |
2973 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
2974 | } |
2975 | |
2976 | TEST_F(ModulesTest, Tanh) { |
2977 | Tanh model; |
2978 | auto x = torch::randn(100) * 10; |
2979 | auto y_exp = (x.exp() - (-x).exp()) / (x.exp() + (-x).exp()); |
2980 | auto y = model(x); |
2981 | |
2982 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
2983 | } |
2984 | |
2985 | TEST_F(ModulesTest, Tanhshrink) { |
2986 | Tanhshrink model; |
2987 | auto x = torch::randn(100) * 10; |
2988 | auto y_exp = x - x.tanh(); |
2989 | auto y = model(x); |
2990 | |
2991 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
2992 | } |
2993 | |
2994 | TEST_F(ModulesTest, Threshold) { |
2995 | const auto size = 3; |
2996 | for (const auto threshold : {0.5, 1.0, 2.0}) { |
2997 | for (const auto value : {0.5, 1.0, 2.0}) { |
2998 | for (const auto inplace : {false, true}) { |
2999 | Threshold model{ThresholdOptions(threshold, value).inplace(inplace)}; |
3000 | auto x = torch::linspace(-3.0, 3.0, 61); |
3001 | x.resize_({size, size, size}); |
3002 | auto x_orig = x.clone(); |
3003 | auto y_exp = |
3004 | (x_orig <= threshold) * value + (x_orig > threshold) * x_orig; |
3005 | auto y = model(x); |
3006 | |
3007 | ASSERT_EQ(y.ndimension(), 3); |
3008 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
3009 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
3010 | if (inplace) { |
3011 | ASSERT_TRUE(torch::allclose(x, y_exp)); |
3012 | } |
3013 | } |
3014 | } |
3015 | } |
3016 | } |
3017 | |
3018 | TEST_F(ModulesTest, Upsampling1D) { |
3019 | { |
3020 | Upsample model(UpsampleOptions() |
3021 | .size(std::vector<int64_t>({4})) |
3022 | .mode(torch::kNearest)); |
3023 | auto input = torch::ones({1, 1, 2}, torch::requires_grad()); |
3024 | auto output = model->forward(input); |
3025 | auto expected = torch::ones({1, 1, 4}); |
3026 | auto s = output.sum(); |
3027 | s.backward(); |
3028 | |
3029 | ASSERT_EQ(s.ndimension(), 0); |
3030 | ASSERT_TRUE(output.allclose(expected)); |
3031 | } |
3032 | { |
3033 | for (const auto align_corners : {true, false}) { |
3034 | // test float scale factor up & down sampling |
3035 | for (const auto scale_factor : {0.5, 1.5, 2.0}) { |
3036 | Upsample model(UpsampleOptions() |
3037 | .scale_factor(std::vector<double>({scale_factor})) |
3038 | .mode(torch::kLinear) |
3039 | .align_corners(align_corners)); |
3040 | auto input = torch::ones({1, 1, 2}, torch::requires_grad()); |
3041 | auto output = model->forward(input); |
3042 | auto expected_size = |
3043 | static_cast<int64_t>(std::floor(input.size(-1) * scale_factor)); |
3044 | auto expected = torch::ones({1, 1, expected_size}); |
3045 | auto s = output.sum(); |
3046 | s.backward(); |
3047 | |
3048 | ASSERT_EQ(s.ndimension(), 0); |
3049 | ASSERT_TRUE(output.allclose(expected)); |
3050 | } |
3051 | } |
3052 | } |
3053 | { |
3054 | // linear (1D) upsampling spatial invariance |
3055 | Upsample model(UpsampleOptions() |
3056 | .scale_factor(std::vector<double>({3})) |
3057 | .mode(torch::kLinear) |
3058 | .align_corners(false)); |
3059 | auto input = torch::zeros({1, 1, 9}); |
3060 | input.narrow(2, 0, 4).normal_(); |
3061 | auto output = model->forward(input); |
3062 | auto expected = model->forward(input.narrow(2, 0, 5)); |
3063 | |
3064 | ASSERT_TRUE(torch::allclose(output.narrow(2, 0, 15), expected)); |
3065 | } |
3066 | } |
3067 | |
3068 | TEST_F(ModulesTest, Upsampling2D) { |
3069 | { |
3070 | Upsample model(UpsampleOptions() |
3071 | .size(std::vector<int64_t>({4, 4})) |
3072 | .mode(torch::kNearest)); |
3073 | auto input = torch::ones({1, 1, 2, 2}, torch::requires_grad()); |
3074 | auto output = model->forward(input); |
3075 | auto expected = torch::ones({1, 1, 4, 4}); |
3076 | auto s = output.sum(); |
3077 | s.backward(); |
3078 | |
3079 | ASSERT_EQ(s.ndimension(), 0); |
3080 | ASSERT_TRUE(output.allclose(expected)); |
3081 | } |
3082 | { |
3083 | for (const auto align_corners : {true, false}) { |
3084 | // test float scale factor up & down sampling |
3085 | for (const auto scale_factor : {0.5, 1.5, 2.0}) { |
3086 | Upsample model( |
3087 | UpsampleOptions() |
3088 | .scale_factor(std::vector<double>({scale_factor, scale_factor})) |
3089 | .mode(torch::kBilinear) |
3090 | .align_corners(align_corners)); |
3091 | auto input = torch::ones({1, 1, 2, 2}, torch::requires_grad()); |
3092 | auto output = model->forward(input); |
3093 | auto expected_size = |
3094 | static_cast<int64_t>(std::floor(input.size(-1) * scale_factor)); |
3095 | auto expected = torch::ones({1, 1, expected_size, expected_size}); |
3096 | auto s = output.sum(); |
3097 | s.backward(); |
3098 | |
3099 | ASSERT_EQ(s.ndimension(), 0); |
3100 | ASSERT_TRUE(output.allclose(expected)); |
3101 | } |
3102 | } |
3103 | } |
3104 | { |
3105 | for (const auto align_corners : {true, false}) { |
3106 | // test float scale factor up & down sampling |
3107 | for (const auto scale_factor : {0.5, 1.5, 2.0}) { |
3108 | Upsample model( |
3109 | UpsampleOptions() |
3110 | .scale_factor(std::vector<double>({scale_factor, scale_factor})) |
3111 | .mode(torch::kBicubic) |
3112 | .align_corners(align_corners)); |
3113 | auto input = torch::ones({1, 1, 2, 2}, torch::requires_grad()); |
3114 | auto output = model->forward(input); |
3115 | auto expected_size = |
3116 | static_cast<int64_t>(std::floor(input.size(-1) * scale_factor)); |
3117 | auto expected = torch::ones({1, 1, expected_size, expected_size}); |
3118 | auto s = output.sum(); |
3119 | s.backward(); |
3120 | |
3121 | ASSERT_EQ(s.ndimension(), 0); |
3122 | ASSERT_TRUE(output.allclose(expected)); |
3123 | } |
3124 | } |
3125 | } |
3126 | } |
3127 | |
3128 | TEST_F(ModulesTest, Upsampling3D) { |
3129 | { |
3130 | Upsample model(UpsampleOptions() |
3131 | .size(std::vector<int64_t>({4, 4, 4})) |
3132 | .mode(torch::kNearest)); |
3133 | auto input = torch::ones({1, 1, 2, 2, 2}, torch::requires_grad()); |
3134 | auto output = model->forward(input); |
3135 | auto expected = torch::ones({1, 1, 4, 4, 4}); |
3136 | auto s = output.sum(); |
3137 | s.backward(); |
3138 | |
3139 | ASSERT_EQ(s.ndimension(), 0); |
3140 | ASSERT_TRUE(output.allclose(expected)); |
3141 | } |
3142 | { |
3143 | for (const auto align_corners : {true, false}) { |
3144 | // test float scale factor up & down sampling |
3145 | for (const auto scale_factor : {0.5, 1.5, 2.0}) { |
3146 | Upsample model(UpsampleOptions() |
3147 | .scale_factor(std::vector<double>( |
3148 | {scale_factor, scale_factor, scale_factor})) |
3149 | .mode(torch::kTrilinear) |
3150 | .align_corners(align_corners)); |
3151 | auto input = torch::ones({1, 1, 2, 2, 2}, torch::requires_grad()); |
3152 | auto output = model->forward(input); |
3153 | auto expected_size = |
3154 | static_cast<int64_t>(std::floor(input.size(-1) * scale_factor)); |
3155 | auto expected = |
3156 | torch::ones({1, 1, expected_size, expected_size, expected_size}); |
3157 | auto s = output.sum(); |
3158 | s.backward(); |
3159 | |
3160 | ASSERT_EQ(s.ndimension(), 0); |
3161 | ASSERT_TRUE(output.allclose(expected)); |
3162 | } |
3163 | } |
3164 | } |
3165 | } |
3166 | |
3167 | TEST_F(ModulesTest, CTCLoss) { |
3168 | CTCLoss loss{CTCLossOptions().reduction(torch::kNone)}; |
3169 | const auto target_lengths = torch::tensor({0, 0, 0}); |
3170 | const auto input_lengths = torch::tensor({50, 50, 50}); |
3171 | const auto targets = |
3172 | torch::randint(1, 15, at::IntArrayRef({0}), torch::kLong); |
3173 | const auto log_probs = |
3174 | torch::randn({50, 3, 15}, torch::kDouble).log_softmax(2); |
3175 | const auto output = |
3176 | loss->forward(log_probs, targets, input_lengths, target_lengths); |
3177 | ASSERT_TRUE(output.ge(0).all().item<bool>()); |
3178 | ASSERT_TRUE(torch::allclose( |
3179 | -log_probs.sum(0).slice(1, 0, 1).view_as(output), output)); |
3180 | } |
3181 | |
3182 | TEST_F(ModulesTest, PoissonNLLLoss) { |
3183 | const auto input = torch::tensor({0.5, 1.5, 2.5}); |
3184 | const auto target = torch::tensor({1., 2., 3.}); |
3185 | const auto component_wise_loss = torch::exp(input) - target * input; |
3186 | { |
3187 | PoissonNLLLoss loss{PoissonNLLLossOptions().reduction(torch::kNone)}; |
3188 | ASSERT_TRUE( |
3189 | torch::allclose(component_wise_loss, loss->forward(input, target))); |
3190 | } |
3191 | { |
3192 | PoissonNLLLoss loss{PoissonNLLLossOptions().reduction(torch::kSum)}; |
3193 | ASSERT_TRUE(torch::allclose( |
3194 | torch::sum(component_wise_loss), loss->forward(input, target))); |
3195 | } |
3196 | { |
3197 | PoissonNLLLoss loss{PoissonNLLLossOptions().reduction(torch::kMean)}; |
3198 | ASSERT_TRUE(torch::allclose( |
3199 | torch::mean(component_wise_loss), loss->forward(input, target))); |
3200 | } |
3201 | } |
3202 | |
3203 | TEST_F(ModulesTest, MarginRankingLoss) { |
3204 | { |
3205 | MarginRankingLoss loss; |
3206 | const auto input1 = torch::randn(15) * 10; |
3207 | const auto input2 = torch::randn(15) * 10; |
3208 | const auto target = torch::randn(15).sign(); |
3209 | ASSERT_TRUE(torch::allclose( |
3210 | loss->forward(input1, input2, target), |
3211 | (-target * (input1 - input2)).clamp(0).mean())); |
3212 | } |
3213 | { |
3214 | MarginRankingLoss loss{ |
3215 | MarginRankingLossOptions().margin(0.5).reduction(torch::kSum)}; |
3216 | const auto input1 = torch::randn(15) * 10; |
3217 | const auto input2 = torch::randn(15) * 10; |
3218 | const auto target = torch::randn(15).sign(); |
3219 | const auto margin = 0.5; |
3220 | ASSERT_TRUE(torch::allclose( |
3221 | loss->forward(input1, input2, target), |
3222 | (-target * (input1 - input2) + margin).clamp(0).sum())); |
3223 | } |
3224 | { |
3225 | MarginRankingLoss loss{ |
3226 | MarginRankingLossOptions().margin(0.5).reduction(torch::kMean)}; |
3227 | const auto input1 = torch::randn(15) * 10; |
3228 | const auto input2 = torch::randn(15) * 10; |
3229 | const auto target = torch::randn(15).sign(); |
3230 | const auto margin = 0.5; |
3231 | ASSERT_TRUE(torch::allclose( |
3232 | loss->forward(input1, input2, target), |
3233 | (-target * (input1 - input2) + margin).clamp(0).mean())); |
3234 | } |
3235 | } |
3236 | |
3237 | TEST_F(ModulesTest, BCEWithLogitsLoss) { |
3238 | {// test BCE with logits raises if target and input are different size |
3239 | {const auto target = torch::rand(5); |
3240 | const auto input = torch::rand({5, 1}); |
3241 | ASSERT_THROWS_WITH( |
3242 | BCEWithLogitsLoss()(input, target), "must be the same as input size" ); |
3243 | } |
3244 | { |
3245 | const auto target = torch::rand({5, 1}); |
3246 | const auto input = torch::rand(5); |
3247 | ASSERT_THROWS_WITH( |
3248 | BCEWithLogitsLoss()(input, target), "must be the same as input size" ); |
3249 | } |
3250 | } |
3251 | { // test BCE with logits gives same result as sigmoid and bce loss |
3252 | auto sigmoid = Sigmoid(); |
3253 | |
3254 | auto target = torch::rand({64, 4}); |
3255 | auto output = torch::rand({64, 4}) - 0.5; |
3256 | |
3257 | ASSERT_TRUE(torch::allclose( |
3258 | BCEWithLogitsLoss()(output, target), BCELoss()(sigmoid(output), target))); |
3259 | |
3260 | auto weight = torch::rand(4); |
3261 | ASSERT_TRUE(torch::allclose( |
3262 | BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))( |
3263 | output, target), |
3264 | BCELoss(BCELossOptions().weight(weight))(sigmoid(output), target))); |
3265 | |
3266 | target = torch::zeros({4, 1}, torch::kFloat); |
3267 | output = torch::empty({4, 1}, torch::kFloat).fill_(-100); |
3268 | |
3269 | ASSERT_TRUE(torch::allclose( |
3270 | BCEWithLogitsLoss()(output, target), BCELoss()(sigmoid(output), target))); |
3271 | |
3272 | ASSERT_TRUE(torch::allclose( |
3273 | BCEWithLogitsLoss(BCEWithLogitsLossOptions().reduction(torch::kNone))( |
3274 | output, target), |
3275 | BCELoss(BCELossOptions().reduction(torch::kNone))( |
3276 | sigmoid(output), target))); |
3277 | |
3278 | weight = torch::rand({1}, torch::kFloat); |
3279 | ASSERT_TRUE(torch::allclose( |
3280 | BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))( |
3281 | output, target), |
3282 | BCELoss(BCELossOptions().weight(weight))(sigmoid(output), target))); |
3283 | } |
3284 | { // test BCE with logits has correct grad at zero |
3285 | const auto output = torch::zeros({3, 1}, torch::requires_grad()); |
3286 | const auto target = torch::zeros({3, 1}); |
3287 | BCEWithLogitsLoss(BCEWithLogitsLossOptions().reduction(torch::kSum))( |
3288 | output, target) |
3289 | .backward(); |
3290 | const auto expected_grad = torch::empty({3, 1}).fill_(0.5); |
3291 | ASSERT_TRUE(torch::allclose(output.grad(), expected_grad)); |
3292 | } |
3293 | { // test BCE with logits broadcasts weights |
3294 | const auto target = torch::rand({16, 4}); |
3295 | const auto output = torch::rand({16, 4}) - 0.5; |
3296 | |
3297 | auto weight = torch::rand(4); |
3298 | auto out1 = BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))( |
3299 | output, target); |
3300 | |
3301 | weight = weight.expand({16, 4}).contiguous(); |
3302 | auto out2 = BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))( |
3303 | output, target); |
3304 | |
3305 | ASSERT_TRUE(torch::allclose(out1, out2)); |
3306 | |
3307 | weight = torch::rand({16, 1}); |
3308 | out1 = BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))( |
3309 | output, target); |
3310 | |
3311 | weight = weight.expand({16, 4}).contiguous(); |
3312 | out2 = BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))( |
3313 | output, target); |
3314 | |
3315 | ASSERT_TRUE(torch::allclose(out1, out2)); |
3316 | } |
3317 | { // test BCE with logits ones in pos weights are the same as none |
3318 | const auto target = torch::rand({64, 4}); |
3319 | const auto output = torch::rand({64, 4}) - 0.5; |
3320 | const auto pos_weight = torch::ones({64, 4}); |
3321 | |
3322 | ASSERT_TRUE(torch::allclose( |
3323 | BCEWithLogitsLoss()(output, target), |
3324 | BCEWithLogitsLoss(BCEWithLogitsLossOptions().pos_weight(pos_weight))( |
3325 | output, target))); |
3326 | } |
3327 | { // test BCE with logits broadcasts pos weights |
3328 | const auto target = torch::rand({64, 4}); |
3329 | const auto output = torch::rand({64, 4}) - 0.5; |
3330 | const auto pos_weight = torch::rand(4); |
3331 | const auto out1 = BCEWithLogitsLoss( |
3332 | BCEWithLogitsLossOptions().pos_weight(pos_weight))(output, target); |
3333 | |
3334 | const auto pos_weight1 = pos_weight.expand({1, 4}); |
3335 | const auto out2 = BCEWithLogitsLoss( |
3336 | BCEWithLogitsLossOptions().pos_weight(pos_weight))(output, target); |
3337 | |
3338 | const auto pos_weight2 = pos_weight.expand({64, 4}); |
3339 | const auto out3 = BCEWithLogitsLoss( |
3340 | BCEWithLogitsLossOptions().pos_weight(pos_weight))(output, target); |
3341 | |
3342 | ASSERT_TRUE(torch::allclose(out1, out2)); |
3343 | ASSERT_TRUE(torch::allclose(out1, out3)); |
3344 | } |
3345 | { // test BCE with logits with pos weight has correct grad at zero |
3346 | const auto output = torch::zeros({3, 1}, torch::requires_grad()); |
3347 | const auto target = torch::zeros({3, 1}); |
3348 | const auto pos_weight = torch::ones({3, 1}); |
3349 | BCEWithLogitsLoss( |
3350 | BCEWithLogitsLossOptions().pos_weight(pos_weight).reduction(torch::kSum))( |
3351 | output, target) |
3352 | .backward(); |
3353 | const auto expected_grad = torch::empty({3, 1}).fill_(0.5); |
3354 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
3355 | const auto grad = output.grad(); |
3356 | ASSERT_TRUE(torch::allclose(grad, expected_grad)); |
3357 | } |
3358 | { // test BCE with logits stability |
3359 | const auto output = torch::tensor({0., -120.}); |
3360 | const auto target = torch::tensor({0., 1.}); |
3361 | const auto pos_weight = torch::tensor({1., 1.}); |
3362 | |
3363 | const auto out1 = BCEWithLogitsLoss()(output, target); |
3364 | ASSERT_TRUE(torch::isfinite(out1).all().item<bool>()); |
3365 | |
3366 | const auto out2 = BCEWithLogitsLoss( |
3367 | BCEWithLogitsLossOptions().pos_weight(pos_weight))(output, target); |
3368 | ASSERT_TRUE(torch::isfinite(out2).all().item<bool>()); |
3369 | } |
3370 | } |
3371 | |
3372 | namespace detail { |
3373 | |
3374 | namespace F = torch::nn::functional; |
3375 | |
3376 | torch::Tensor _batchmatmul(const torch::Tensor& a, const torch::Tensor& b) { |
3377 | TORCH_INTERNAL_ASSERT(a.size(0) == b.size(0)); |
3378 | TORCH_INTERNAL_ASSERT(a.size(1) == b.size(1)); |
3379 | auto retval = torch::zeros( |
3380 | {a.size(0), a.size(1), a.size(2), b.size(3)}, torch::kFloat32); |
3381 | for (const auto i : c10::irange(a.size(0))) { |
3382 | for (const auto j : c10::irange(a.size(1))) { |
3383 | retval[i][j] = torch::matmul(a[i][j], b[i][j]); |
3384 | } |
3385 | } |
3386 | return retval; |
3387 | } |
3388 | |
3389 | torch::Tensor _softmax(const torch::Tensor& x) { |
3390 | auto output = torch::zeros(x.sizes()); |
3391 | for (const auto i : c10::irange(x.size(0))) { |
3392 | for (const auto j : c10::irange(x.size(1))) { |
3393 | for (const auto k : c10::irange(x.size(2))) { |
3394 | const auto& x_curr = x[i][j][k]; |
3395 | const auto e_x = torch::exp(x_curr - torch::max(x_curr)); |
3396 | output[i][j][k] = e_x / torch::sum(e_x); |
3397 | } |
3398 | } |
3399 | } |
3400 | return output; |
3401 | } |
3402 | |
3403 | std::tuple<torch::Tensor, torch::Tensor> _scaled_dot_attn_ref( |
3404 | const torch::Tensor& Q, |
3405 | const torch::Tensor& K, |
3406 | const torch::Tensor& V, |
3407 | at::IntArrayRef dims, |
3408 | const torch::Tensor& unseen_mask = {}, |
3409 | const torch::Tensor& key_padding_mask = {}, |
3410 | bool average_attn_weights = true) { |
3411 | auto QKT = _batchmatmul(Q, K.permute({0, 1, 3, 2}) / std::sqrt(dims[3])); |
3412 | const auto b1 = QKT.size(0); |
3413 | const auto b2 = QKT.size(1); |
3414 | const auto s1 = QKT.size(2); |
3415 | const auto s2 = QKT.size(3); |
3416 | if (unseen_mask.defined() || key_padding_mask.defined()) { |
3417 | for (const auto i : c10::irange(b1)) { |
3418 | for (const auto j : c10::irange(b2)) { |
3419 | for (const auto m : c10::irange(s1)) { |
3420 | for (const auto n : c10::irange(s2)) { |
3421 | if (unseen_mask.defined() && |
3422 | unseen_mask[m][n].item<double>() == 0) { |
3423 | QKT[i][j][m][n] = -std::numeric_limits<double>::infinity(); |
3424 | } |
3425 | if (key_padding_mask.defined() && |
3426 | key_padding_mask[i][n].item<double>() != 0) { |
3427 | QKT[i][j][m][n] = -std::numeric_limits<double>::infinity(); |
3428 | } |
3429 | } |
3430 | } |
3431 | } |
3432 | } |
3433 | } |
3434 | auto reference = _softmax(QKT); |
3435 | auto ref_attn_weight = reference; |
3436 | if (average_attn_weights) { |
3437 | // NOLINTNEXTLINE(bugprone-argument-comment) |
3438 | ref_attn_weight = torch::sum(ref_attn_weight, /*axis=*/1) / b2; |
3439 | } |
3440 | reference = _batchmatmul(reference, V); |
3441 | return std::tie(reference, ref_attn_weight); |
3442 | } |
3443 | |
3444 | torch::Tensor _split_heads_ref( |
3445 | const torch::Tensor& X, |
3446 | at::IntArrayRef dims, |
3447 | int nheads, |
3448 | int d_head) { |
3449 | auto X_split = X.reshape({dims[0], dims[1], nheads, d_head}); |
3450 | auto X_split_transposed = X_split.permute({0, 2, 1, 3}); |
3451 | return X_split_transposed.reshape({dims[0], nheads, dims[1], d_head}); |
3452 | } |
3453 | |
3454 | torch::Tensor _combine_heads_ref( |
3455 | const torch::Tensor& X, |
3456 | at::IntArrayRef dims, |
3457 | int nheads, |
3458 | int d_head) { |
3459 | auto X_transposed = X.permute({0, 2, 1, 3}); |
3460 | auto reference = X_transposed.reshape({dims[0], dims[1], nheads * d_head}); |
3461 | return reference; |
3462 | } |
3463 | |
3464 | torch::Tensor _fc( |
3465 | torch::Tensor X, |
3466 | torch::Tensor X_weight, |
3467 | torch::Tensor X_bias) { |
3468 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
3469 | auto X_fc_b = X_bias; |
3470 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
3471 | auto X_fc_w = X_weight; |
3472 | return torch::matmul(X, torch::t(X_fc_w)) + X_fc_b; |
3473 | } |
3474 | |
3475 | void _multihead_attn_test_helper( |
3476 | bool add_key_padding_mask = false, |
3477 | bool add_bias_kv = false, |
3478 | bool add_zero_attn = false, |
3479 | bool saved_kv = false, |
3480 | bool same_embed_dim = false, |
3481 | bool average_attn_weights = true) { |
3482 | std::random_device device; |
3483 | std::mt19937 generator(device()); |
3484 | std::uniform_int_distribution<int> d_2_10(2, 10); |
3485 | std::uniform_int_distribution<int> d_3_10(3, 10); |
3486 | bool registration_checked = false; |
3487 | for (const auto i : c10::irange(100)) { |
3488 | (void)i; // Suppress unused variable warning |
3489 | const auto batch_sz = d_2_10(generator); |
3490 | const auto seq_len = d_2_10(generator); |
3491 | const auto d_head = d_3_10(generator); |
3492 | const auto nheads = d_3_10(generator); |
3493 | const auto d_model = d_head * nheads; |
3494 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
3495 | int kv_dim; |
3496 | if (same_embed_dim) { |
3497 | kv_dim = d_model; |
3498 | } else { |
3499 | std::uniform_int_distribution<int> d(5, 20); |
3500 | kv_dim = d(generator); |
3501 | while (kv_dim == d_model) { |
3502 | kv_dim = d(generator); |
3503 | } |
3504 | } |
3505 | std::vector<int64_t> dims{batch_sz, seq_len, kv_dim}; |
3506 | torch::Tensor saved_k; |
3507 | torch::Tensor saved_k_tensor; |
3508 | torch::Tensor saved_v; |
3509 | torch::Tensor saved_v_tensor; |
3510 | if (saved_kv) { |
3511 | saved_k = torch::rand({batch_sz * nheads, seq_len, d_head}); |
3512 | saved_k_tensor = saved_k; |
3513 | saved_v = torch::rand({batch_sz * nheads, seq_len, d_head}); |
3514 | saved_v_tensor = saved_v; |
3515 | } |
3516 | torch::Tensor key_padding_mask; |
3517 | torch::Tensor key_padding_mask_tensor; |
3518 | if (add_key_padding_mask) { |
3519 | const auto seq_mask = torch::randint(0, 2, {1, seq_len}); |
3520 | key_padding_mask = seq_mask.repeat({batch_sz, 1}) == 1; |
3521 | key_padding_mask_tensor = key_padding_mask; |
3522 | } |
3523 | const auto decoder_state = torch::rand({batch_sz, d_model}); |
3524 | const torch::Tensor K = torch::rand(dims); |
3525 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
3526 | const torch::Tensor V = K; |
3527 | const torch::Tensor Q = |
3528 | decoder_state.clone().resize_({batch_sz, 1, d_model}); |
3529 | auto attn_mask = torch::randint(0, 2, {1, seq_len}, torch::kFloat); |
3530 | const torch::Tensor attn_mask_tensor = attn_mask.clone(); |
3531 | attn_mask_tensor.masked_fill_( |
3532 | attn_mask_tensor == 0, -std::numeric_limits<double>::infinity()); |
3533 | attn_mask_tensor.masked_fill_(attn_mask_tensor > 0, double(0.0)); |
3534 | |
3535 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
3536 | const torch::Tensor decoder_state_tensor = decoder_state; |
3537 | const torch::Tensor source_hid_tensor = K.transpose(0, 1); |
3538 | |
3539 | const auto options = MultiheadAttentionOptions(d_model, nheads) |
3540 | .add_bias_kv(add_bias_kv) |
3541 | .add_zero_attn(add_zero_attn) |
3542 | .kdim(kv_dim) |
3543 | .vdim(kv_dim); |
3544 | const auto multihead_attn_module = MultiheadAttention(options); |
3545 | |
3546 | if (!registration_checked) { |
3547 | // make sure parameters are all registered correctly |
3548 | auto named_parameters = multihead_attn_module->named_parameters(); |
3549 | if (same_embed_dim) { |
3550 | ASSERT_TRUE(named_parameters.contains("in_proj_weight" )); |
3551 | } else { |
3552 | ASSERT_TRUE(named_parameters.contains("q_proj_weight" )); |
3553 | ASSERT_TRUE(named_parameters.contains("k_proj_weight" )); |
3554 | ASSERT_TRUE(named_parameters.contains("v_proj_weight" )); |
3555 | } |
3556 | if (add_bias_kv) { |
3557 | ASSERT_TRUE(named_parameters.contains("bias_k" )); |
3558 | ASSERT_TRUE(named_parameters.contains("bias_v" )); |
3559 | } |
3560 | // make sure sub modules are all registered correctly |
3561 | auto submodules = multihead_attn_module->named_children(); |
3562 | ASSERT_TRUE(submodules.contains("out_proj" )); |
3563 | registration_checked = true; |
3564 | } |
3565 | |
3566 | torch::Tensor bias_k; |
3567 | torch::Tensor bias_v; |
3568 | if (add_bias_kv) { |
3569 | bias_k = multihead_attn_module->bias_k.detach(); |
3570 | bias_v = multihead_attn_module->bias_v.detach(); |
3571 | } else { |
3572 | bias_k.reset(); |
3573 | bias_v.reset(); |
3574 | } |
3575 | |
3576 | torch::Tensor _Q = decoder_state_tensor.unsqueeze(1).transpose(0, 1); |
3577 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
3578 | torch::Tensor _V = source_hid_tensor; |
3579 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
3580 | torch::Tensor _K = source_hid_tensor; |
3581 | |
3582 | torch::Tensor result; |
3583 | torch::Tensor result_weight; |
3584 | if (multihead_attn_module->_qkv_same_embed_dim) { |
3585 | std::tie(result, result_weight) = F::multi_head_attention_forward( |
3586 | _Q, |
3587 | _K, |
3588 | _V, |
3589 | F::MultiheadAttentionForwardFuncOptions( |
3590 | /*embed_dim_to_check=*/d_model, |
3591 | /*num_heads=*/nheads, |
3592 | /*in_proj_weight=*/multihead_attn_module->in_proj_weight, |
3593 | /*in_proj_bias=*/multihead_attn_module->in_proj_bias, |
3594 | /*bias_k=*/multihead_attn_module->bias_k, |
3595 | /*bias_v=*/multihead_attn_module->bias_v, |
3596 | /*add_zero_attn=*/multihead_attn_module->options.add_zero_attn(), |
3597 | /*dropout_p=*/multihead_attn_module->options.dropout(), |
3598 | /*out_proj_weight=*/multihead_attn_module->out_proj->weight, |
3599 | /*out_proj_bias=*/multihead_attn_module->out_proj->bias) |
3600 | .training(multihead_attn_module->is_training()) |
3601 | .key_padding_mask(key_padding_mask_tensor) |
3602 | .need_weights(true) |
3603 | .attn_mask(attn_mask_tensor) |
3604 | .static_k(saved_k_tensor) |
3605 | .static_v(saved_v_tensor) |
3606 | .average_attn_weights(average_attn_weights)); |
3607 | } else { |
3608 | std::tie(result, result_weight) = F::multi_head_attention_forward( |
3609 | _Q, |
3610 | _K, |
3611 | _V, |
3612 | F::MultiheadAttentionForwardFuncOptions( |
3613 | /*embed_dim_to_check=*/d_model, |
3614 | /*num_heads=*/nheads, |
3615 | /*in_proj_weight=*/{}, |
3616 | /*in_proj_bias=*/multihead_attn_module->in_proj_bias, |
3617 | /*bias_k=*/multihead_attn_module->bias_k, |
3618 | /*bias_v=*/multihead_attn_module->bias_v, |
3619 | /*add_zero_attn=*/multihead_attn_module->options.add_zero_attn(), |
3620 | /*dropout_p=*/multihead_attn_module->options.dropout(), |
3621 | /*out_proj_weight=*/multihead_attn_module->out_proj->weight, |
3622 | /*out_proj_bias=*/multihead_attn_module->out_proj->bias) |
3623 | .training(multihead_attn_module->is_training()) |
3624 | .key_padding_mask(key_padding_mask_tensor) |
3625 | .need_weights(true) |
3626 | .attn_mask(attn_mask_tensor) |
3627 | .use_separate_proj_weight(true) |
3628 | .q_proj_weight(multihead_attn_module->q_proj_weight) |
3629 | .k_proj_weight(multihead_attn_module->k_proj_weight) |
3630 | .v_proj_weight(multihead_attn_module->v_proj_weight) |
3631 | .static_k(saved_k_tensor) |
3632 | .static_v(saved_v_tensor) |
3633 | .average_attn_weights(average_attn_weights)); |
3634 | } |
3635 | result = result.squeeze(0).detach(); |
3636 | torch::Tensor q_proj_weight; |
3637 | torch::Tensor k_proj_weight; |
3638 | torch::Tensor v_proj_weight; |
3639 | if (multihead_attn_module->_qkv_same_embed_dim) { |
3640 | q_proj_weight = |
3641 | multihead_attn_module->in_proj_weight.slice(/*dim=*/0, 0, d_model); |
3642 | k_proj_weight = multihead_attn_module->in_proj_weight.slice( |
3643 | /*dim=*/0, d_model, (d_model * 2)); |
3644 | v_proj_weight = |
3645 | multihead_attn_module->in_proj_weight.slice(/*dim=*/0, (d_model * 2)); |
3646 | } else { |
3647 | q_proj_weight = multihead_attn_module->q_proj_weight; |
3648 | k_proj_weight = multihead_attn_module->k_proj_weight; |
3649 | v_proj_weight = multihead_attn_module->v_proj_weight; |
3650 | } |
3651 | auto Q_fc = |
3652 | _fc(Q, |
3653 | q_proj_weight, |
3654 | multihead_attn_module->in_proj_bias.slice(/*dim=*/0, 0, d_model)); |
3655 | auto K_fc = |
3656 | _fc(K, |
3657 | k_proj_weight, |
3658 | multihead_attn_module->in_proj_bias.slice( |
3659 | /*dim=*/0, d_model, (d_model * 2))); |
3660 | auto V_fc = _fc( |
3661 | V, |
3662 | v_proj_weight, |
3663 | multihead_attn_module->in_proj_bias.slice(/*dim=*/0, (d_model * 2))); |
3664 | |
3665 | if (add_bias_kv) { |
3666 | K_fc = torch::cat( |
3667 | {K_fc, |
3668 | bias_k.repeat({K_fc.size(0) / bias_k.size(0), 1, 1} /*, axis=0*/)}, |
3669 | /*dim=*/1); |
3670 | V_fc = torch::cat( |
3671 | {V_fc, |
3672 | bias_v.repeat({V_fc.size(0) / bias_v.size(0), 1, 1} /*, axis=0*/)}, |
3673 | /*dim=*/1); |
3674 | if (attn_mask.defined()) { |
3675 | attn_mask = torch::cat({attn_mask, torch::ones({1, 1})}, /*dim=*/1); |
3676 | } |
3677 | if (key_padding_mask.defined()) { |
3678 | key_padding_mask = torch::cat( |
3679 | {key_padding_mask, torch::full({batch_sz, 1}, false, torch::kBool)}, |
3680 | /*dim=*/1); |
3681 | } |
3682 | dims[1] += 1; |
3683 | } |
3684 | const auto Q_split = |
3685 | _split_heads_ref(Q_fc, {batch_sz, 1, d_model}, nheads, d_head); |
3686 | torch::Tensor K_split; |
3687 | if (saved_k.defined()) { |
3688 | K_split = saved_k.reshape({dims[0], nheads, dims[1], d_head}); |
3689 | } else { |
3690 | K_split = _split_heads_ref(K_fc, dims, nheads, d_head); |
3691 | } |
3692 | torch::Tensor V_split; |
3693 | if (saved_v.defined()) { |
3694 | V_split = saved_v.reshape({dims[0], nheads, dims[1], d_head}); |
3695 | } else { |
3696 | V_split = _split_heads_ref(V_fc, dims, nheads, d_head); |
3697 | } |
3698 | if (add_zero_attn) { |
3699 | dims[1] += 1; |
3700 | K_split = torch::cat( |
3701 | {K_split, |
3702 | torch::zeros( |
3703 | {K_split.size(0), K_split.size(1), 1, K_split.size(3)})}, |
3704 | /*dim=*/2); |
3705 | V_split = torch::cat( |
3706 | {V_split, |
3707 | torch::zeros( |
3708 | {V_split.size(0), V_split.size(1), 1, V_split.size(3)})}, |
3709 | /*dim=*/2); |
3710 | if (attn_mask.defined()) { |
3711 | attn_mask = torch::cat({attn_mask, torch::ones({1, 1})}, /*dim=*/1); |
3712 | } |
3713 | if (key_padding_mask.defined()) { |
3714 | key_padding_mask = torch::cat( |
3715 | {key_padding_mask, torch::full({batch_sz, 1}, false, torch::kBool)}, |
3716 | /*dim=*/1); |
3717 | } |
3718 | } |
3719 | torch::Tensor attn_heads; |
3720 | torch::Tensor ref_attn_weight; |
3721 | std::tie(attn_heads, ref_attn_weight) = _scaled_dot_attn_ref( |
3722 | Q_split, |
3723 | K_split, |
3724 | V_split, |
3725 | Q_split.sizes(), |
3726 | attn_mask, |
3727 | key_padding_mask, |
3728 | average_attn_weights); |
3729 | const auto combined_attn_heads = |
3730 | _combine_heads_ref(attn_heads, {batch_sz, 1}, nheads, d_head); |
3731 | auto reference = |
3732 | _fc(combined_attn_heads, |
3733 | multihead_attn_module->out_proj->weight, |
3734 | multihead_attn_module->out_proj->bias); |
3735 | // NOLINTNEXTLINE(bugprone-argument-comment) |
3736 | reference = torch::squeeze(reference, /*axis=*/1); |
3737 | |
3738 | // result = reference |
3739 | ASSERT_EQ(result.sizes(), std::vector<int64_t>({batch_sz, d_model})); |
3740 | ASSERT_TRUE( |
3741 | torch::allclose(result, reference, 1e-5, 1e-5, /*equal_nan=*/true)); |
3742 | |
3743 | // result_weight = ref_attn_weight |
3744 | result_weight = result_weight.detach(); |
3745 | ASSERT_EQ(result_weight.sizes(), ref_attn_weight.sizes()); |
3746 | ASSERT_TRUE(torch::allclose( |
3747 | result_weight, ref_attn_weight, 1e-5, 1e-5, /*equal_nan=*/true)); |
3748 | } |
3749 | } |
3750 | } // namespace detail |
3751 | |
3752 | TEST_F(ModulesTest, MultiheadAttention) { |
3753 | using namespace ::detail; |
3754 | |
3755 | for (auto average_attn_weights : {false, true}) { |
3756 | // test_multihead_attn_add_zero_attn |
3757 | _multihead_attn_test_helper( |
3758 | /*add_key_padding_mask=*/false, |
3759 | /*add_bias_kv=*/false, |
3760 | /*add_zero_attn=*/true, |
3761 | /*saved_kv=*/false, |
3762 | /*same_embed_dim=*/false, |
3763 | /*average_attn_weights=*/average_attn_weights); |
3764 | |
3765 | // test_multihead_attn_add_bias_kv |
3766 | _multihead_attn_test_helper( |
3767 | /*add_key_padding_mask=*/false, |
3768 | /*add_bias_kv=*/true, |
3769 | /*add_zero_attn=*/false, |
3770 | /*saved_kv=*/false, |
3771 | /*same_embed_dim=*/false, |
3772 | /*average_attn_weights=*/average_attn_weights); |
3773 | |
3774 | // test_multihead_attn_no_masking(): |
3775 | _multihead_attn_test_helper(); |
3776 | |
3777 | // test_multihead_attn_key_padding_mask |
3778 | _multihead_attn_test_helper( |
3779 | /*add_key_padding_mask=*/true, |
3780 | /*add_bias_kv=*/false, |
3781 | /*add_zero_attn=*/false, |
3782 | /*saved_kv=*/false, |
3783 | /*same_embed_dim=*/false, |
3784 | /*average_attn_weights=*/average_attn_weights); |
3785 | |
3786 | // test_multihead_attn_saved_kv |
3787 | _multihead_attn_test_helper( |
3788 | /*add_key_padding_mask=*/false, |
3789 | /*add_bias_kv=*/false, |
3790 | /*add_zero_attn=*/false, |
3791 | /*saved_kv=*/true, |
3792 | /*same_embed_dim=*/false, |
3793 | /*average_attn_weights=*/average_attn_weights); |
3794 | |
3795 | // test_multihead_attn_add_bias_kv_zero_attn |
3796 | _multihead_attn_test_helper( |
3797 | /*add_key_padding_mask=*/true, |
3798 | /*add_bias_kv=*/true, |
3799 | /*add_zero_attn=*/true, |
3800 | /*saved_kv=*/false, |
3801 | /*same_embed_dim=*/false, |
3802 | /*average_attn_weights=*/average_attn_weights); |
3803 | |
3804 | // test_multihead_attn_all_arguments1 |
3805 | _multihead_attn_test_helper( |
3806 | /*add_key_padding_mask=*/true, |
3807 | /*add_bias_kv=*/false, |
3808 | /*add_zero_attn=*/true, |
3809 | /*saved_kv=*/true, |
3810 | /*same_embed_dim=*/false, |
3811 | /*average_attn_weights=*/average_attn_weights); |
3812 | |
3813 | ASSERT_THROWS_WITH( |
3814 | // test_multihead_attn_all_arguments2 |
3815 | _multihead_attn_test_helper( |
3816 | /*add_key_padding_mask=*/true, |
3817 | /*add_bias_kv=*/true, |
3818 | /*add_zero_attn=*/true, |
3819 | /*saved_kv=*/true, |
3820 | /*same_embed_dim=*/false, |
3821 | /*average_attn_weights=*/average_attn_weights), |
3822 | "bias cannot be added to static key" ); |
3823 | |
3824 | // test_multihead_attn_all_arguments3 |
3825 | _multihead_attn_test_helper( |
3826 | /*add_key_padding_mask=*/true, |
3827 | /*add_bias_kv=*/false, |
3828 | /*add_zero_attn=*/true, |
3829 | /*saved_kv=*/true, |
3830 | /*same_embed_dim=*/true, |
3831 | /*average_attn_weights=*/average_attn_weights); |
3832 | } |
3833 | } |
3834 | |
3835 | TEST_F(ModulesTest, PrettyPrintIdentity) { |
3836 | ASSERT_EQ(c10::str(Identity()), "torch::nn::Identity()" ); |
3837 | } |
3838 | |
3839 | TEST_F(ModulesTest, PrettyPrintFlatten) { |
3840 | ASSERT_EQ(c10::str(Flatten()), "torch::nn::Flatten(start_dim=1, end_dim=-1)" ); |
3841 | ASSERT_EQ( |
3842 | c10::str(Flatten(FlattenOptions().start_dim(2).end_dim(4))), |
3843 | "torch::nn::Flatten(start_dim=2, end_dim=4)" ); |
3844 | } |
3845 | |
3846 | TEST_F(ModulesTest, PrettyPrintUnflatten) { |
3847 | ASSERT_EQ( |
3848 | c10::str(Unflatten(UnflattenOptions(0, {2, 2}))), |
3849 | "torch::nn::Unflatten(dim=0, unflattened_size={2, 2})" ); |
3850 | ASSERT_EQ( |
3851 | c10::str(Unflatten(UnflattenOptions( |
3852 | "B" , |
3853 | {std::pair<std::string, int64_t>{"B1" , 2}, |
3854 | std::pair<std::string, int64_t>{"B2" , 2}}))), |
3855 | "torch::nn::Unflatten(dim=\"B\", unflattened_size={{\"B1\", 2}, {\"B2\", 2}})" ); |
3856 | } |
3857 | |
3858 | TEST_F(ModulesTest, ReflectionPad1d) { |
3859 | { |
3860 | ReflectionPad1d m(ReflectionPad1dOptions(2)); |
3861 | auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4}); |
3862 | auto output = m(input); |
3863 | auto expected = torch::tensor( |
3864 | {{{2., 1., 0., 1., 2., 3., 2., 1.}, {6., 5., 4., 5., 6., 7., 6., 5.}}}, |
3865 | torch::kFloat); |
3866 | ASSERT_TRUE(output.allclose(expected)); |
3867 | } |
3868 | { |
3869 | ReflectionPad1d m(ReflectionPad1dOptions({3, 1})); |
3870 | auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4}); |
3871 | auto output = m(input); |
3872 | auto expected = torch::tensor( |
3873 | {{{3., 2., 1., 0., 1., 2., 3., 2.}, {7., 6., 5., 4., 5., 6., 7., 6.}}}, |
3874 | torch::kFloat); |
3875 | ASSERT_TRUE(output.allclose(expected)); |
3876 | } |
3877 | } |
3878 | |
3879 | TEST_F(ModulesTest, ReflectionPad2d) { |
3880 | { |
3881 | ReflectionPad2d m(ReflectionPad2dOptions(2)); |
3882 | auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3}); |
3883 | auto output = m(input); |
3884 | auto expected = torch::tensor( |
3885 | {{{{8., 7., 6., 7., 8., 7., 6.}, |
3886 | {5., 4., 3., 4., 5., 4., 3.}, |
3887 | {2., 1., 0., 1., 2., 1., 0.}, |
3888 | {5., 4., 3., 4., 5., 4., 3.}, |
3889 | {8., 7., 6., 7., 8., 7., 6.}, |
3890 | {5., 4., 3., 4., 5., 4., 3.}, |
3891 | {2., 1., 0., 1., 2., 1., 0.}}}}, |
3892 | torch::kFloat); |
3893 | ASSERT_TRUE(output.allclose(expected)); |
3894 | } |
3895 | { |
3896 | ReflectionPad2d m(ReflectionPad2dOptions({1, 1, 2, 0})); |
3897 | auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3}); |
3898 | auto output = m(input); |
3899 | auto expected = torch::tensor( |
3900 | {{{{7., 6., 7., 8., 7.}, |
3901 | {4., 3., 4., 5., 4.}, |
3902 | {1., 0., 1., 2., 1.}, |
3903 | {4., 3., 4., 5., 4.}, |
3904 | {7., 6., 7., 8., 7.}}}}, |
3905 | torch::kFloat); |
3906 | ASSERT_TRUE(output.allclose(expected)); |
3907 | } |
3908 | } |
3909 | |
3910 | TEST_F(ModulesTest, ReflectionPad3d) { |
3911 | { |
3912 | ReflectionPad3d m(ReflectionPad3dOptions(1)); |
3913 | auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2}); |
3914 | auto output = m(input); |
3915 | auto expected = torch::tensor( |
3916 | {{{{{7., 6., 7., 6.}, |
3917 | {5., 4., 5., 4.}, |
3918 | {7., 6., 7., 6.}, |
3919 | {5., 4., 5., 4.}}, |
3920 | {{3., 2., 3., 2.}, |
3921 | {1., 0., 1., 0.}, |
3922 | {3., 2., 3., 2.}, |
3923 | {1., 0., 1., 0.}}, |
3924 | {{7., 6., 7., 6.}, |
3925 | {5., 4., 5., 4.}, |
3926 | {7., 6., 7., 6.}, |
3927 | {5., 4., 5., 4.}}, |
3928 | {{3., 2., 3., 2.}, |
3929 | {1., 0., 1., 0.}, |
3930 | {3., 2., 3., 2.}, |
3931 | {1., 0., 1., 0.}}}}}, |
3932 | torch::kFloat); |
3933 | ASSERT_TRUE(output.allclose(expected)); |
3934 | } |
3935 | { |
3936 | ReflectionPad3d m(ReflectionPad3dOptions({0, 1, 1, 0, 1, 2})); |
3937 | auto input = torch::arange(16, torch::kFloat).reshape({1, 1, 4, 2, 2}); |
3938 | auto output = m(input); |
3939 | auto expected = torch::tensor( |
3940 | {{{{{6., 7., 6.}, {4., 5., 4.}, {6., 7., 6.}}, |
3941 | {{2., 3., 2.}, {0., 1., 0.}, {2., 3., 2.}}, |
3942 | {{6., 7., 6.}, {4., 5., 4.}, {6., 7., 6.}}, |
3943 | {{10., 11., 10.}, {8., 9., 8.}, {10., 11., 10.}}, |
3944 | {{14., 15., 14.}, {12., 13., 12.}, {14., 15., 14.}}, |
3945 | {{10., 11., 10.}, {8., 9., 8.}, {10., 11., 10.}}, |
3946 | {{6., 7., 6.}, {4., 5., 4.}, {6., 7., 6.}}}}}, |
3947 | torch::kFloat); |
3948 | ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 7, 3, 3})); |
3949 | ASSERT_TRUE(output.allclose(expected)); |
3950 | } |
3951 | } |
3952 | TEST_F(ModulesTest, ReplicationPad1d) { |
3953 | { |
3954 | ReplicationPad1d m(ReplicationPad1dOptions(2)); |
3955 | auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4}); |
3956 | auto output = m(input); |
3957 | auto expected = torch::tensor( |
3958 | {{{0., 0., 0., 1., 2., 3., 3., 3.}, {4., 4., 4., 5., 6., 7., 7., 7.}}}, |
3959 | torch::kFloat); |
3960 | ASSERT_TRUE(output.allclose(expected)); |
3961 | } |
3962 | { |
3963 | ReplicationPad1d m(ReplicationPad1dOptions({3, 1})); |
3964 | auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4}); |
3965 | auto output = m(input); |
3966 | auto expected = torch::tensor( |
3967 | {{{0., 0., 0., 0., 1., 2., 3., 3.}, {4., 4., 4., 4., 5., 6., 7., 7.}}}, |
3968 | torch::kFloat); |
3969 | ASSERT_TRUE(output.allclose(expected)); |
3970 | } |
3971 | } |
3972 | |
3973 | TEST_F(ModulesTest, ReplicationPad2d) { |
3974 | { |
3975 | ReplicationPad2d m(ReplicationPad2dOptions(2)); |
3976 | auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3}); |
3977 | auto output = m(input); |
3978 | auto expected = torch::tensor( |
3979 | {{{{0., 0., 0., 1., 2., 2., 2.}, |
3980 | {0., 0., 0., 1., 2., 2., 2.}, |
3981 | {0., 0., 0., 1., 2., 2., 2.}, |
3982 | {3., 3., 3., 4., 5., 5., 5.}, |
3983 | {6., 6., 6., 7., 8., 8., 8.}, |
3984 | {6., 6., 6., 7., 8., 8., 8.}, |
3985 | {6., 6., 6., 7., 8., 8., 8.}}}}, |
3986 | torch::kFloat); |
3987 | ASSERT_TRUE(output.allclose(expected)); |
3988 | } |
3989 | { |
3990 | ReplicationPad2d m(ReplicationPad2dOptions({1, 1, 2, 0})); |
3991 | auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3}); |
3992 | auto output = m(input); |
3993 | auto expected = torch::tensor( |
3994 | {{{{0., 0., 1., 2., 2.}, |
3995 | {0., 0., 1., 2., 2.}, |
3996 | {0., 0., 1., 2., 2.}, |
3997 | {3., 3., 4., 5., 5.}, |
3998 | {6., 6., 7., 8., 8.}}}}, |
3999 | torch::kFloat); |
4000 | ASSERT_TRUE(output.allclose(expected)); |
4001 | } |
4002 | } |
4003 | |
4004 | TEST_F(ModulesTest, ReplicationPad3d) { |
4005 | { |
4006 | ReplicationPad3d m(ReplicationPad3dOptions(1)); |
4007 | auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2}); |
4008 | auto output = m(input); |
4009 | auto expected = torch::tensor( |
4010 | {{{{{0., 0., 1., 1.}, |
4011 | {0., 0., 1., 1.}, |
4012 | {2., 2., 3., 3.}, |
4013 | {2., 2., 3., 3.}}, |
4014 | {{0., 0., 1., 1.}, |
4015 | {0., 0., 1., 1.}, |
4016 | {2., 2., 3., 3.}, |
4017 | {2., 2., 3., 3.}}, |
4018 | {{4., 4., 5., 5.}, |
4019 | {4., 4., 5., 5.}, |
4020 | {6., 6., 7., 7.}, |
4021 | {6., 6., 7., 7.}}, |
4022 | {{4., 4., 5., 5.}, |
4023 | {4., 4., 5., 5.}, |
4024 | {6., 6., 7., 7.}, |
4025 | {6., 6., 7., 7.}}}}}, |
4026 | torch::kFloat); |
4027 | ASSERT_TRUE(output.allclose(expected)); |
4028 | } |
4029 | { |
4030 | ReplicationPad3d m(ReplicationPad3dOptions({1, 2, 1, 2, 1, 2})); |
4031 | auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2}); |
4032 | auto output = m(input); |
4033 | auto expected = torch::tensor( |
4034 | {{{{{0., 0., 1., 1., 1.}, |
4035 | {0., 0., 1., 1., 1.}, |
4036 | {2., 2., 3., 3., 3.}, |
4037 | {2., 2., 3., 3., 3.}, |
4038 | {2., 2., 3., 3., 3.}}, |
4039 | {{0., 0., 1., 1., 1.}, |
4040 | {0., 0., 1., 1., 1.}, |
4041 | {2., 2., 3., 3., 3.}, |
4042 | {2., 2., 3., 3., 3.}, |
4043 | {2., 2., 3., 3., 3.}}, |
4044 | {{4., 4., 5., 5., 5.}, |
4045 | {4., 4., 5., 5., 5.}, |
4046 | {6., 6., 7., 7., 7.}, |
4047 | {6., 6., 7., 7., 7.}, |
4048 | {6., 6., 7., 7., 7.}}, |
4049 | {{4., 4., 5., 5., 5.}, |
4050 | {4., 4., 5., 5., 5.}, |
4051 | {6., 6., 7., 7., 7.}, |
4052 | {6., 6., 7., 7., 7.}, |
4053 | {6., 6., 7., 7., 7.}}, |
4054 | {{4., 4., 5., 5., 5.}, |
4055 | {4., 4., 5., 5., 5.}, |
4056 | {6., 6., 7., 7., 7.}, |
4057 | {6., 6., 7., 7., 7.}, |
4058 | {6., 6., 7., 7., 7.}}}}}, |
4059 | torch::kFloat); |
4060 | ASSERT_TRUE(output.allclose(expected)); |
4061 | } |
4062 | } |
4063 | |
4064 | TEST_F(ModulesTest, ZeroPad2d) { |
4065 | { |
4066 | ZeroPad2d m(ZeroPad2dOptions(2)); |
4067 | auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3}); |
4068 | auto output = m(input); |
4069 | auto expected = torch::tensor( |
4070 | {{{{0., 0., 0., 0., 0., 0., 0.}, |
4071 | {0., 0., 0., 0., 0., 0., 0.}, |
4072 | {0., 0., 0., 1., 2., 0., 0.}, |
4073 | {0., 0., 3., 4., 5., 0., 0.}, |
4074 | {0., 0., 6., 7., 8., 0., 0.}, |
4075 | {0., 0., 0., 0., 0., 0., 0.}, |
4076 | {0., 0., 0., 0., 0., 0., 0.}}}}, |
4077 | torch::kFloat); |
4078 | ASSERT_TRUE(output.allclose(expected)); |
4079 | } |
4080 | { |
4081 | ZeroPad2d m(ZeroPad2dOptions({1, 1, 2, 0})); |
4082 | auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3}); |
4083 | auto output = m(input); |
4084 | auto expected = torch::tensor( |
4085 | {{{{0., 0., 0., 0., 0.}, |
4086 | {0., 0., 0., 0., 0.}, |
4087 | {0., 0., 1., 2., 0.}, |
4088 | {0., 3., 4., 5., 0.}, |
4089 | {0., 6., 7., 8., 0.}}}}, |
4090 | torch::kFloat); |
4091 | ASSERT_TRUE(output.allclose(expected)); |
4092 | } |
4093 | } |
4094 | |
4095 | TEST_F(ModulesTest, ConstantPad1d) { |
4096 | { |
4097 | ConstantPad1d m(ConstantPad1dOptions(2, 3.5)); |
4098 | auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4}); |
4099 | auto output = m(input); |
4100 | auto expected = torch::tensor( |
4101 | {{{3.5000, 3.5000, 0.0000, 1.0000, 2.0000, 3.0000, 3.5000, 3.5000}, |
4102 | {3.5000, 3.5000, 4.0000, 5.0000, 6.0000, 7.0000, 3.5000, 3.5000}}}, |
4103 | torch::kFloat); |
4104 | ASSERT_TRUE(output.allclose(expected)); |
4105 | } |
4106 | { |
4107 | ConstantPad1d m(ConstantPad1dOptions({3, 1}, 3.5)); |
4108 | auto input = torch::arange(6, torch::kFloat).reshape({1, 2, 3}); |
4109 | auto output = m(input); |
4110 | auto expected = torch::tensor( |
4111 | {{{3.5000, 3.5000, 3.5000, 0.0000, 1.0000, 2.0000, 3.5000}, |
4112 | {3.5000, 3.5000, 3.5000, 3.0000, 4.0000, 5.0000, 3.5000}}}, |
4113 | torch::kFloat); |
4114 | ASSERT_TRUE(output.allclose(expected)); |
4115 | } |
4116 | } |
4117 | |
4118 | TEST_F(ModulesTest, ConstantPad2d) { |
4119 | { |
4120 | ConstantPad2d m(ConstantPad2dOptions(2, 3.5)); |
4121 | auto input = torch::arange(4, torch::kFloat).reshape({1, 2, 2}); |
4122 | auto output = m(input); |
4123 | auto expected = torch::tensor( |
4124 | {{{3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, |
4125 | {3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, |
4126 | {3.5000, 3.5000, 0.0000, 1.0000, 3.5000, 3.5000}, |
4127 | {3.5000, 3.5000, 2.0000, 3.0000, 3.5000, 3.5000}, |
4128 | {3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, |
4129 | {3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}}, |
4130 | torch::kFloat); |
4131 | ASSERT_TRUE(output.allclose(expected)); |
4132 | } |
4133 | { |
4134 | ConstantPad2d m(ConstantPad2dOptions({3, 0, 2, 1}, 3.5)); |
4135 | auto input = torch::arange(4, torch::kFloat).reshape({1, 2, 2}); |
4136 | auto output = m(input); |
4137 | auto expected = torch::tensor( |
4138 | {{{3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, |
4139 | {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, |
4140 | {3.5000, 3.5000, 3.5000, 0.0000, 1.0000}, |
4141 | {3.5000, 3.5000, 3.5000, 2.0000, 3.0000}, |
4142 | {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}}, |
4143 | torch::kFloat); |
4144 | ASSERT_TRUE(output.allclose(expected)); |
4145 | } |
4146 | } |
4147 | |
4148 | TEST_F(ModulesTest, ConstantPad3d) { |
4149 | { |
4150 | ConstantPad3d m(ConstantPad3dOptions(1, 3.5)); |
4151 | auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2}); |
4152 | auto output = m(input); |
4153 | auto expected = torch::tensor( |
4154 | {{{{{3.5000, 3.5000, 3.5000, 3.5000}, |
4155 | {3.5000, 3.5000, 3.5000, 3.5000}, |
4156 | {3.5000, 3.5000, 3.5000, 3.5000}, |
4157 | {3.5000, 3.5000, 3.5000, 3.5000}}, |
4158 | {{3.5000, 3.5000, 3.5000, 3.5000}, |
4159 | {3.5000, 0.0000, 1.0000, 3.5000}, |
4160 | {3.5000, 2.0000, 3.0000, 3.5000}, |
4161 | {3.5000, 3.5000, 3.5000, 3.5000}}, |
4162 | {{3.5000, 3.5000, 3.5000, 3.5000}, |
4163 | {3.5000, 4.0000, 5.0000, 3.5000}, |
4164 | {3.5000, 6.0000, 7.0000, 3.5000}, |
4165 | {3.5000, 3.5000, 3.5000, 3.5000}}, |
4166 | {{3.5000, 3.5000, 3.5000, 3.5000}, |
4167 | {3.5000, 3.5000, 3.5000, 3.5000}, |
4168 | {3.5000, 3.5000, 3.5000, 3.5000}, |
4169 | {3.5000, 3.5000, 3.5000, 3.5000}}}}}, |
4170 | torch::kFloat); |
4171 | ASSERT_TRUE(output.allclose(expected)); |
4172 | } |
4173 | { |
4174 | ConstantPad3d m(ConstantPad3dOptions({1, 2, 1, 2, 1, 2}, 3.5)); |
4175 | auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2}); |
4176 | auto output = m(input); |
4177 | auto expected = torch::tensor( |
4178 | {{{{{3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, |
4179 | {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, |
4180 | {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, |
4181 | {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, |
4182 | {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}, |
4183 | {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, |
4184 | {3.5000, 0.0000, 1.0000, 3.5000, 3.5000}, |
4185 | {3.5000, 2.0000, 3.0000, 3.5000, 3.5000}, |
4186 | {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, |
4187 | {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}, |
4188 | {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, |
4189 | {3.5000, 4.0000, 5.0000, 3.5000, 3.5000}, |
4190 | {3.5000, 6.0000, 7.0000, 3.5000, 3.5000}, |
4191 | {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, |
4192 | {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}, |
4193 | {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, |
4194 | {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, |
4195 | {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, |
4196 | {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, |
4197 | {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}, |
4198 | {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, |
4199 | {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, |
4200 | {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, |
4201 | {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}, |
4202 | {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}}}}, |
4203 | torch::kFloat); |
4204 | ASSERT_TRUE(output.allclose(expected)); |
4205 | } |
4206 | } |
4207 | |
4208 | TEST_F(ModulesTest, CrossMapLRN2d) { |
4209 | /// size 3, default options |
4210 | auto input = |
4211 | torch::arange(9, torch::kFloat32).view({1, 1, 3, 3}).requires_grad_(true); |
4212 | auto expected = torch::tensor( |
4213 | {{{{0.00000000, 0.99997497, 1.99980010}, |
4214 | {2.99932500, 3.99840070, 4.99687700}, |
4215 | {5.99460600, 6.99143740, 7.98722360}}}}, |
4216 | torch::kFloat32); |
4217 | auto grad_expected = torch::tensor( |
4218 | {{{{1.00000000, 0.99992496, 0.99970007}, |
4219 | {0.99932520, 0.99880093, 0.99812720}, |
4220 | {0.99730474, 0.99633380, 0.99521490}}}}, |
4221 | torch::kFloat32); |
4222 | auto crossmaplrn2d = CrossMapLRN2d(3); |
4223 | auto output = crossmaplrn2d(input); |
4224 | output.sum().backward(); |
4225 | |
4226 | ASSERT_TRUE(input.grad().allclose(grad_expected)); |
4227 | ASSERT_TRUE(output.allclose(expected)); |
4228 | |
4229 | /// size change |
4230 | crossmaplrn2d = |
4231 | CrossMapLRN2d(CrossMapLRN2dOptions(4).alpha(1e-4).beta(0.75).k(1)); |
4232 | output = crossmaplrn2d(input); |
4233 | expected = torch::tensor( |
4234 | {{{{0.00000000, 0.99998120, 1.99985000}, |
4235 | {2.99949400, 3.99880050, 4.99765800}, |
4236 | {5.99595300, 6.99357600, 7.99041300}}}}, |
4237 | torch::kFloat32); |
4238 | ASSERT_TRUE(output.allclose(expected)); |
4239 | |
4240 | /// alpha change |
4241 | crossmaplrn2d = |
4242 | CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-3).beta(0.75).k(1)); |
4243 | output = crossmaplrn2d(input); |
4244 | expected = torch::tensor( |
4245 | {{{{0.00000000, 0.99975010, 1.99800230}, |
4246 | {2.99326750, 3.98407440, 4.96897600}, |
4247 | {5.94656100, 6.91545720, 7.87434340}}}}, |
4248 | torch::kFloat32); |
4249 | ASSERT_TRUE(output.allclose(expected)); |
4250 | |
4251 | /// beta change |
4252 | crossmaplrn2d = |
4253 | CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-4).beta(0.95).k(1)); |
4254 | output = crossmaplrn2d(input); |
4255 | expected = torch::tensor( |
4256 | {{{{0.00000000, 0.99996830, 1.99974680}, |
4257 | {2.99914500, 3.99797440, 4.99604460}, |
4258 | {5.99316840, 6.98915600, 7.98382000}}}}, |
4259 | torch::kFloat32); |
4260 | ASSERT_TRUE(output.allclose(expected)); |
4261 | |
4262 | /// k change |
4263 | crossmaplrn2d = |
4264 | CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-4).beta(0.75).k(2)); |
4265 | output = crossmaplrn2d(input); |
4266 | expected = torch::tensor( |
4267 | {{{{0.00000000, 0.59459610, 1.18914770}, |
4268 | {1.78361000, 2.37793870, 2.97208900}, |
4269 | {3.56601700, 4.15967700, 4.75302650}}}}, |
4270 | torch::kFloat32); |
4271 | ASSERT_TRUE(output.allclose(expected)); |
4272 | } |
4273 | |
4274 | TEST_F(ModulesTest, RNNCell) { |
4275 | torch::manual_seed(0); |
4276 | auto rnn = RNNCell(1, 2); |
4277 | auto input = torch::randn({3, 1}); |
4278 | auto hx = torch::randn({3, 2}); |
4279 | auto output = rnn(input, hx); |
4280 | auto expected = |
4281 | torch::tensor({{-0.5078, 0.4380}, {-0.7215, 0.2969}, {-0.1304, 0.0653}}); |
4282 | ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04)); |
4283 | |
4284 | output = rnn(input); |
4285 | expected = |
4286 | torch::tensor({{-0.0775, 0.6688}, {-0.0734, 0.4759}, {-0.0725, 0.4225}}); |
4287 | ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04)); |
4288 | } |
4289 | |
4290 | TEST_F(ModulesTest, LSTMCell) { |
4291 | torch::manual_seed(0); |
4292 | auto rnn = LSTMCell(1, 2); |
4293 | auto input = torch::randn({3, 1}); |
4294 | auto hx = torch::randn({3, 2}); |
4295 | auto cx = torch::randn({3, 2}); |
4296 | auto output = rnn(input, std::make_tuple(hx, cx)); |
4297 | auto output_hx = std::get<0>(output); |
4298 | auto output_cx = std::get<1>(output); |
4299 | auto expected_hx = |
4300 | torch::tensor({{-0.2462, 0.0810}, {-0.2206, 0.1867}, {-0.0146, 0.0429}}); |
4301 | auto expected_cx = |
4302 | torch::tensor({{-0.4480, 0.1071}, {-0.6245, 0.2687}, {-0.0322, 0.0518}}); |
4303 | ASSERT_TRUE(torch::allclose(output_hx, expected_hx, 1e-05, 2e-04)); |
4304 | ASSERT_TRUE(torch::allclose(output_cx, expected_cx, 1e-05, 2e-04)); |
4305 | |
4306 | output = rnn(input); |
4307 | output_hx = std::get<0>(output); |
4308 | output_cx = std::get<1>(output); |
4309 | expected_hx = |
4310 | torch::tensor({{-0.1331, 0.1634}, {-0.1494, 0.2869}, {-0.1428, 0.2263}}); |
4311 | expected_cx = |
4312 | torch::tensor({{-0.2679, 0.2180}, {-0.3049, 0.3493}, {-0.2896, 0.2853}}); |
4313 | ASSERT_TRUE(torch::allclose(output_hx, expected_hx, 1e-05, 2e-04)); |
4314 | ASSERT_TRUE(torch::allclose(output_cx, expected_cx, 1e-05, 2e-04)); |
4315 | } |
4316 | |
4317 | TEST_F(ModulesTest, GRUCell) { |
4318 | torch::manual_seed(0); |
4319 | auto rnn = GRUCell(1, 2); |
4320 | auto input = torch::randn({3, 1}); |
4321 | auto hx = torch::randn({3, 2}); |
4322 | auto output = rnn(input, hx); |
4323 | auto expected = |
4324 | torch::tensor({{1.0243, 0.3227}, {-0.5659, 0.0330}, {-0.4030, -0.2800}}); |
4325 | ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04)); |
4326 | |
4327 | output = rnn(input); |
4328 | expected = |
4329 | torch::tensor({{-0.0085, 0.1095}, {-0.1291, 0.2675}, {-0.1339, 0.2725}}); |
4330 | ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04)); |
4331 | } |
4332 | |
4333 | TEST_F(ModulesTest, PrettyPrintLinear) { |
4334 | ASSERT_EQ( |
4335 | c10::str(Linear(3, 4)), |
4336 | "torch::nn::Linear(in_features=3, out_features=4, bias=true)" ); |
4337 | } |
4338 | |
4339 | TEST_F(ModulesTest, PrettyPrintBilinear) { |
4340 | ASSERT_EQ( |
4341 | c10::str(Bilinear(3, 2, 4)), |
4342 | "torch::nn::Bilinear(in1_features=3, in2_features=2, out_features=4, bias=true)" ); |
4343 | ASSERT_EQ( |
4344 | c10::str(Bilinear(BilinearOptions(3, 2, 4).bias(false))), |
4345 | "torch::nn::Bilinear(in1_features=3, in2_features=2, out_features=4, bias=false)" ); |
4346 | } |
4347 | |
4348 | TEST_F(ModulesTest, PrettyPrintConv) { |
4349 | ASSERT_EQ( |
4350 | c10::str(Conv1d(3, 4, 5)), |
4351 | "torch::nn::Conv1d(3, 4, kernel_size=5, stride=1)" ); |
4352 | |
4353 | ASSERT_EQ( |
4354 | c10::str(Conv2d(3, 4, 5)), |
4355 | "torch::nn::Conv2d(3, 4, kernel_size=[5, 5], stride=[1, 1])" ); |
4356 | ASSERT_EQ( |
4357 | c10::str(Conv2d(Conv2dOptions(3, 4, 5).stride(2))), |
4358 | "torch::nn::Conv2d(3, 4, kernel_size=[5, 5], stride=[2, 2])" ); |
4359 | { |
4360 | const auto options = |
4361 | Conv2dOptions(3, 4, std::vector<int64_t>{5, 6}).stride({1, 2}); |
4362 | ASSERT_EQ( |
4363 | c10::str(Conv2d(options)), |
4364 | "torch::nn::Conv2d(3, 4, kernel_size=[5, 6], stride=[1, 2])" ); |
4365 | } |
4366 | |
4367 | ASSERT_EQ( |
4368 | c10::str(Conv3d(4, 4, std::vector<int64_t>{5, 6, 7})), |
4369 | "torch::nn::Conv3d(4, 4, kernel_size=[5, 6, 7], stride=[1, 1, 1])" ); |
4370 | { |
4371 | const auto options = Conv3dOptions(4, 4, std::vector<int64_t>{5, 6, 7}) |
4372 | .stride({1, 2, 3}) |
4373 | .padding(1) |
4374 | .dilation(0) |
4375 | .groups(2) |
4376 | .bias(false) |
4377 | .padding_mode(torch::kCircular); |
4378 | ASSERT_EQ( |
4379 | c10::str(Conv3d(options)), |
4380 | "torch::nn::Conv3d(" |
4381 | "4, " |
4382 | "4, " |
4383 | "kernel_size=[5, 6, 7], " |
4384 | "stride=[1, 2, 3], " |
4385 | "padding=[1, 1, 1], " |
4386 | "dilation=[0, 0, 0], " |
4387 | "groups=2, " |
4388 | "bias=false, " |
4389 | "padding_mode=kCircular)" ); |
4390 | } |
4391 | } |
4392 | |
4393 | TEST_F(ModulesTest, PrettyPrintConvTranspose) { |
4394 | ASSERT_EQ( |
4395 | c10::str(ConvTranspose1d(3, 4, 5)), |
4396 | "torch::nn::ConvTranspose1d(3, 4, kernel_size=5, stride=1)" ); |
4397 | |
4398 | ASSERT_EQ( |
4399 | c10::str(ConvTranspose2d(3, 4, 5)), |
4400 | "torch::nn::ConvTranspose2d(3, 4, kernel_size=[5, 5], stride=[1, 1])" ); |
4401 | ASSERT_EQ( |
4402 | c10::str(ConvTranspose2d(ConvTranspose2dOptions(3, 4, 5).stride(2))), |
4403 | "torch::nn::ConvTranspose2d(3, 4, kernel_size=[5, 5], stride=[2, 2])" ); |
4404 | { |
4405 | const auto options = |
4406 | ConvTranspose2dOptions(3, 4, std::vector<int64_t>{5, 6}).stride({1, 2}); |
4407 | ASSERT_EQ( |
4408 | c10::str(ConvTranspose2d(options)), |
4409 | "torch::nn::ConvTranspose2d(3, 4, kernel_size=[5, 6], stride=[1, 2])" ); |
4410 | } |
4411 | |
4412 | ASSERT_EQ( |
4413 | c10::str(ConvTranspose3d(4, 4, std::vector<int64_t>{5, 6, 7})), |
4414 | "torch::nn::ConvTranspose3d(4, 4, kernel_size=[5, 6, 7], stride=[1, 1, 1])" ); |
4415 | { |
4416 | const auto options = |
4417 | ConvTranspose3dOptions(4, 4, std::vector<int64_t>{5, 6, 7}) |
4418 | .stride({1, 2, 3}) |
4419 | .padding(1) |
4420 | .dilation(0) |
4421 | .groups(2) |
4422 | .bias(false) |
4423 | .padding_mode(torch::kCircular); |
4424 | ASSERT_EQ( |
4425 | c10::str(ConvTranspose3d(options)), |
4426 | "torch::nn::ConvTranspose3d(" |
4427 | "4, " |
4428 | "4, " |
4429 | "kernel_size=[5, 6, 7], " |
4430 | "stride=[1, 2, 3], " |
4431 | "padding=[1, 1, 1], " |
4432 | "dilation=[0, 0, 0], " |
4433 | "groups=2, " |
4434 | "bias=false, " |
4435 | "padding_mode=kCircular)" ); |
4436 | } |
4437 | } |
4438 | |
4439 | TEST_F(ModulesTest, PrettyPrintUpsample) { |
4440 | ASSERT_EQ( |
4441 | c10::str( |
4442 | Upsample(UpsampleOptions().size(std::vector<int64_t>({2, 4, 4})))), |
4443 | "torch::nn::Upsample(size=[2, 4, 4], mode=kNearest)" ); |
4444 | ASSERT_EQ( |
4445 | c10::str(Upsample(UpsampleOptions() |
4446 | .scale_factor(std::vector<double>({0.5, 1.5})) |
4447 | .mode(torch::kBilinear))), |
4448 | "torch::nn::Upsample(scale_factor=[0.5, 1.5], mode=kBilinear)" ); |
4449 | } |
4450 | |
4451 | TEST_F(ModulesTest, PrettyPrintFold) { |
4452 | ASSERT_EQ( |
4453 | c10::str(Fold(FoldOptions({2, 2}, {5, 5}))), |
4454 | "torch::nn::Fold(output_size=[2, 2], kernel_size=[5, 5], dilation=[1, 1], padding=[0, 0], stride=[1, 1])" ); |
4455 | ASSERT_EQ( |
4456 | c10::str(Fold( |
4457 | FoldOptions({8, 8}, {3, 3}).dilation(2).padding({2, 1}).stride(2))), |
4458 | "torch::nn::Fold(output_size=[8, 8], kernel_size=[3, 3], dilation=[2, 2], padding=[2, 1], stride=[2, 2])" ); |
4459 | } |
4460 | |
4461 | TEST_F(ModulesTest, PrettyPrintUnfold) { |
4462 | ASSERT_EQ( |
4463 | c10::str(Unfold(torch::IntArrayRef({2, 4}))), |
4464 | "torch::nn::Unfold(kernel_size=[2, 4], dilation=[1, 1], padding=[0, 0], stride=[1, 1])" ); |
4465 | ASSERT_EQ( |
4466 | c10::str( |
4467 | Unfold(UnfoldOptions({2, 4}).dilation(2).padding({2, 1}).stride(2))), |
4468 | "torch::nn::Unfold(kernel_size=[2, 4], dilation=[2, 2], padding=[2, 1], stride=[2, 2])" ); |
4469 | } |
4470 | |
4471 | TEST_F(ModulesTest, PrettyPrintMaxPool) { |
4472 | ASSERT_EQ( |
4473 | c10::str(MaxPool1d(5)), |
4474 | "torch::nn::MaxPool1d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=false)" ); |
4475 | ASSERT_EQ( |
4476 | c10::str(MaxPool2d(5)), |
4477 | "torch::nn::MaxPool2d(kernel_size=[5, 5], stride=[5, 5], padding=[0, 0], dilation=[1, 1], ceil_mode=false)" ); |
4478 | ASSERT_EQ( |
4479 | c10::str(MaxPool2d(MaxPool2dOptions(5).stride(2))), |
4480 | "torch::nn::MaxPool2d(kernel_size=[5, 5], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=false)" ); |
4481 | ASSERT_EQ( |
4482 | c10::str(MaxPool3d(5)), |
4483 | "torch::nn::MaxPool3d(kernel_size=[5, 5, 5], stride=[5, 5, 5], padding=[0, 0, 0], dilation=[1, 1, 1], ceil_mode=false)" ); |
4484 | ASSERT_EQ( |
4485 | c10::str(MaxPool3d(MaxPool3dOptions(5).stride(2))), |
4486 | "torch::nn::MaxPool3d(kernel_size=[5, 5, 5], stride=[2, 2, 2], padding=[0, 0, 0], dilation=[1, 1, 1], ceil_mode=false)" ); |
4487 | |
4488 | const auto options = |
4489 | MaxPool2dOptions(std::vector<int64_t>{5, 6}).stride({1, 2}); |
4490 | ASSERT_EQ( |
4491 | c10::str(MaxPool2d(options)), |
4492 | "torch::nn::MaxPool2d(kernel_size=[5, 6], stride=[1, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=false)" ); |
4493 | } |
4494 | |
4495 | TEST_F(ModulesTest, PrettyPrintAvgPool) { |
4496 | ASSERT_EQ( |
4497 | c10::str(AvgPool1d(5)), |
4498 | "torch::nn::AvgPool1d(kernel_size=5, stride=5, padding=0)" ); |
4499 | ASSERT_EQ( |
4500 | c10::str(AvgPool2d(5)), |
4501 | "torch::nn::AvgPool2d(kernel_size=[5, 5], stride=[5, 5], padding=[0, 0])" ); |
4502 | ASSERT_EQ( |
4503 | c10::str(AvgPool2d(AvgPool2dOptions(5).stride(2))), |
4504 | "torch::nn::AvgPool2d(kernel_size=[5, 5], stride=[2, 2], padding=[0, 0])" ); |
4505 | ASSERT_EQ( |
4506 | c10::str(AvgPool3d(5)), |
4507 | "torch::nn::AvgPool3d(kernel_size=[5, 5, 5], stride=[5, 5, 5], padding=[0, 0, 0])" ); |
4508 | ASSERT_EQ( |
4509 | c10::str(AvgPool3d(AvgPool3dOptions(5).stride(2))), |
4510 | "torch::nn::AvgPool3d(kernel_size=[5, 5, 5], stride=[2, 2, 2], padding=[0, 0, 0])" ); |
4511 | |
4512 | const auto options = |
4513 | AvgPool2dOptions(std::vector<int64_t>{5, 6}).stride({1, 2}); |
4514 | ASSERT_EQ( |
4515 | c10::str(AvgPool2d(options)), |
4516 | "torch::nn::AvgPool2d(kernel_size=[5, 6], stride=[1, 2], padding=[0, 0])" ); |
4517 | } |
4518 | |
4519 | TEST_F(ModulesTest, PrettyPrinFractionalMaxPool) { |
4520 | ASSERT_EQ( |
4521 | c10::str( |
4522 | FractionalMaxPool2d(FractionalMaxPool2dOptions(5).output_size(1))), |
4523 | "torch::nn::FractionalMaxPool2d()" ); |
4524 | ASSERT_EQ( |
4525 | c10::str( |
4526 | FractionalMaxPool3d(FractionalMaxPool3dOptions(5).output_size(1))), |
4527 | "torch::nn::FractionalMaxPool3d()" ); |
4528 | } |
4529 | |
4530 | TEST_F(ModulesTest, PrettyPrintLPPool) { |
4531 | ASSERT_EQ( |
4532 | c10::str(LPPool1d(2, 5)), |
4533 | "torch::nn::LPPool1d(norm_type=2, kernel_size=5, stride=5, ceil_mode=false)" ); |
4534 | ASSERT_EQ( |
4535 | c10::str(LPPool1d(LPPool1dOptions(1, 2).stride(5).ceil_mode(true))), |
4536 | "torch::nn::LPPool1d(norm_type=1, kernel_size=2, stride=5, ceil_mode=true)" ); |
4537 | ASSERT_EQ( |
4538 | c10::str(LPPool2d(2, std::vector<int64_t>({1, 2}))), |
4539 | "torch::nn::LPPool2d(norm_type=2, kernel_size=[1, 2], stride=[1, 2], ceil_mode=false)" ); |
4540 | ASSERT_EQ( |
4541 | c10::str(LPPool2d(LPPool2dOptions(1, std::vector<int64_t>({3, 4})) |
4542 | .stride({5, 6}) |
4543 | .ceil_mode(true))), |
4544 | "torch::nn::LPPool2d(norm_type=1, kernel_size=[3, 4], stride=[5, 6], ceil_mode=true)" ); |
4545 | } |
4546 | |
4547 | TEST_F(ModulesTest, PrettyPrintAdaptiveMaxPool) { |
4548 | ASSERT_EQ( |
4549 | c10::str(AdaptiveMaxPool1d(5)), |
4550 | "torch::nn::AdaptiveMaxPool1d(output_size=5)" ); |
4551 | |
4552 | const auto options = AdaptiveMaxPool1dOptions(3); |
4553 | ASSERT_EQ( |
4554 | c10::str(AdaptiveMaxPool1d(options)), |
4555 | "torch::nn::AdaptiveMaxPool1d(output_size=3)" ); |
4556 | |
4557 | ASSERT_EQ( |
4558 | c10::str(AdaptiveMaxPool2d(5)), |
4559 | "torch::nn::AdaptiveMaxPool2d(output_size=[5, 5])" ); |
4560 | ASSERT_EQ( |
4561 | c10::str(AdaptiveMaxPool2d(AdaptiveMaxPool2dOptions({5, 6}))), |
4562 | "torch::nn::AdaptiveMaxPool2d(output_size=[5, 6])" ); |
4563 | ASSERT_EQ( |
4564 | c10::str(AdaptiveMaxPool2d(AdaptiveMaxPool2dOptions({5, c10::nullopt}))), |
4565 | "torch::nn::AdaptiveMaxPool2d(output_size=[5, None])" ); |
4566 | ASSERT_EQ( |
4567 | c10::str(AdaptiveMaxPool2d( |
4568 | AdaptiveMaxPool2dOptions({c10::nullopt, c10::nullopt}))), |
4569 | "torch::nn::AdaptiveMaxPool2d(output_size=[None, None])" ); |
4570 | |
4571 | ASSERT_EQ( |
4572 | c10::str(AdaptiveMaxPool3d(5)), |
4573 | "torch::nn::AdaptiveMaxPool3d(output_size=[5, 5, 5])" ); |
4574 | ASSERT_EQ( |
4575 | c10::str(AdaptiveMaxPool3d(AdaptiveMaxPool3dOptions({5, 6, 7}))), |
4576 | "torch::nn::AdaptiveMaxPool3d(output_size=[5, 6, 7])" ); |
4577 | ASSERT_EQ( |
4578 | c10::str( |
4579 | AdaptiveMaxPool3d(AdaptiveMaxPool3dOptions({5, c10::nullopt, 7}))), |
4580 | "torch::nn::AdaptiveMaxPool3d(output_size=[5, None, 7])" ); |
4581 | ASSERT_EQ( |
4582 | c10::str(AdaptiveMaxPool3d(AdaptiveMaxPool3dOptions( |
4583 | {c10::nullopt, c10::nullopt, c10::nullopt}))), |
4584 | "torch::nn::AdaptiveMaxPool3d(output_size=[None, None, None])" ); |
4585 | } |
4586 | |
4587 | TEST_F(ModulesTest, PrettyPrintAdaptiveAvgPool) { |
4588 | ASSERT_EQ( |
4589 | c10::str(AdaptiveAvgPool1d(5)), |
4590 | "torch::nn::AdaptiveAvgPool1d(output_size=5)" ); |
4591 | |
4592 | ASSERT_EQ( |
4593 | c10::str(AdaptiveAvgPool2d(5)), |
4594 | "torch::nn::AdaptiveAvgPool2d(output_size=[5, 5])" ); |
4595 | ASSERT_EQ( |
4596 | c10::str(AdaptiveAvgPool2d(AdaptiveAvgPool2dOptions({5, 6}))), |
4597 | "torch::nn::AdaptiveAvgPool2d(output_size=[5, 6])" ); |
4598 | ASSERT_EQ( |
4599 | c10::str(AdaptiveAvgPool2d(AdaptiveAvgPool2dOptions({5, c10::nullopt}))), |
4600 | "torch::nn::AdaptiveAvgPool2d(output_size=[5, None])" ); |
4601 | ASSERT_EQ( |
4602 | c10::str(AdaptiveAvgPool2d( |
4603 | AdaptiveAvgPool2dOptions({c10::nullopt, c10::nullopt}))), |
4604 | "torch::nn::AdaptiveAvgPool2d(output_size=[None, None])" ); |
4605 | |
4606 | ASSERT_EQ( |
4607 | c10::str(AdaptiveAvgPool3d(5)), |
4608 | "torch::nn::AdaptiveAvgPool3d(output_size=[5, 5, 5])" ); |
4609 | ASSERT_EQ( |
4610 | c10::str(AdaptiveAvgPool3d(AdaptiveAvgPool3dOptions({5, 6, 7}))), |
4611 | "torch::nn::AdaptiveAvgPool3d(output_size=[5, 6, 7])" ); |
4612 | ASSERT_EQ( |
4613 | c10::str( |
4614 | AdaptiveAvgPool3d(AdaptiveAvgPool3dOptions({5, c10::nullopt, 7}))), |
4615 | "torch::nn::AdaptiveAvgPool3d(output_size=[5, None, 7])" ); |
4616 | ASSERT_EQ( |
4617 | c10::str(AdaptiveAvgPool3d(AdaptiveAvgPool3dOptions( |
4618 | {c10::nullopt, c10::nullopt, c10::nullopt}))), |
4619 | "torch::nn::AdaptiveAvgPool3d(output_size=[None, None, None])" ); |
4620 | } |
4621 | |
4622 | TEST_F(ModulesTest, PrettyPrintMaxUnpool) { |
4623 | ASSERT_EQ( |
4624 | c10::str(MaxUnpool1d(5)), |
4625 | "torch::nn::MaxUnpool1d(kernel_size=5, stride=5, padding=0)" ); |
4626 | ASSERT_EQ( |
4627 | c10::str(MaxUnpool1d(MaxUnpool1dOptions(5).stride(3).padding(1))), |
4628 | "torch::nn::MaxUnpool1d(kernel_size=5, stride=3, padding=1)" ); |
4629 | |
4630 | ASSERT_EQ( |
4631 | c10::str(MaxUnpool2d(5)), |
4632 | "torch::nn::MaxUnpool2d(kernel_size=[5, 5], stride=[5, 5], padding=[0, 0])" ); |
4633 | ASSERT_EQ( |
4634 | c10::str(MaxUnpool2d(std::vector<int64_t>{5, 6})), |
4635 | "torch::nn::MaxUnpool2d(kernel_size=[5, 6], stride=[5, 6], padding=[0, 0])" ); |
4636 | ASSERT_EQ( |
4637 | c10::str(MaxUnpool2d(MaxUnpool2dOptions(std::vector<int64_t>{5, 6}) |
4638 | .stride({3, 4}) |
4639 | .padding({1, 2}))), |
4640 | "torch::nn::MaxUnpool2d(kernel_size=[5, 6], stride=[3, 4], padding=[1, 2])" ); |
4641 | } |
4642 | |
4643 | TEST_F(ModulesTest, PrettyPrintDropout) { |
4644 | ASSERT_EQ(c10::str(Dropout()), "torch::nn::Dropout(p=0.5, inplace=false)" ); |
4645 | ASSERT_EQ( |
4646 | c10::str(Dropout(0.42)), "torch::nn::Dropout(p=0.42, inplace=false)" ); |
4647 | ASSERT_EQ( |
4648 | c10::str(Dropout(DropoutOptions().p(0.42).inplace(true))), |
4649 | "torch::nn::Dropout(p=0.42, inplace=true)" ); |
4650 | } |
4651 | |
4652 | TEST_F(ModulesTest, PrettyPrintDropout2d) { |
4653 | ASSERT_EQ( |
4654 | c10::str(Dropout2d()), "torch::nn::Dropout2d(p=0.5, inplace=false)" ); |
4655 | ASSERT_EQ( |
4656 | c10::str(Dropout2d(0.42)), "torch::nn::Dropout2d(p=0.42, inplace=false)" ); |
4657 | ASSERT_EQ( |
4658 | c10::str(Dropout2d(Dropout2dOptions().p(0.42).inplace(true))), |
4659 | "torch::nn::Dropout2d(p=0.42, inplace=true)" ); |
4660 | } |
4661 | |
4662 | TEST_F(ModulesTest, PrettyPrintDropout3d) { |
4663 | ASSERT_EQ( |
4664 | c10::str(Dropout3d()), "torch::nn::Dropout3d(p=0.5, inplace=false)" ); |
4665 | ASSERT_EQ( |
4666 | c10::str(Dropout3d(0.42)), "torch::nn::Dropout3d(p=0.42, inplace=false)" ); |
4667 | ASSERT_EQ( |
4668 | c10::str(Dropout3d(Dropout3dOptions().p(0.42).inplace(true))), |
4669 | "torch::nn::Dropout3d(p=0.42, inplace=true)" ); |
4670 | } |
4671 | |
4672 | TEST_F(ModulesTest, PrettyPrintFunctional) { |
4673 | ASSERT_EQ(c10::str(Functional(torch::relu)), "torch::nn::Functional()" ); |
4674 | } |
4675 | |
4676 | TEST_F(ModulesTest, PrettyPrintBatchNorm1d) { |
4677 | ASSERT_EQ( |
4678 | c10::str(BatchNorm1d(BatchNorm1dOptions(4) |
4679 | .eps(0.5) |
4680 | .momentum(0.1) |
4681 | .affine(false) |
4682 | .track_running_stats(true))), |
4683 | "torch::nn::BatchNorm1d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)" ); |
4684 | } |
4685 | |
4686 | TEST_F(ModulesTest, PrettyPrintBatchNorm2d) { |
4687 | ASSERT_EQ( |
4688 | c10::str(BatchNorm2d(BatchNorm2dOptions(4) |
4689 | .eps(0.5) |
4690 | .momentum(0.1) |
4691 | .affine(false) |
4692 | .track_running_stats(true))), |
4693 | "torch::nn::BatchNorm2d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)" ); |
4694 | } |
4695 | |
4696 | TEST_F(ModulesTest, PrettyPrintBatchNorm3d) { |
4697 | ASSERT_EQ( |
4698 | c10::str(BatchNorm3d(BatchNorm3dOptions(4) |
4699 | .eps(0.5) |
4700 | .momentum(0.1) |
4701 | .affine(false) |
4702 | .track_running_stats(true))), |
4703 | "torch::nn::BatchNorm3d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)" ); |
4704 | } |
4705 | |
4706 | TEST_F(ModulesTest, PrettyPrintInstanceNorm1d) { |
4707 | ASSERT_EQ( |
4708 | c10::str(InstanceNorm1d(InstanceNorm1dOptions(4) |
4709 | .eps(0.5) |
4710 | .momentum(0.1) |
4711 | .affine(false) |
4712 | .track_running_stats(true))), |
4713 | "torch::nn::InstanceNorm1d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)" ); |
4714 | } |
4715 | |
4716 | TEST_F(ModulesTest, PrettyPrintInstanceNorm2d) { |
4717 | ASSERT_EQ( |
4718 | c10::str(InstanceNorm2d(InstanceNorm2dOptions(4) |
4719 | .eps(0.5) |
4720 | .momentum(0.1) |
4721 | .affine(false) |
4722 | .track_running_stats(true))), |
4723 | "torch::nn::InstanceNorm2d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)" ); |
4724 | } |
4725 | |
4726 | TEST_F(ModulesTest, PrettyPrintInstanceNorm3d) { |
4727 | ASSERT_EQ( |
4728 | c10::str(InstanceNorm3d(InstanceNorm3dOptions(4) |
4729 | .eps(0.5) |
4730 | .momentum(0.1) |
4731 | .affine(false) |
4732 | .track_running_stats(true))), |
4733 | "torch::nn::InstanceNorm3d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)" ); |
4734 | } |
4735 | |
4736 | TEST_F(ModulesTest, PrettyPrintLayerNorm) { |
4737 | ASSERT_EQ( |
4738 | c10::str(LayerNorm(LayerNormOptions({2, 2}))), |
4739 | "torch::nn::LayerNorm([2, 2], eps=1e-05, elementwise_affine=true)" ); |
4740 | ASSERT_EQ( |
4741 | c10::str(LayerNorm( |
4742 | LayerNormOptions({2, 2}).elementwise_affine(false).eps(2e-5))), |
4743 | "torch::nn::LayerNorm([2, 2], eps=2e-05, elementwise_affine=false)" ); |
4744 | } |
4745 | |
4746 | TEST_F(ModulesTest, PrettyPrintGroupNorm) { |
4747 | ASSERT_EQ( |
4748 | c10::str(GroupNorm(GroupNormOptions(2, 2))), |
4749 | "torch::nn::GroupNorm(2, 2, eps=1e-05, affine=true)" ); |
4750 | ASSERT_EQ( |
4751 | c10::str(GroupNorm(GroupNormOptions(2, 2).eps(2e-5).affine(false))), |
4752 | "torch::nn::GroupNorm(2, 2, eps=2e-05, affine=false)" ); |
4753 | } |
4754 | |
4755 | TEST_F(ModulesTest, PrettyPrintLocalResponseNorm) { |
4756 | ASSERT_EQ( |
4757 | c10::str(LocalResponseNorm(LocalResponseNormOptions(2))), |
4758 | "torch::nn::LocalResponseNorm(2, alpha=0.0001, beta=0.75, k=1)" ); |
4759 | ASSERT_EQ( |
4760 | c10::str(LocalResponseNorm( |
4761 | LocalResponseNormOptions(2).alpha(0.0002).beta(0.85).k(2.))), |
4762 | "torch::nn::LocalResponseNorm(2, alpha=0.0002, beta=0.85, k=2)" ); |
4763 | } |
4764 | |
4765 | TEST_F(ModulesTest, PrettyPrintEmbedding) { |
4766 | ASSERT_EQ( |
4767 | c10::str(Embedding(EmbeddingOptions(10, 2))), |
4768 | "torch::nn::Embedding(num_embeddings=10, embedding_dim=2)" ); |
4769 | ASSERT_EQ( |
4770 | c10::str(Embedding(EmbeddingOptions(10, 2).padding_idx(3).max_norm(2))), |
4771 | "torch::nn::Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2)" ); |
4772 | ASSERT_EQ( |
4773 | c10::str(Embedding(EmbeddingOptions(10, 2) |
4774 | .padding_idx(3) |
4775 | .max_norm(2) |
4776 | .norm_type(2.5) |
4777 | .scale_grad_by_freq(true) |
4778 | .sparse(true))), |
4779 | "torch::nn::Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true)" ); |
4780 | } |
4781 | |
4782 | TEST_F(ModulesTest, PrettyPrintEmbeddingBag) { |
4783 | ASSERT_EQ( |
4784 | c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2))), |
4785 | "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2)" ); |
4786 | ASSERT_EQ( |
4787 | c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2).max_norm(2))), |
4788 | "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2)" ); |
4789 | ASSERT_EQ( |
4790 | c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2) |
4791 | .max_norm(2) |
4792 | .norm_type(2.5) |
4793 | .scale_grad_by_freq(true) |
4794 | .sparse(true))), |
4795 | "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true)" ); |
4796 | ASSERT_EQ( |
4797 | c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2) |
4798 | .max_norm(2) |
4799 | .norm_type(2.5) |
4800 | .scale_grad_by_freq(true) |
4801 | .sparse(true) |
4802 | .mode(torch::kSum))), |
4803 | "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true, mode=kSum)" ); |
4804 | ASSERT_EQ( |
4805 | c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2) |
4806 | .max_norm(2) |
4807 | .norm_type(2.5) |
4808 | .scale_grad_by_freq(true) |
4809 | .sparse(true) |
4810 | .mode(torch::kSum) |
4811 | .padding_idx(5))), |
4812 | "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true, mode=kSum, padding_idx=5)" ); |
4813 | } |
4814 | |
4815 | TEST_F(ModulesTest, PrettyPrintL1Loss) { |
4816 | ASSERT_EQ(c10::str(L1Loss()), "torch::nn::L1Loss()" ); |
4817 | } |
4818 | TEST_F(ModulesTest, PrettyPrintKLDivLoss) { |
4819 | ASSERT_EQ(c10::str(KLDivLoss()), "torch::nn::KLDivLoss()" ); |
4820 | } |
4821 | TEST_F(ModulesTest, PrettyPrintMSELoss) { |
4822 | ASSERT_EQ(c10::str(MSELoss()), "torch::nn::MSELoss()" ); |
4823 | } |
4824 | TEST_F(ModulesTest, PrettyPrintBCELoss) { |
4825 | ASSERT_EQ(c10::str(BCELoss()), "torch::nn::BCELoss()" ); |
4826 | } |
4827 | TEST_F(ModulesTest, PrettyPrintHingeEmbeddingLoss) { |
4828 | ASSERT_EQ( |
4829 | c10::str(HingeEmbeddingLoss(HingeEmbeddingLossOptions().margin(4))), |
4830 | "torch::nn::HingeEmbeddingLoss(margin=4)" ); |
4831 | } |
4832 | |
4833 | TEST_F(ModulesTest, PrettyPrintCosineEmbeddingLoss) { |
4834 | ASSERT_EQ( |
4835 | c10::str(CosineEmbeddingLoss(CosineEmbeddingLossOptions().margin(0.25))), |
4836 | "torch::nn::CosineEmbeddingLoss(margin=0.25)" ); |
4837 | } |
4838 | |
4839 | TEST_F(ModulesTest, PrettyPrintTripletMarginLoss) { |
4840 | ASSERT_EQ( |
4841 | c10::str(TripletMarginLoss( |
4842 | TripletMarginLossOptions().margin(3).p(2).eps(1e-06).swap(false))), |
4843 | "torch::nn::TripletMarginLoss(margin=3, p=2, eps=1e-06, swap=false)" ); |
4844 | } |
4845 | |
4846 | TEST_F(ModulesTest, PrettyPrintTripletMarginWithDistanceLoss) { |
4847 | auto distanceOptions = TripletMarginWithDistanceLossOptions() |
4848 | .distance_function([&](const torch::Tensor& x, |
4849 | const torch::Tensor& y) { |
4850 | return torch::pairwise_distance(x, y, 2.0, 1e-6); |
4851 | }) |
4852 | .margin(1.5) |
4853 | .swap(true) |
4854 | .reduction(torch::kMean); |
4855 | ASSERT_EQ( |
4856 | c10::str(TripletMarginWithDistanceLoss(distanceOptions)), |
4857 | "torch::nn::TripletMarginWithDistanceLoss(margin=1.5, swap=true)" ); |
4858 | } |
4859 | |
4860 | TEST_F(ModulesTest, PrettyPrintNLLLoss) { |
4861 | ASSERT_EQ(c10::str(NLLLoss()), "torch::nn::NLLLoss()" ); |
4862 | } |
4863 | |
4864 | TEST_F(ModulesTest, PrettyPrinCrossEntropyLoss) { |
4865 | ASSERT_EQ(c10::str(CrossEntropyLoss()), "torch::nn::CrossEntropyLoss()" ); |
4866 | } |
4867 | |
4868 | TEST_F(ModulesTest, PrettyPrintMultiLabelMarginLoss) { |
4869 | ASSERT_EQ( |
4870 | c10::str(MultiLabelMarginLoss()), "torch::nn::MultiLabelMarginLoss()" ); |
4871 | } |
4872 | |
4873 | TEST_F(ModulesTest, PrettyPrintMultiLabelSoftMarginLoss) { |
4874 | ASSERT_EQ( |
4875 | c10::str(MultiLabelSoftMarginLoss()), |
4876 | "torch::nn::MultiLabelSoftMarginLoss()" ); |
4877 | } |
4878 | |
4879 | TEST_F(ModulesTest, PrettyPrintSoftMarginLoss) { |
4880 | ASSERT_EQ(c10::str(SoftMarginLoss()), "torch::nn::SoftMarginLoss()" ); |
4881 | } |
4882 | |
4883 | TEST_F(ModulesTest, PrettyPrintCosineSimilarity) { |
4884 | ASSERT_EQ( |
4885 | c10::str(CosineSimilarity()), |
4886 | "torch::nn::CosineSimilarity(dim=1, eps=1e-08)" ); |
4887 | ASSERT_EQ( |
4888 | c10::str(CosineSimilarity(CosineSimilarityOptions().dim(0).eps(0.5))), |
4889 | "torch::nn::CosineSimilarity(dim=0, eps=0.5)" ); |
4890 | } |
4891 | |
4892 | TEST_F(ModulesTest, PrettyPrintPairwiseDistance) { |
4893 | ASSERT_EQ( |
4894 | c10::str(PairwiseDistance()), |
4895 | "torch::nn::PairwiseDistance(p=2, eps=1e-06, keepdim=false)" ); |
4896 | ASSERT_EQ( |
4897 | c10::str(PairwiseDistance( |
4898 | PairwiseDistanceOptions().p(3).eps(0.5).keepdim(true))), |
4899 | "torch::nn::PairwiseDistance(p=3, eps=0.5, keepdim=true)" ); |
4900 | } |
4901 | |
4902 | TEST_F(ModulesTest, PrettyPrintReflectionPad) { |
4903 | ASSERT_EQ( |
4904 | c10::str(ReflectionPad1d(ReflectionPad1dOptions(2))), |
4905 | "torch::nn::ReflectionPad1d(padding=[2, 2])" ); |
4906 | ASSERT_EQ( |
4907 | c10::str(ReflectionPad1d(ReflectionPad1dOptions({3, 1}))), |
4908 | "torch::nn::ReflectionPad1d(padding=[3, 1])" ); |
4909 | ASSERT_EQ( |
4910 | c10::str(ReflectionPad2d(ReflectionPad2dOptions(2))), |
4911 | "torch::nn::ReflectionPad2d(padding=[2, 2, 2, 2])" ); |
4912 | ASSERT_EQ( |
4913 | c10::str(ReflectionPad2d(ReflectionPad2dOptions({1, 1, 2, 0}))), |
4914 | "torch::nn::ReflectionPad2d(padding=[1, 1, 2, 0])" ); |
4915 | } |
4916 | |
4917 | TEST_F(ModulesTest, PrettyPrintReplicationPad) { |
4918 | ASSERT_EQ( |
4919 | c10::str(ReplicationPad1d(ReplicationPad1dOptions(2))), |
4920 | "torch::nn::ReplicationPad1d(padding=[2, 2])" ); |
4921 | ASSERT_EQ( |
4922 | c10::str(ReplicationPad1d(ReplicationPad1dOptions({3, 1}))), |
4923 | "torch::nn::ReplicationPad1d(padding=[3, 1])" ); |
4924 | ASSERT_EQ( |
4925 | c10::str(ReplicationPad2d(ReplicationPad2dOptions(2))), |
4926 | "torch::nn::ReplicationPad2d(padding=[2, 2, 2, 2])" ); |
4927 | ASSERT_EQ( |
4928 | c10::str(ReplicationPad2d(ReplicationPad2dOptions({1, 1, 2, 0}))), |
4929 | "torch::nn::ReplicationPad2d(padding=[1, 1, 2, 0])" ); |
4930 | ASSERT_EQ( |
4931 | c10::str(ReplicationPad3d(ReplicationPad3dOptions(1))), |
4932 | "torch::nn::ReplicationPad3d(padding=[1, 1, 1, 1, 1, 1])" ); |
4933 | ASSERT_EQ( |
4934 | c10::str(ReplicationPad3d(ReplicationPad3dOptions({1, 2, 1, 2, 1, 2}))), |
4935 | "torch::nn::ReplicationPad3d(padding=[1, 2, 1, 2, 1, 2])" ); |
4936 | } |
4937 | |
4938 | TEST_F(ModulesTest, PrettyPrintZeroPad2d) { |
4939 | ASSERT_EQ( |
4940 | c10::str(ZeroPad2d(ZeroPad2dOptions(2))), |
4941 | "torch::nn::ZeroPad2d(padding=[2, 2, 2, 2])" ); |
4942 | ASSERT_EQ( |
4943 | c10::str(ZeroPad2d(ZeroPad2dOptions({1, 1, 2, 0}))), |
4944 | "torch::nn::ZeroPad2d(padding=[1, 1, 2, 0])" ); |
4945 | } |
4946 | |
4947 | TEST_F(ModulesTest, PrettyPrintConstantPad) { |
4948 | ASSERT_EQ( |
4949 | c10::str(ConstantPad1d(ConstantPad1dOptions(2, 3.5))), |
4950 | "torch::nn::ConstantPad1d(padding=[2, 2], value=3.5)" ); |
4951 | ASSERT_EQ( |
4952 | c10::str(ConstantPad1d(ConstantPad1dOptions({3, 1}, 3.5))), |
4953 | "torch::nn::ConstantPad1d(padding=[3, 1], value=3.5)" ); |
4954 | ASSERT_EQ( |
4955 | c10::str(ConstantPad2d(ConstantPad2dOptions(2, 3.5))), |
4956 | "torch::nn::ConstantPad2d(padding=[2, 2, 2, 2], value=3.5)" ); |
4957 | ASSERT_EQ( |
4958 | c10::str(ConstantPad2d(ConstantPad2dOptions({3, 0, 2, 1}, 3.5))), |
4959 | "torch::nn::ConstantPad2d(padding=[3, 0, 2, 1], value=3.5)" ); |
4960 | ASSERT_EQ( |
4961 | c10::str(ConstantPad3d(ConstantPad3dOptions(1, 3.5))), |
4962 | "torch::nn::ConstantPad3d(padding=[1, 1, 1, 1, 1, 1], value=3.5)" ); |
4963 | ASSERT_EQ( |
4964 | c10::str(ConstantPad3d(ConstantPad3dOptions({1, 2, 1, 2, 1, 2}, 3.5))), |
4965 | "torch::nn::ConstantPad3d(padding=[1, 2, 1, 2, 1, 2], value=3.5)" ); |
4966 | } |
4967 | |
4968 | TEST_F(ModulesTest, PrettyPrintNestedModel) { |
4969 | struct InnerTestModule : torch::nn::Module { |
4970 | InnerTestModule() |
4971 | : torch::nn::Module("InnerTestModule" ), |
4972 | fc(register_module("fc" , torch::nn::Linear(3, 4))), |
4973 | table(register_module("table" , torch::nn::Embedding(10, 2))) {} |
4974 | |
4975 | torch::nn::Linear fc; |
4976 | torch::nn::Embedding table; |
4977 | }; |
4978 | |
4979 | struct TestModule : torch::nn::Module { |
4980 | TestModule() |
4981 | : torch::nn::Module("TestModule" ), |
4982 | fc(register_module("fc" , torch::nn::Linear(4, 5))), |
4983 | table(register_module( |
4984 | "table" , |
4985 | torch::nn::Embedding(EmbeddingOptions(10, 2)))), |
4986 | inner(register_module("inner" , std::make_shared<InnerTestModule>())) { |
4987 | } |
4988 | |
4989 | torch::nn::Linear fc; |
4990 | torch::nn::Embedding table; |
4991 | std::shared_ptr<InnerTestModule> inner; |
4992 | }; |
4993 | |
4994 | ASSERT_EQ( |
4995 | c10::str(TestModule{}), |
4996 | "TestModule(\n" |
4997 | " (fc): torch::nn::Linear(in_features=4, out_features=5, bias=true)\n" |
4998 | " (table): torch::nn::Embedding(num_embeddings=10, embedding_dim=2)\n" |
4999 | " (inner): InnerTestModule(\n" |
5000 | " (fc): torch::nn::Linear(in_features=3, out_features=4, bias=true)\n" |
5001 | " (table): torch::nn::Embedding(num_embeddings=10, embedding_dim=2)\n" |
5002 | " )\n" |
5003 | ")" ); |
5004 | } |
5005 | |
5006 | TEST_F(ModulesTest, PrettyPrintELU) { |
5007 | ASSERT_EQ(c10::str(ELU()), "torch::nn::ELU(alpha=1)" ); |
5008 | ASSERT_EQ( |
5009 | c10::str(ELU(ELUOptions().alpha(42.42).inplace(true))), |
5010 | "torch::nn::ELU(alpha=42.42, inplace=true)" ); |
5011 | } |
5012 | |
5013 | TEST_F(ModulesTest, PrettyPrintSELU) { |
5014 | ASSERT_EQ(c10::str(SELU()), "torch::nn::SELU()" ); |
5015 | ASSERT_EQ( |
5016 | c10::str(SELU(SELUOptions().inplace(true))), |
5017 | "torch::nn::SELU(inplace=true)" ); |
5018 | } |
5019 | |
5020 | TEST_F(ModulesTest, PrettyPrintGLU) { |
5021 | ASSERT_EQ(c10::str(GLU()), "torch::nn::GLU(dim=-1)" ); |
5022 | ASSERT_EQ(c10::str(GLU(1)), "torch::nn::GLU(dim=1)" ); |
5023 | } |
5024 | |
5025 | TEST_F(ModulesTest, PrettyPrintHardshrink) { |
5026 | ASSERT_EQ(c10::str(Hardshrink()), "torch::nn::Hardshrink(0.5)" ); |
5027 | ASSERT_EQ( |
5028 | c10::str(Hardshrink(HardshrinkOptions().lambda(42.42))), |
5029 | "torch::nn::Hardshrink(42.42)" ); |
5030 | } |
5031 | |
5032 | TEST_F(ModulesTest, PrettyPrintHardtanh) { |
5033 | ASSERT_EQ(c10::str(Hardtanh()), "torch::nn::Hardtanh(min_val=-1, max_val=1)" ); |
5034 | ASSERT_EQ( |
5035 | c10::str(Hardtanh( |
5036 | HardtanhOptions().min_val(-42.42).max_val(0.42).inplace(true))), |
5037 | "torch::nn::Hardtanh(min_val=-42.42, max_val=0.42, inplace=true)" ); |
5038 | } |
5039 | |
5040 | TEST_F(ModulesTest, PrettyPrintLeakyReLU) { |
5041 | ASSERT_EQ(c10::str(LeakyReLU()), "torch::nn::LeakyReLU(negative_slope=0.01)" ); |
5042 | ASSERT_EQ( |
5043 | c10::str( |
5044 | LeakyReLU(LeakyReLUOptions().negative_slope(0.42).inplace(true))), |
5045 | "torch::nn::LeakyReLU(negative_slope=0.42, inplace=true)" ); |
5046 | } |
5047 | |
5048 | TEST_F(ModulesTest, PrettyPrintLogSigmoid) { |
5049 | ASSERT_EQ(c10::str(LogSigmoid()), "torch::nn::LogSigmoid()" ); |
5050 | } |
5051 | |
5052 | TEST_F(ModulesTest, PrettyPrintSoftmax) { |
5053 | ASSERT_EQ(c10::str(Softmax(SoftmaxOptions(1))), "torch::nn::Softmax(dim=1)" ); |
5054 | } |
5055 | |
5056 | TEST_F(ModulesTest, PrettyPrintSoftmin) { |
5057 | ASSERT_EQ(c10::str(Softmin(SoftminOptions(1))), "torch::nn::Softmin(dim=1)" ); |
5058 | } |
5059 | |
5060 | TEST_F(ModulesTest, PrettyPrintLogSoftmax) { |
5061 | ASSERT_EQ( |
5062 | c10::str(LogSoftmax(LogSoftmaxOptions(1))), |
5063 | "torch::nn::LogSoftmax(dim=1)" ); |
5064 | } |
5065 | |
5066 | TEST_F(ModulesTest, PrettyPrintSoftmax2d) { |
5067 | ASSERT_EQ(c10::str(Softmax2d()), "torch::nn::Softmax2d()" ); |
5068 | } |
5069 | |
5070 | TEST_F(ModulesTest, PrettyPrintPReLU) { |
5071 | ASSERT_EQ(c10::str(PReLU()), "torch::nn::PReLU(num_parameters=1)" ); |
5072 | ASSERT_EQ( |
5073 | c10::str(PReLU(PReLUOptions().num_parameters(42))), |
5074 | "torch::nn::PReLU(num_parameters=42)" ); |
5075 | } |
5076 | |
5077 | TEST_F(ModulesTest, PrettyPrintReLU) { |
5078 | ASSERT_EQ(c10::str(ReLU()), "torch::nn::ReLU()" ); |
5079 | ASSERT_EQ( |
5080 | c10::str(ReLU(ReLUOptions().inplace(true))), |
5081 | "torch::nn::ReLU(inplace=true)" ); |
5082 | ASSERT_EQ(c10::str(ReLU(/*inplace=*/true)), "torch::nn::ReLU(inplace=true)" ); |
5083 | } |
5084 | |
5085 | TEST_F(ModulesTest, PrettyPrintReLU6) { |
5086 | ASSERT_EQ(c10::str(ReLU6()), "torch::nn::ReLU6()" ); |
5087 | ASSERT_EQ( |
5088 | c10::str(ReLU6(ReLU6Options().inplace(true))), |
5089 | "torch::nn::ReLU6(inplace=true)" ); |
5090 | ASSERT_EQ( |
5091 | c10::str(ReLU6(/*inplace=*/true)), "torch::nn::ReLU6(inplace=true)" ); |
5092 | } |
5093 | |
5094 | TEST_F(ModulesTest, PrettyPrintRReLU) { |
5095 | ASSERT_EQ(c10::str(RReLU()), "torch::nn::RReLU(lower=0.125, upper=0.333333)" ); |
5096 | ASSERT_EQ( |
5097 | c10::str(RReLU(RReLUOptions().lower(0.24).upper(0.42).inplace(true))), |
5098 | "torch::nn::RReLU(lower=0.24, upper=0.42, inplace=true)" ); |
5099 | } |
5100 | |
5101 | TEST_F(ModulesTest, PrettyPrintCELU) { |
5102 | ASSERT_EQ(c10::str(CELU()), "torch::nn::CELU(alpha=1)" ); |
5103 | ASSERT_EQ( |
5104 | c10::str(CELU(CELUOptions().alpha(42.42).inplace(true))), |
5105 | "torch::nn::CELU(alpha=42.42, inplace=true)" ); |
5106 | } |
5107 | |
5108 | TEST_F(ModulesTest, PrettyPrintSigmoid) { |
5109 | ASSERT_EQ(c10::str(Sigmoid()), "torch::nn::Sigmoid()" ); |
5110 | } |
5111 | |
5112 | TEST_F(ModulesTest, PrettyPrintPixelShuffle) { |
5113 | ASSERT_EQ( |
5114 | c10::str(PixelShuffle(PixelShuffleOptions(5))), |
5115 | "torch::nn::PixelShuffle(upscale_factor=5)" ); |
5116 | } |
5117 | |
5118 | TEST_F(ModulesTest, PrettyPrintPixelUnshuffle) { |
5119 | ASSERT_EQ( |
5120 | c10::str(PixelUnshuffle(PixelUnshuffleOptions(5))), |
5121 | "torch::nn::PixelUnshuffle(downscale_factor=5)" ); |
5122 | } |
5123 | |
5124 | TEST_F(ModulesTest, PrettyPrintSoftplus) { |
5125 | ASSERT_EQ(c10::str(Softplus()), "torch::nn::Softplus(beta=1, threshold=20)" ); |
5126 | ASSERT_EQ( |
5127 | c10::str(Softplus(SoftplusOptions().beta(0.24).threshold(42.42))), |
5128 | "torch::nn::Softplus(beta=0.24, threshold=42.42)" ); |
5129 | } |
5130 | |
5131 | TEST_F(ModulesTest, PrettyPrintSoftshrink) { |
5132 | ASSERT_EQ(c10::str(Softshrink()), "torch::nn::Softshrink(0.5)" ); |
5133 | ASSERT_EQ( |
5134 | c10::str(Softshrink(SoftshrinkOptions(42.42))), |
5135 | "torch::nn::Softshrink(42.42)" ); |
5136 | } |
5137 | |
5138 | TEST_F(ModulesTest, PrettyPrintSoftsign) { |
5139 | ASSERT_EQ(c10::str(Softsign()), "torch::nn::Softsign()" ); |
5140 | } |
5141 | |
5142 | TEST_F(ModulesTest, PrettyPrintTanh) { |
5143 | ASSERT_EQ(c10::str(Tanh()), "torch::nn::Tanh()" ); |
5144 | } |
5145 | |
5146 | TEST_F(ModulesTest, PrettyPrintTanhshrink) { |
5147 | ASSERT_EQ(c10::str(Tanhshrink()), "torch::nn::Tanhshrink()" ); |
5148 | } |
5149 | |
5150 | TEST_F(ModulesTest, PrettyPrintThreshold) { |
5151 | ASSERT_EQ( |
5152 | c10::str(Threshold(24.24, 42.42)), |
5153 | "torch::nn::Threshold(threshold=24.24, value=42.42)" ); |
5154 | ASSERT_EQ( |
5155 | c10::str(Threshold(ThresholdOptions(42.42, 24.24).inplace(true))), |
5156 | "torch::nn::Threshold(threshold=42.42, value=24.24, inplace=true)" ); |
5157 | } |
5158 | |
5159 | TEST_F(ModulesTest, PrettyPrintCTCLoss) { |
5160 | ASSERT_EQ(c10::str(CTCLoss()), "torch::nn::CTCLoss()" ); |
5161 | ASSERT_EQ( |
5162 | c10::str( |
5163 | CTCLoss(CTCLossOptions().blank(42).zero_infinity(false).reduction( |
5164 | torch::kSum))), |
5165 | "torch::nn::CTCLoss()" ); |
5166 | } |
5167 | |
5168 | TEST_F(ModulesTest, PrettyPrintPoissonNLLLoss) { |
5169 | ASSERT_EQ(c10::str(PoissonNLLLoss()), "torch::nn::PoissonNLLLoss()" ); |
5170 | ASSERT_EQ( |
5171 | c10::str(PoissonNLLLoss(PoissonNLLLossOptions() |
5172 | .log_input(false) |
5173 | .full(true) |
5174 | .eps(0.42) |
5175 | .reduction(torch::kSum))), |
5176 | "torch::nn::PoissonNLLLoss()" ); |
5177 | } |
5178 | |
5179 | TEST_F(ModulesTest, PrettyPrintMarginRankingLoss) { |
5180 | ASSERT_EQ(c10::str(MarginRankingLoss()), "torch::nn::MarginRankingLoss()" ); |
5181 | ASSERT_EQ( |
5182 | c10::str(MarginRankingLoss( |
5183 | MarginRankingLossOptions().margin(0.5).reduction(torch::kSum))), |
5184 | "torch::nn::MarginRankingLoss()" ); |
5185 | } |
5186 | |
5187 | TEST_F(ModulesTest, PrettyPrintCrossMapLRN2d) { |
5188 | ASSERT_EQ( |
5189 | c10::str(CrossMapLRN2d(4)), |
5190 | "torch::nn::CrossMapLRN2d(4, alpha=0.0001, beta=0.75, k=1)" ); |
5191 | ASSERT_EQ( |
5192 | c10::str( |
5193 | CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-5).beta(0.1).k(10))), |
5194 | "torch::nn::CrossMapLRN2d(3, alpha=1e-05, beta=0.1, k=10)" ); |
5195 | } |
5196 | |
5197 | TEST_F(ModulesTest, PrettyPrintAlphaDropout) { |
5198 | ASSERT_EQ( |
5199 | c10::str(AlphaDropout()), |
5200 | "torch::nn::AlphaDropout(p=0.5, inplace=false)" ); |
5201 | ASSERT_EQ( |
5202 | c10::str(AlphaDropout(AlphaDropoutOptions(0.2))), |
5203 | "torch::nn::AlphaDropout(p=0.2, inplace=false)" ); |
5204 | ASSERT_EQ( |
5205 | c10::str(AlphaDropout(AlphaDropoutOptions(0.2).inplace(true))), |
5206 | "torch::nn::AlphaDropout(p=0.2, inplace=true)" ); |
5207 | } |
5208 | |
5209 | TEST_F(ModulesTest, PrettyPrintFeatureAlphaDropout) { |
5210 | ASSERT_EQ( |
5211 | c10::str(FeatureAlphaDropout()), |
5212 | "torch::nn::FeatureAlphaDropout(p=0.5, inplace=false)" ); |
5213 | ASSERT_EQ( |
5214 | c10::str(FeatureAlphaDropout(FeatureAlphaDropoutOptions(0.2))), |
5215 | "torch::nn::FeatureAlphaDropout(p=0.2, inplace=false)" ); |
5216 | ASSERT_EQ( |
5217 | c10::str( |
5218 | FeatureAlphaDropout(FeatureAlphaDropoutOptions(0.2).inplace(true))), |
5219 | "torch::nn::FeatureAlphaDropout(p=0.2, inplace=true)" ); |
5220 | } |
5221 | |
5222 | TEST_F(ModulesTest, PrettyPrintBCEWithLogitsLoss) { |
5223 | ASSERT_EQ(c10::str(BCEWithLogitsLoss()), "torch::nn::BCEWithLogitsLoss()" ); |
5224 | ASSERT_EQ( |
5225 | c10::str(BCEWithLogitsLoss(BCEWithLogitsLossOptions() |
5226 | .weight(torch::ones({3, 3})) |
5227 | .pos_weight(torch::ones({3, 3})) |
5228 | .reduction(torch::kSum))), |
5229 | "torch::nn::BCEWithLogitsLoss()" ); |
5230 | } |
5231 | |
5232 | TEST_F(ModulesTest, PrettyPrintMultiheadAttention) { |
5233 | ASSERT_EQ( |
5234 | c10::str(MultiheadAttention(20, 10)), |
5235 | "torch::nn::MultiheadAttention(\n (out_proj): torch::nn::Linear(in_features=20, out_features=20, bias=true)\n)" ); |
5236 | ASSERT_EQ( |
5237 | c10::str( |
5238 | MultiheadAttention(MultiheadAttentionOptions(20, 10).bias(false))), |
5239 | "torch::nn::MultiheadAttention(\n (out_proj): torch::nn::Linear(in_features=20, out_features=20, bias=false)\n)" ); |
5240 | } |
5241 | |
5242 | TEST_F(ModulesTest, PrettyPrintRNNCell) { |
5243 | ASSERT_EQ(c10::str(RNNCell(20, 10)), "torch::nn::RNNCell(20, 10)" ); |
5244 | ASSERT_EQ( |
5245 | c10::str(RNNCell( |
5246 | RNNCellOptions(20, 10).bias(false).nonlinearity(torch::kTanh))), |
5247 | "torch::nn::RNNCell(20, 10, bias=false)" ); |
5248 | ASSERT_EQ( |
5249 | c10::str(RNNCell( |
5250 | RNNCellOptions(20, 10).bias(false).nonlinearity(torch::kReLU))), |
5251 | "torch::nn::RNNCell(20, 10, bias=false, nonlinearity=kReLU)" ); |
5252 | } |
5253 | |
5254 | TEST_F(ModulesTest, PrettyPrintLSTMCell) { |
5255 | ASSERT_EQ(c10::str(LSTMCell(20, 10)), "torch::nn::LSTMCell(20, 10)" ); |
5256 | ASSERT_EQ( |
5257 | c10::str(LSTMCell(LSTMCellOptions(20, 10).bias(false))), |
5258 | "torch::nn::LSTMCell(20, 10, bias=false)" ); |
5259 | } |
5260 | |
5261 | TEST_F(ModulesTest, PrettyPrintGRUCell) { |
5262 | ASSERT_EQ(c10::str(GRUCell(20, 10)), "torch::nn::GRUCell(20, 10)" ); |
5263 | ASSERT_EQ( |
5264 | c10::str(GRUCell(GRUCellOptions(20, 10).bias(false))), |
5265 | "torch::nn::GRUCell(20, 10, bias=false)" ); |
5266 | } |
5267 | |
5268 | TEST_F(ModulesTest, PrettyPrintAdaptiveLogSoftmaxWithLoss) { |
5269 | { |
5270 | AdaptiveLogSoftmaxWithLoss asfm( |
5271 | AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.)); |
5272 | ASSERT_EQ( |
5273 | c10::str(asfm), |
5274 | "torch::nn::AdaptiveLogSoftmaxWithLoss(\n" |
5275 | " (head): torch::nn::Linear(in_features=8, out_features=3, bias=false)\n" |
5276 | " (tail): torch::nn::ModuleList(\n" |
5277 | " (0): torch::nn::Sequential(\n" |
5278 | " (0): torch::nn::Linear(in_features=8, out_features=4, bias=false)\n" |
5279 | " (1): torch::nn::Linear(in_features=4, out_features=2, bias=false)\n" |
5280 | " )\n" |
5281 | " )\n" |
5282 | ")" ); |
5283 | } |
5284 | { |
5285 | AdaptiveLogSoftmaxWithLoss asfm( |
5286 | AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8}) |
5287 | .div_value(2.) |
5288 | .head_bias(true)); |
5289 | ASSERT_EQ( |
5290 | c10::str(asfm), |
5291 | "torch::nn::AdaptiveLogSoftmaxWithLoss(\n" |
5292 | " (head): torch::nn::Linear(in_features=8, out_features=6, bias=true)\n" |
5293 | " (tail): torch::nn::ModuleList(\n" |
5294 | " (0): torch::nn::Sequential(\n" |
5295 | " (0): torch::nn::Linear(in_features=8, out_features=4, bias=false)\n" |
5296 | " (1): torch::nn::Linear(in_features=4, out_features=4, bias=false)\n" |
5297 | " )\n" |
5298 | " (1): torch::nn::Sequential(\n" |
5299 | " (0): torch::nn::Linear(in_features=8, out_features=2, bias=false)\n" |
5300 | " (1): torch::nn::Linear(in_features=2, out_features=2, bias=false)\n" |
5301 | " )\n" |
5302 | " )\n" |
5303 | ")" ); |
5304 | } |
5305 | } |
5306 | |