1#include <gtest/gtest.h>
2
3#include <c10/util/irange.h>
4#include <torch/torch.h>
5
6#include <test/cpp/api/support.h>
7
8#include <torch/expanding_array.h>
9#include <torch/nn/functional/activation.h>
10#include <torch/nn/options/activation.h>
11#include <limits>
12#include <random>
13
14using namespace torch::nn;
15using namespace torch::test;
16
17class TestModel : public torch::nn::Module {
18 public:
19 TestModel()
20 : l1(register_module("l1", Linear(10, 3))),
21 l2(register_module("l2", Linear(3, 5))),
22 l3(register_module("l3", Linear(5, 100))) {}
23
24 Linear l1, l2, l3;
25};
26
27class NestedModel : public torch::nn::Module {
28 public:
29 NestedModel()
30 : param_(register_parameter("param", torch::empty({3, 2, 21}))),
31 l1(register_module("l1", Linear(5, 20))),
32 t(register_module("test", std::make_shared<TestModel>())) {}
33
34 torch::Tensor param_;
35 Linear l1;
36 std::shared_ptr<TestModel> t;
37};
38
39struct ModulesTest : torch::test::SeedingFixture {};
40
41TEST_F(ModulesTest, Conv1d) {
42 Conv1d model(Conv1dOptions(3, 2, 3).stride(1).bias(false));
43 model->weight.set_data(
44 torch::arange(18, torch::dtype(torch::kFloat)).reshape({2, 3, 3}));
45 auto x = torch::arange(30, torch::dtype(torch::kFloat).requires_grad(true))
46 .reshape({2, 3, 5});
47 auto y = model(x);
48 auto expected = torch::tensor(
49 {{{312., 348., 384.}, {798., 915., 1032.}},
50
51 {{852., 888., 924.}, {2553., 2670., 2787.}}},
52 torch::kFloat);
53 ASSERT_TRUE(torch::allclose(y, expected));
54
55 torch::Tensor s = y.sum();
56 s.backward();
57 ASSERT_EQ(s.ndimension(), 0);
58 ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3);
59}
60
61TEST_F(ModulesTest, Conv1dSameStrided) {
62 auto options = Conv1dOptions(3, 2, 3);
63 options.stride(1).padding(torch::kSame);
64 Conv1d model_valid(options);
65 ASSERT_THROWS_WITH(
66 [&] { Conv1d model_invalid(options.stride(2)); }(),
67 "padding='same' is not supported for strided convolutions");
68}
69
70TEST_F(ModulesTest, Conv2dEven) {
71 Conv2d model(Conv2dOptions(3, 2, 3).stride(1).bias(false));
72 model->weight.set_data(
73 torch::arange(54, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 3}));
74 auto x = torch::arange(75, torch::dtype(torch::kFloat).requires_grad(true))
75 .reshape({1, 3, 5, 5});
76 auto y = model(x);
77 auto expected = torch::tensor(
78 {{{{15219., 15570., 15921.},
79 {16974., 17325., 17676.},
80 {18729., 19080., 19431.}},
81
82 {{37818., 38898., 39978.},
83 {43218., 44298., 45378.},
84 {48618., 49698., 50778.}}}},
85 torch::kFloat);
86 ASSERT_TRUE(torch::allclose(y, expected));
87
88 torch::Tensor s = y.sum();
89 s.backward();
90 ASSERT_EQ(s.ndimension(), 0);
91 ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 3);
92}
93
94TEST_F(ModulesTest, Conv2dUneven) {
95 Conv2d model(Conv2dOptions(3, 2, {3, 2}).stride({1, 1}).bias(false));
96 model->weight.set_data(
97 torch::arange(36, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 2}));
98 auto x = torch::arange(60, torch::dtype(torch::kFloat).requires_grad(true))
99 .reshape({1, 3, 5, 4});
100 auto y = model(x);
101 auto expected = torch::tensor(
102 {{{{5289., 5442., 5595.}, {5901., 6054., 6207.}, {6513., 6666., 6819.}},
103
104 {{13227., 13704., 14181.},
105 {15135., 15612., 16089.},
106 {17043., 17520., 17997.}}}},
107 torch::kFloat);
108 ASSERT_TRUE(torch::allclose(y, expected));
109
110 torch::Tensor s = y.sum();
111 s.backward();
112 ASSERT_EQ(s.ndimension(), 0);
113 ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 2);
114}
115
116TEST_F(ModulesTest, Conv2dSameStrided) {
117 auto options = Conv2dOptions(3, 2, {3, 4});
118 options.stride(1).padding(torch::kSame);
119 Conv2d model_valid(options);
120 ASSERT_THROWS_WITH(
121 [&] { Conv2d model_invalid(options.stride(2)); }(),
122 "padding='same' is not supported for strided convolutions");
123 ASSERT_THROWS_WITH(
124 [&] {
125 Conv2d model_invalid(options.stride({1, 2}));
126 }(),
127 "padding='same' is not supported for strided convolutions");
128}
129
130TEST_F(ModulesTest, Conv3d) {
131 Conv3d model(Conv3dOptions(3, 2, 3).stride(1).bias(false));
132 model->weight.set_data(
133 torch::arange(162, torch::dtype(torch::kFloat)).reshape({2, 3, 3, 3, 3}));
134 auto x = torch::arange(375, torch::dtype(torch::kFloat).requires_grad(true))
135 .reshape({1, 3, 5, 5, 5});
136 auto y = model(x);
137 auto expected = torch::tensor(
138 {{{{{700704., 703944., 707184.},
139 {716904., 720144., 723384.},
140 {733104., 736344., 739584.}},
141
142 {{781704., 784944., 788184.},
143 {797904., 801144., 804384.},
144 {814104., 817344., 820584.}},
145
146 {{862704., 865944., 869184.},
147 {878904., 882144., 885384.},
148 {895104., 898344., 901584.}}},
149
150 {{{1724220., 1734021., 1743822.},
151 {1773225., 1783026., 1792827.},
152 {1822230., 1832031., 1841832.}},
153
154 {{1969245., 1979046., 1988847.},
155 {2018250., 2028051., 2037852.},
156 {2067255., 2077056., 2086857.}},
157
158 {{2214270., 2224071., 2233872.},
159 {2263275., 2273076., 2282877.},
160 {2312280., 2322081., 2331882.}}}}},
161 torch::kFloat);
162 ASSERT_TRUE(torch::allclose(y, expected));
163
164 torch::Tensor s = y.sum();
165 s.backward();
166 ASSERT_EQ(s.ndimension(), 0);
167 ASSERT_TRUE(model->weight.grad().numel() == 3 * 2 * 3 * 3 * 3);
168}
169
170TEST_F(ModulesTest, Conv3dSameStrided) {
171 auto options = Conv3dOptions(3, 2, {3, 4, 5});
172 options.stride(1).padding(torch::kSame);
173 Conv3d model_valid(options);
174 ASSERT_THROWS_WITH(
175 [&] { Conv3d model_invalid(options.stride(2)); }(),
176 "padding='same' is not supported for strided convolutions");
177 ASSERT_THROWS_WITH(
178 [&] {
179 Conv3d model_invalid(options.stride({1, 2, 1}));
180 }(),
181 "padding='same' is not supported for strided convolutions");
182}
183
184TEST_F(ModulesTest, ConvTranspose1d) {
185 ConvTranspose1d model(ConvTranspose1dOptions(3, 2, 3).stride(1).bias(false));
186 model->weight.set_data(torch::arange(18.).view({2, 3, 3}));
187 auto x = torch::arange(20.).reshape({2, 2, 5});
188 auto y = model(x);
189 auto expected = torch::tensor(
190 {{{45., 104., 179., 212., 245., 188., 107.},
191 {60., 140., 242., 293., 344., 260., 146.},
192 {75., 176., 305., 374., 443., 332., 185.}},
193 {{135., 304., 509., 542., 575., 428., 237.},
194 {210., 460., 752., 803., 854., 620., 336.},
195 {285., 616., 995., 1064., 1133., 812., 435.}}});
196 ASSERT_TRUE(torch::allclose(y, expected));
197
198 torch::Tensor s = y.sum();
199 s.backward();
200 ASSERT_EQ(s.ndimension(), 0);
201 ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3);
202}
203
204TEST_F(ModulesTest, ConvTranspose2dEven) {
205 ConvTranspose2d model(ConvTranspose2dOptions(3, 2, 3).stride(1).bias(false));
206 model->weight.set_data(torch::arange(54.).view({2, 3, 3, 3}));
207 auto x = torch::arange(50.).view({1, 2, 5, 5});
208 auto y = model(x);
209 auto expected = torch::tensor(
210 {{{{675., 1402., 2183., 2270., 2357., 1634., 849.},
211 {1560., 3240., 5044., 5236., 5428., 3760., 1952.},
212 {2685., 5574., 8673., 8988., 9303., 6438., 3339.},
213 {3180., 6594., 10248., 10563., 10878., 7518., 3894.},
214 {3675., 7614., 11823., 12138., 12453., 8598., 4449.},
215 {2820., 5832., 9040., 9268., 9496., 6544., 3380.},
216 {1605., 3314., 5129., 5252., 5375., 3698., 1907.}},
217 {{900., 1870., 2912., 3053., 3194., 2210., 1146.},
218 {2100., 4356., 6772., 7072., 7372., 5092., 2636.},
219 {3630., 7518., 11670., 12147., 12624., 8706., 4500.},
220 {4395., 9078., 14055., 14532., 15009., 10326., 5325.},
221 {5160., 10638., 16440., 16917., 17394., 11946., 6150.},
222 {3900., 8028., 12388., 12724., 13060., 8956., 4604.},
223 {2190., 4502., 6938., 7115., 7292., 4994., 2564.}},
224 {{1125., 2338., 3641., 3836., 4031., 2786., 1443.},
225 {2640., 5472., 8500., 8908., 9316., 6424., 3320.},
226 {4575., 9462., 14667., 15306., 15945., 10974., 5661.},
227 {5610., 11562., 17862., 18501., 19140., 13134., 6756.},
228 {6645., 13662., 21057., 21696., 22335., 15294., 7851.},
229 {4980., 10224., 15736., 16180., 16624., 11368., 5828.},
230 {2775., 5690., 8747., 8978., 9209., 6290., 3221.}}}});
231 ASSERT_TRUE(torch::allclose(y, expected));
232
233 torch::Tensor s = y.sum();
234 s.backward();
235 ASSERT_EQ(s.ndimension(), 0);
236 ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 3);
237}
238
239TEST_F(ModulesTest, ConvTranspose2dUneven) {
240 ConvTranspose2d model(
241 ConvTranspose2dOptions(3, 2, {3, 2}).stride({1, 1}).bias(false));
242 model->weight.set_data(torch::arange(36.).view({2, 3, 3, 2}));
243 auto x = torch::arange(40.).view({1, 2, 5, 4});
244 auto y = model(x);
245 auto expected = torch::tensor(
246 {{{{360., 758., 796., 834., 440.},
247 {832., 1752., 1836., 1920., 1012.},
248 {1432., 3014., 3152., 3290., 1732.},
249 {1696., 3566., 3704., 3842., 2020.},
250 {1960., 4118., 4256., 4394., 2308.},
251 {1504., 3152., 3252., 3352., 1756.},
252 {856., 1790., 1844., 1898., 992.}},
253 {{480., 1010., 1072., 1134., 596.},
254 {1120., 2352., 2484., 2616., 1372.},
255 {1936., 4058., 4268., 4478., 2344.},
256 {2344., 4898., 5108., 5318., 2776.},
257 {2752., 5738., 5948., 6158., 3208.},
258 {2080., 4328., 4476., 4624., 2404.},
259 {1168., 2426., 2504., 2582., 1340.}},
260 {{600., 1262., 1348., 1434., 752.},
261 {1408., 2952., 3132., 3312., 1732.},
262 {2440., 5102., 5384., 5666., 2956.},
263 {2992., 6230., 6512., 6794., 3532.},
264 {3544., 7358., 7640., 7922., 4108.},
265 {2656., 5504., 5700., 5896., 3052.},
266 {1480., 3062., 3164., 3266., 1688.}}}});
267 ASSERT_TRUE(torch::allclose(y, expected));
268
269 torch::Tensor s = y.sum();
270 s.backward();
271 ASSERT_EQ(s.ndimension(), 0);
272 ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 2);
273}
274
275TEST_F(ModulesTest, ConvTranspose3d) {
276 ConvTranspose3d model(ConvTranspose3dOptions(2, 2, 2).stride(1).bias(false));
277 model->weight.set_data(torch::arange(32.).reshape({2, 2, 2, 2, 2}));
278 auto x = torch::arange(16.).reshape({1, 2, 2, 2, 2});
279 auto y = model(x);
280 auto expected = torch::tensor(
281 {{{{{128., 280., 154.}, {304., 664., 364.}, {184., 400., 218.}},
282 {{352., 768., 420.}, {832., 1808., 984.}, {496., 1072., 580.}},
283 {{256., 552., 298.}, {592., 1272., 684.}, {344., 736., 394.}}},
284 {{{192., 424., 234.}, {464., 1016., 556.}, {280., 608., 330.}},
285 {{544., 1184., 644.}, {1280., 2768., 1496.}, {752., 1616., 868.}},
286 {{384., 824., 442.}, {880., 1880., 1004.}, {504., 1072., 570.}}}}});
287 ASSERT_TRUE(torch::allclose(y, expected));
288
289 torch::Tensor s = y.sum();
290 s.backward();
291 ASSERT_EQ(s.ndimension(), 0);
292 ASSERT_TRUE(model->weight.grad().numel() == 2 * 2 * 2 * 2 * 2);
293}
294
295TEST_F(ModulesTest, MaxPool1d) {
296 MaxPool1d model(MaxPool1dOptions(3).stride(2));
297 auto x = torch::ones({1, 1, 5}, torch::requires_grad());
298 auto y = model(x);
299 torch::Tensor s = y.sum();
300
301 s.backward();
302 ASSERT_EQ(y.ndimension(), 3);
303 ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2})));
304 ASSERT_EQ(s.ndimension(), 0);
305 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2}));
306}
307
308TEST_F(ModulesTest, MaxPool1dReturnIndices) {
309 MaxPool1d model(MaxPool1dOptions(3).stride(2));
310 auto x = torch::ones({1, 1, 5}, torch::requires_grad());
311 torch::Tensor y, indices;
312 std::tie(y, indices) = model->forward_with_indices(x);
313
314 ASSERT_EQ(y.dim(), 3);
315 ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2})));
316 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2}));
317
318 ASSERT_TRUE(
319 torch::allclose(indices, torch::tensor({{{0, 2}}}, torch::kLong)));
320 ASSERT_EQ(indices.sizes(), std::vector<int64_t>({1, 1, 2}));
321}
322
323TEST_F(ModulesTest, MaxPool2dEven) {
324 MaxPool2d model(MaxPool2dOptions(3).stride(2));
325 auto x = torch::ones({2, 5, 5}, torch::requires_grad());
326 auto y = model(x);
327 torch::Tensor s = y.sum();
328
329 s.backward();
330 ASSERT_EQ(y.ndimension(), 3);
331 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
332 ASSERT_EQ(s.ndimension(), 0);
333 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
334}
335
336TEST_F(ModulesTest, MaxPool2dUneven) {
337 MaxPool2d model(MaxPool2dOptions({3, 2}).stride({2, 2}));
338 auto x = torch::ones({2, 5, 4}, torch::requires_grad());
339 auto y = model(x);
340 torch::Tensor s = y.sum();
341
342 s.backward();
343 ASSERT_EQ(y.ndimension(), 3);
344 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
345 ASSERT_EQ(s.ndimension(), 0);
346 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
347}
348
349TEST_F(ModulesTest, MaxPool2dReturnIndices) {
350 MaxPool2d model(MaxPool2dOptions(3).stride(2));
351 auto x = torch::ones({2, 5, 5}, torch::requires_grad());
352 torch::Tensor y, indices;
353 std::tie(y, indices) = model->forward_with_indices(x);
354
355 ASSERT_EQ(y.dim(), 3);
356 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
357 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
358 ASSERT_TRUE(torch::allclose(
359 indices,
360 torch::tensor({{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}}, torch::kLong)));
361 ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 2, 2}));
362}
363
364TEST_F(ModulesTest, MaxPool3d) {
365 MaxPool3d model(MaxPool3dOptions(3).stride(2));
366 auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad());
367 auto y = model(x);
368 torch::Tensor s = y.sum();
369
370 s.backward();
371 ASSERT_EQ(y.ndimension(), 4);
372 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
373 ASSERT_EQ(s.ndimension(), 0);
374 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
375}
376
377TEST_F(ModulesTest, MaxPool3dReturnIndices) {
378 MaxPool3d model(MaxPool3dOptions(3).stride(2));
379 auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad());
380 torch::Tensor y, indices;
381 std::tie(y, indices) = model->forward_with_indices(x);
382
383 ASSERT_EQ(y.dim(), 4);
384 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
385 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
386
387 ASSERT_TRUE(torch::allclose(
388 indices,
389 torch::tensor(
390 {{{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}},
391 {{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}}},
392 torch::kLong)));
393 ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
394}
395
396TEST_F(ModulesTest, AvgPool1d) {
397 AvgPool1d model(AvgPool1dOptions(3).stride(2));
398 auto x = torch::ones({1, 1, 5}, torch::requires_grad());
399 auto y = model(x);
400 torch::Tensor s = y.sum();
401
402 s.backward();
403 ASSERT_EQ(y.ndimension(), 3);
404 ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2})));
405 ASSERT_EQ(s.ndimension(), 0);
406 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2}));
407}
408
409TEST_F(ModulesTest, AvgPool2dEven) {
410 AvgPool2d model(AvgPool2dOptions(3).stride(2));
411 auto x = torch::ones({2, 5, 5}, torch::requires_grad());
412 auto y = model(x);
413 torch::Tensor s = y.sum();
414
415 s.backward();
416 ASSERT_EQ(y.ndimension(), 3);
417 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
418 ASSERT_EQ(s.ndimension(), 0);
419 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
420}
421
422TEST_F(ModulesTest, AvgPool2dUneven) {
423 AvgPool2d model(AvgPool2dOptions({3, 2}).stride({2, 2}));
424 auto x = torch::ones({2, 5, 4}, torch::requires_grad());
425 auto y = model(x);
426 torch::Tensor s = y.sum();
427
428 s.backward();
429 ASSERT_EQ(y.ndimension(), 3);
430 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
431 ASSERT_EQ(s.ndimension(), 0);
432 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
433}
434
435TEST_F(ModulesTest, AvgPool3d) {
436 AvgPool3d model(AvgPool3dOptions(3).stride(2));
437 auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad());
438 auto y = model(x);
439 torch::Tensor s = y.sum();
440
441 s.backward();
442 ASSERT_EQ(y.ndimension(), 4);
443 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
444 ASSERT_EQ(s.ndimension(), 0);
445 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
446}
447
448TEST_F(ModulesTest, FractionalMaxPool2d) {
449 FractionalMaxPool2d model(FractionalMaxPool2dOptions(3).output_size(2));
450 auto x = torch::ones({2, 5, 5}, torch::requires_grad());
451 auto y = model(x);
452 torch::Tensor s = y.sum();
453
454 s.backward();
455 ASSERT_EQ(y.ndimension(), 3);
456 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
457 ASSERT_EQ(s.ndimension(), 0);
458 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
459}
460
461TEST_F(ModulesTest, FractionalMaxPool2dReturnIndices) {
462 FractionalMaxPool2d model(FractionalMaxPool2dOptions(3).output_size(2));
463 auto x = torch::ones({2, 5, 5}, torch::requires_grad());
464 torch::Tensor y, indices;
465 std::tie(y, indices) = model->forward_with_indices(x);
466
467 ASSERT_EQ(y.dim(), 3);
468 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
469 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
470 ASSERT_TRUE(torch::allclose(
471 indices, torch::tensor({{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}})));
472 ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 2, 2}));
473}
474
475TEST_F(ModulesTest, FractionalMaxPool3d) {
476 FractionalMaxPool3d model(FractionalMaxPool3dOptions(3).output_size(2));
477 auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad());
478 auto y = model(x);
479 torch::Tensor s = y.sum();
480
481 s.backward();
482 ASSERT_EQ(y.ndimension(), 4);
483 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
484 ASSERT_EQ(s.ndimension(), 0);
485 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
486}
487
488TEST_F(ModulesTest, FractionalMaxPool3dReturnIndices) {
489 FractionalMaxPool3d model(FractionalMaxPool3dOptions(3).output_size(2));
490 auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad());
491 torch::Tensor y, indices;
492 std::tie(y, indices) = model->forward_with_indices(x);
493
494 ASSERT_EQ(y.dim(), 4);
495 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
496 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
497
498 ASSERT_TRUE(torch::allclose(
499 indices,
500 torch::tensor(
501 {{{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}},
502 {{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}}})));
503 ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
504}
505
506TEST_F(ModulesTest, LPPool1d) {
507 int norm_type = 2;
508 int stride = 2;
509 int kernel_size = 3;
510
511 LPPool1d model(LPPool1dOptions(norm_type, kernel_size).stride(stride));
512 auto x = torch::ones({1, 1, 5});
513 auto y = model(x);
514 auto expected =
515 (torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) *
516 kernel_size)
517 .pow(1. / norm_type);
518
519 ASSERT_EQ(y.ndimension(), 3);
520 ASSERT_TRUE(torch::allclose(y, expected));
521 ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2}));
522}
523
524TEST_F(ModulesTest, LPPool2d) {
525 int norm_type = 2;
526 int stride = 2;
527 std::vector<int64_t> kernel_size({2, 3});
528
529 LPPool2d model(LPPool2dOptions(norm_type, kernel_size).stride(stride));
530 auto x = torch::ones({1, 2, 5});
531 auto y = model(x);
532 auto expected =
533 (torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) *
534 (kernel_size[0] * kernel_size[1]))
535 .pow(1. / norm_type);
536
537 ASSERT_EQ(y.ndimension(), 3);
538 ASSERT_TRUE(torch::allclose(y, expected));
539 ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2}));
540}
541
542TEST_F(ModulesTest, Identity) {
543 Identity identity;
544 auto input = torch::tensor(
545 {{1, 3, 4}, {2, 3, 4}}, torch::dtype(torch::kFloat).requires_grad(true));
546 auto output = identity->forward(input);
547 auto expected = torch::tensor({{1, 3, 4}, {2, 3, 4}}, torch::kFloat);
548 auto s = output.sum();
549 s.backward();
550
551 ASSERT_TRUE(torch::equal(output, expected));
552 ASSERT_TRUE(torch::equal(input.grad(), torch::ones_like(input)));
553}
554
555TEST_F(ModulesTest, Flatten) {
556 Flatten flatten;
557 auto input = torch::tensor(
558 {{1, 3, 4}, {2, 5, 6}}, torch::dtype(torch::kFloat).requires_grad(true));
559 auto output = flatten->forward(input);
560 auto expected = torch::tensor({{1, 3, 4}, {2, 5, 6}}, torch::kFloat);
561 auto s = output.sum();
562
563 s.backward();
564 ASSERT_TRUE(torch::equal(output, expected));
565 ASSERT_TRUE(torch::equal(input.grad(), torch::ones_like(input)));
566
567 // Testing with optional arguments start_dim and end_dim
568 Flatten flatten_optional_dims(FlattenOptions().start_dim(2).end_dim(3));
569 input = torch::tensor(
570 {{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}},
571 {{{9, 10}, {11, 12}}, {{13, 14}, {15, 16}}}},
572 torch::dtype(torch::kFloat)
573 .requires_grad(true)); // Tensor with sizes (2, 2, 2, 2)
574
575 output = flatten_optional_dims->forward(input);
576 expected = torch::tensor(
577 {{{1, 2, 3, 4}, {5, 6, 7, 8}}, {{9, 10, 11, 12}, {13, 14, 15, 16}}},
578 torch::kFloat); // Tensor with sizes (2, 2, 4)
579
580 s = output.sum();
581 s.backward();
582 ASSERT_TRUE(torch::equal(output, expected));
583 ASSERT_TRUE(torch::equal(input.grad(), torch::ones_like(input)));
584}
585
586TEST_F(ModulesTest, Unflatten) {
587 // Non-named tensor
588 Unflatten unflatten(UnflattenOptions(0, {2, 2}));
589 auto output = unflatten->forward(torch::tensor({1, 2, 3, 4}));
590 auto expected = torch::tensor({{1, 2}, {3, 4}});
591 ASSERT_TRUE(torch::equal(output, expected));
592
593 // Named tensor
594 auto make_dimnames = [](std::vector<std::string> names) {
595 std::vector<torch::Dimname> dimnames;
596 // NOLINTNEXTLINE(performance-for-range-copy)
597 for (auto name : names) {
598 // NOLINTNEXTLINE(performance-inefficient-vector-operation)
599 dimnames.push_back(
600 torch::Dimname::fromSymbol(torch::Symbol::dimname(name)));
601 }
602 return dimnames;
603 };
604
605 unflatten = Unflatten(UnflattenOptions(
606 "B",
607 {std::pair<std::string, int64_t>{"B1", 2},
608 std::pair<std::string, int64_t>{"B2", 2}}));
609 output = unflatten->forward(
610 torch::tensor({{1, 2, 3, 4}}).refine_names(make_dimnames({"A", "B"})));
611 expected = torch::tensor({{{1, 2}, {3, 4}}})
612 .refine_names(make_dimnames({"A", "B1", "B2"}));
613 ASSERT_TRUE(torch::equal(output, expected));
614}
615
616TEST_F(ModulesTest, AdaptiveMaxPool1d) {
617 AdaptiveMaxPool1d model(3);
618 auto x = torch::tensor(
619 {{{1, 2, 3, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
620 auto y = model(x);
621 torch::Tensor s = y.sum();
622
623 s.backward();
624 ASSERT_EQ(y.ndimension(), 3);
625 ASSERT_TRUE(torch::allclose(y, torch::tensor({{{2, 4, 5}}}, torch::kFloat)));
626 ASSERT_EQ(s.ndimension(), 0);
627 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3}));
628}
629
630TEST_F(ModulesTest, AdaptiveMaxPool1dReturnIndices) {
631 AdaptiveMaxPool1d model(3);
632 auto x = torch::tensor(
633 {{{1, 2, 3, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
634 torch::Tensor y, indices;
635 std::tie(y, indices) = model->forward_with_indices(x);
636
637 ASSERT_EQ(y.dim(), 3);
638 ASSERT_TRUE(torch::allclose(y, torch::tensor({{{2, 4, 5}}}, torch::kFloat)));
639 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3}));
640 ASSERT_TRUE(
641 torch::allclose(indices, torch::tensor({{{1, 3, 4}}}, torch::kLong)));
642 ASSERT_EQ(indices.sizes(), std::vector<int64_t>({1, 1, 3}));
643}
644
645TEST_F(ModulesTest, AdaptiveMaxPool2dEven) {
646 AdaptiveMaxPool2d model(3);
647 auto x = torch::arange(0., 50);
648 x.resize_({2, 5, 5}).set_requires_grad(true);
649 auto y = model(x);
650 torch::Tensor s = y.sum();
651
652 s.backward();
653 ASSERT_EQ(y.ndimension(), 3);
654 ASSERT_TRUE(torch::allclose(
655 y,
656 torch::tensor(
657 {
658 {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}},
659 {{31, 33, 34}, {41, 43, 44}, {46, 48, 49}},
660 },
661 torch::kFloat)));
662 ASSERT_EQ(s.ndimension(), 0);
663 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3}));
664}
665
666TEST_F(ModulesTest, AdaptiveMaxPool2dUneven) {
667 AdaptiveMaxPool2d model(AdaptiveMaxPool2dOptions({3, 2}));
668 auto x = torch::arange(0., 40);
669 x.resize_({2, 5, 4}).set_requires_grad(true);
670 auto y = model(x);
671 torch::Tensor s = y.sum();
672
673 s.backward();
674 ASSERT_EQ(y.ndimension(), 3);
675 ASSERT_TRUE(torch::allclose(
676 y,
677 torch::tensor(
678 {
679 {{5, 7}, {13, 15}, {17, 19}},
680 {{25, 27}, {33, 35}, {37, 39}},
681 },
682 torch::kFloat)));
683 ASSERT_EQ(s.ndimension(), 0);
684 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 2}));
685}
686
687TEST_F(ModulesTest, AdaptiveMaxPool2dReturnIndicesEven) {
688 AdaptiveMaxPool2d model(3);
689 auto x = torch::arange(0., 50);
690 x.resize_({2, 5, 5}).set_requires_grad(true);
691 torch::Tensor y, indices;
692 std::tie(y, indices) = model->forward_with_indices(x);
693 torch::Tensor s = y.sum();
694
695 s.backward();
696 ASSERT_EQ(s.ndimension(), 0);
697
698 ASSERT_EQ(y.ndimension(), 3);
699 ASSERT_TRUE(torch::allclose(
700 y,
701 torch::tensor(
702 {
703 {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}},
704 {{31, 33, 34}, {41, 43, 44}, {46, 48, 49}},
705 },
706 torch::kFloat)));
707 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3}));
708
709 ASSERT_EQ(indices.ndimension(), 3);
710 ASSERT_TRUE(torch::allclose(
711 indices,
712 torch::tensor(
713 {
714 {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}},
715 {{6, 8, 9}, {16, 18, 19}, {21, 23, 24}},
716 },
717 torch::kLong)));
718 ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 3, 3}));
719}
720
721TEST_F(ModulesTest, AdaptiveMaxPool2dReturnIndicesUneven) {
722 AdaptiveMaxPool2d model(AdaptiveMaxPool2dOptions({3, 2}));
723 auto x = torch::arange(0., 40);
724 x.resize_({2, 5, 4}).set_requires_grad(true);
725 torch::Tensor y, indices;
726 std::tie(y, indices) = model->forward_with_indices(x);
727 torch::Tensor s = y.sum();
728
729 s.backward();
730 ASSERT_EQ(s.ndimension(), 0);
731
732 ASSERT_EQ(y.ndimension(), 3);
733 ASSERT_TRUE(torch::allclose(
734 y,
735 torch::tensor(
736 {
737 {{5, 7}, {13, 15}, {17, 19}},
738 {{25, 27}, {33, 35}, {37, 39}},
739 },
740 torch::kFloat)));
741 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 2}));
742
743 ASSERT_EQ(indices.ndimension(), 3);
744 ASSERT_TRUE(torch::allclose(
745 indices,
746 torch::tensor(
747 {
748 {{5, 7}, {13, 15}, {17, 19}},
749 {{5, 7}, {13, 15}, {17, 19}},
750 },
751 torch::kLong)));
752 ASSERT_EQ(indices.sizes(), std::vector<int64_t>({2, 3, 2}));
753}
754
755TEST_F(ModulesTest, AdaptiveMaxPool3d) {
756 AdaptiveMaxPool3d model(3);
757 auto x = torch::arange(0., 64);
758 x.resize_({1, 4, 4, 4}).set_requires_grad(true);
759 auto y = model(x);
760 torch::Tensor s = y.sum();
761
762 s.backward();
763 ASSERT_EQ(s.ndimension(), 0);
764
765 ASSERT_EQ(y.ndimension(), 4);
766 ASSERT_TRUE(torch::allclose(
767 y,
768 torch::tensor(
769 {
770 {{21, 22, 23}, {25, 26, 27}, {29, 30, 31}},
771 {{37, 38, 39}, {41, 42, 43}, {45, 46, 47}},
772 {{53, 54, 55}, {57, 58, 59}, {61, 62, 63}},
773 },
774 torch::kFloat)));
775 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 3, 3, 3}));
776}
777
778TEST_F(ModulesTest, AdaptiveMaxPool3dReturnIndices) {
779 AdaptiveMaxPool3d model(3);
780 auto x = torch::arange(0., 64);
781 x.resize_({1, 4, 4, 4}).set_requires_grad(true);
782 torch::Tensor y, indices;
783 std::tie(y, indices) = model->forward_with_indices(x);
784 torch::Tensor s = y.sum();
785
786 s.backward();
787 ASSERT_EQ(s.ndimension(), 0);
788
789 ASSERT_EQ(y.ndimension(), 4);
790 ASSERT_TRUE(torch::allclose(
791 y,
792 torch::tensor(
793 {
794 {{21, 22, 23}, {25, 26, 27}, {29, 30, 31}},
795 {{37, 38, 39}, {41, 42, 43}, {45, 46, 47}},
796 {{53, 54, 55}, {57, 58, 59}, {61, 62, 63}},
797 },
798 torch::kFloat)));
799 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 3, 3, 3}));
800
801 ASSERT_EQ(indices.ndimension(), 4);
802 ASSERT_TRUE(torch::allclose(
803 indices,
804 torch::tensor(
805 {
806 {{21, 22, 23}, {25, 26, 27}, {29, 30, 31}},
807 {{37, 38, 39}, {41, 42, 43}, {45, 46, 47}},
808 {{53, 54, 55}, {57, 58, 59}, {61, 62, 63}},
809 },
810 torch::kLong)));
811 ASSERT_EQ(indices.sizes(), std::vector<int64_t>({1, 3, 3, 3}));
812}
813
814TEST_F(ModulesTest, AdaptiveAvgPool1d) {
815 AdaptiveAvgPool1d model(3);
816 auto x = torch::tensor(
817 {{{1, 2, 3, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
818 auto y = model(x);
819 torch::Tensor s = y.sum();
820
821 s.backward();
822 ASSERT_EQ(s.ndimension(), 0);
823
824 ASSERT_EQ(y.ndimension(), 3);
825 ASSERT_TRUE(
826 torch::allclose(y, torch::tensor({{{1.5, 3.0, 4.5}}}, torch::kFloat)));
827 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3}));
828}
829
830TEST_F(ModulesTest, AdaptiveAvgPool2dEven) {
831 AdaptiveAvgPool2d model(3);
832 auto x = torch::arange(0., 50);
833 x.resize_({2, 5, 5}).set_requires_grad(true);
834 auto y = model(x);
835 torch::Tensor s = y.sum();
836
837 s.backward();
838 ASSERT_EQ(s.ndimension(), 0);
839
840 ASSERT_EQ(y.ndimension(), 3);
841 ASSERT_TRUE(torch::allclose(
842 y,
843 torch::tensor(
844 {
845 {{3.0, 4.5, 6.0}, {10.5, 12.0, 13.5}, {18.0, 19.5, 21.0}},
846 {{28.0, 29.5, 31.0}, {35.5, 37.0, 38.5}, {43.0, 44.5, 46.0}},
847 },
848 torch::kFloat)));
849 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3}));
850}
851
852TEST_F(ModulesTest, AdaptiveAvgPool2dUneven) {
853 AdaptiveAvgPool2d model(AdaptiveAvgPool2dOptions({3, 2}));
854 auto x = torch::arange(0., 40);
855 x.resize_({2, 5, 4}).set_requires_grad(true);
856 auto y = model(x);
857 torch::Tensor s = y.sum();
858
859 s.backward();
860 ASSERT_EQ(s.ndimension(), 0);
861
862 ASSERT_EQ(y.ndimension(), 3);
863 ASSERT_TRUE(torch::allclose(
864 y,
865 torch::tensor(
866 {
867 {{2.5, 4.5}, {8.5, 10.5}, {14.5, 16.5}},
868 {{22.5, 24.5}, {28.5, 30.5}, {34.5, 36.5}},
869 },
870 torch::kFloat)));
871 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 2}));
872}
873
874TEST_F(ModulesTest, AdaptiveAvgPool3d) {
875 AdaptiveAvgPool3d model(3);
876 auto x = torch::arange(0., 64);
877 x.resize_({1, 4, 4, 4}).set_requires_grad(true);
878 auto y = model(x);
879 torch::Tensor s = y.sum();
880
881 s.backward();
882 ASSERT_EQ(s.ndimension(), 0);
883
884 ASSERT_EQ(y.ndimension(), 4);
885 ASSERT_TRUE(torch::allclose(
886 y,
887 torch::tensor(
888 {
889 {{10.5, 11.5, 12.5}, {14.5, 15.5, 16.5}, {18.5, 19.5, 20.5}},
890 {{26.5, 27.5, 28.5}, {30.5, 31.5, 32.5}, {34.5, 35.5, 36.5}},
891 {{42.5, 43.5, 44.5}, {46.5, 47.5, 48.5}, {50.5, 51.5, 52.5}},
892 },
893 torch::kFloat)));
894 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 3, 3, 3}));
895}
896
897TEST_F(ModulesTest, MaxUnpool1d) {
898 auto indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
899 auto x = torch::tensor(
900 {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
901 auto model = MaxUnpool1d{3};
902 auto y = model->forward(x, indices);
903
904 ASSERT_EQ(y.dim(), 3);
905 ASSERT_TRUE(torch::allclose(
906 y, torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat)));
907 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 9}));
908
909 indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
910 x = torch::tensor(
911 {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
912 model = MaxUnpool1d{MaxUnpool1dOptions(3).stride(2).padding(1)};
913 y = model->forward(x, indices, std::vector<int64_t>({1, 1, 5}));
914
915 ASSERT_EQ(y.dim(), 3);
916 ASSERT_TRUE(
917 torch::allclose(y, torch::tensor({{{0, 2, 0, 4, 5}}}, torch::kFloat)));
918 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 5}));
919}
920
921TEST_F(ModulesTest, MaxPool1d_MaxUnpool1d) {
922 MaxPool1d pool{MaxPool1dOptions(2).stride(2)};
923 MaxUnpool1d unpool{MaxUnpool1dOptions(2).stride(2)};
924 auto input = torch::tensor({{{1, 2, 3, 4, 5, 6, 7, 8}}}, torch::kFloat);
925 torch::Tensor output, indices;
926 std::tie(output, indices) = pool->forward_with_indices(input);
927 ASSERT_TRUE(torch::allclose(
928 unpool(output, indices),
929 torch::tensor({{{0, 2, 0, 4, 0, 6, 0, 8}}}, torch::kFloat)));
930
931 // Example showcasing the use of output_size
932 input = torch::tensor({{{1, 2, 3, 4, 5, 6, 7, 8, 9}}}, torch::kFloat);
933 std::tie(output, indices) = pool->forward_with_indices(input);
934 ASSERT_TRUE(torch::allclose(
935 unpool(output, indices, input.sizes().vec()),
936 torch::tensor({{{0, 2, 0, 4, 0, 6, 0, 8, 0}}}, torch::kFloat)));
937 ASSERT_TRUE(torch::allclose(
938 unpool(output, indices),
939 torch::tensor({{{0, 2, 0, 4, 0, 6, 0, 8}}}, torch::kFloat)));
940}
941
942TEST_F(ModulesTest, MaxUnpool2d) {
943 auto indices = torch::tensor(
944 {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
945 {{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}},
946 torch::kLong);
947 auto x = torch::tensor(
948 {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
949 {{{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}}},
950 torch::dtype(torch::kFloat).requires_grad(true));
951 auto model = MaxUnpool2d{MaxUnpool2dOptions(3).stride(2).padding(1)};
952 auto y = model->forward(x, indices);
953
954 ASSERT_EQ(y.dim(), 4);
955 ASSERT_TRUE(torch::allclose(
956 y,
957 torch::tensor(
958 {{{{0, 0, 0, 0, 0},
959 {0, 6, 0, 8, 9},
960 {0, 0, 0, 0, 0},
961 {0, 16, 0, 18, 19},
962 {0, 21, 0, 23, 24}}},
963 {{{0, 0, 0, 0, 0},
964 {0, 31, 0, 33, 34},
965 {0, 0, 0, 0, 0},
966 {0, 41, 0, 43, 44},
967 {0, 46, 0, 48, 49}}}},
968 torch::kFloat)));
969 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 1, 5, 5}));
970}
971
972TEST_F(ModulesTest, MaxPool2d_MaxUnpool2d) {
973 MaxPool2d pool{MaxPool2dOptions(2).stride(2)};
974 MaxUnpool2d unpool{MaxUnpool2dOptions(2).stride(2)};
975 auto input = torch::tensor(
976 {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}}}},
977 torch::kFloat);
978 torch::Tensor output, indices;
979 std::tie(output, indices) = pool->forward_with_indices(input);
980 ASSERT_TRUE(torch::allclose(
981 unpool(output, indices),
982 torch::tensor(
983 {{{{0, 0, 0, 0}, {0, 6, 0, 8}, {0, 0, 0, 0}, {0, 14, 0, 16}}}},
984 torch::kFloat)));
985
986 ASSERT_TRUE(torch::allclose(
987 unpool(output, indices, std::vector<int64_t>{1, 1, 5, 5}),
988 torch::tensor(
989 {{{{0, 0, 0, 0, 0},
990 {6, 0, 8, 0, 0},
991 {0, 0, 0, 14, 0},
992 {16, 0, 0, 0, 0},
993 {0, 0, 0, 0, 0}}}},
994 torch::kFloat)));
995}
996
997TEST_F(ModulesTest, MaxUnpool3d) {
998 auto indices = torch::tensor({{{{{26}}}}}, torch::kLong);
999 auto x = torch::tensor(
1000 {{{{{26}}}}}, torch::dtype(torch::kFloat).requires_grad(true));
1001 auto model = MaxUnpool3d{3};
1002 auto y = model->forward(x, indices);
1003
1004 ASSERT_EQ(y.dim(), 5);
1005 ASSERT_TRUE(torch::allclose(
1006 y,
1007 torch::tensor(
1008 {{{{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
1009 {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
1010 {{0, 0, 0}, {0, 0, 0}, {0, 0, 26}}}}},
1011 torch::kFloat)));
1012 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3, 3, 3}));
1013}
1014
1015TEST_F(ModulesTest, MaxUnpool3dOutputSize) {
1016 auto indices = torch::tensor(
1017 {{{{{21, 23}, {29, 31}}, {{53, 55}, {61, 63}}}}}, torch::kLong);
1018 auto x = torch::tensor(
1019 {{{{{21, 23}, {29, 31}}, {{53, 55}, {61, 63}}}}},
1020 torch::dtype(torch::kFloat).requires_grad(true));
1021 auto model = MaxUnpool3d{MaxUnpool3dOptions(3).stride(2).padding(1)};
1022 auto y = model->forward(x, indices, std::vector<int64_t>({1, 1, 4, 4, 4}));
1023
1024 ASSERT_EQ(y.dim(), 5);
1025 ASSERT_TRUE(torch::allclose(
1026 y,
1027 torch::tensor(
1028 {{{{{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}},
1029 {{0, 0, 0, 0}, {0, 21, 0, 23}, {0, 0, 0, 0}, {0, 29, 0, 31}},
1030 {{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}},
1031 {{0, 0, 0, 0}, {0, 53, 0, 55}, {0, 0, 0, 0}, {0, 61, 0, 63}}}}},
1032 torch::kFloat)));
1033 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 4, 4, 4}));
1034}
1035
1036TEST_F(ModulesTest, MaxPool3d_MaxUnpool3d) {
1037 MaxPool3d pool{MaxPool3dOptions(3).stride(2)};
1038 MaxUnpool3d unpool{MaxUnpool3dOptions(3).stride(2)};
1039 auto input = torch::randn({20, 16, 51, 33, 15});
1040 torch::Tensor output, indices;
1041 std::tie(output, indices) = pool->forward_with_indices(input);
1042 auto unpooled_output = unpool(output, indices);
1043 ASSERT_EQ(
1044 unpooled_output.sizes(), std::vector<int64_t>({20, 16, 51, 33, 15}));
1045}
1046
1047TEST_F(ModulesTest, Linear) {
1048 {
1049 Linear model(5, 2);
1050 auto x = torch::randn({10, 5}, torch::requires_grad());
1051 auto y = model(x);
1052 torch::Tensor s = y.sum();
1053
1054 s.backward();
1055 ASSERT_EQ(y.ndimension(), 2);
1056 ASSERT_EQ(s.ndimension(), 0);
1057 ASSERT_EQ(y.size(0), 10);
1058 ASSERT_EQ(y.size(1), 2);
1059
1060 ASSERT_EQ(model->weight.grad().numel(), 2 * 5);
1061
1062 auto y_exp = torch::addmm(model->bias, x, model->weight.t());
1063 ASSERT_TRUE(torch::allclose(y, y_exp));
1064 }
1065 {
1066 Linear model(LinearOptions(5, 2).bias(false));
1067 auto x = torch::randn({10, 5}, torch::requires_grad());
1068 auto y = model(x);
1069 torch::Tensor s = y.sum();
1070
1071 s.backward();
1072 ASSERT_EQ(y.ndimension(), 2);
1073 ASSERT_EQ(s.ndimension(), 0);
1074 ASSERT_EQ(y.size(0), 10);
1075 ASSERT_EQ(y.size(1), 2);
1076
1077 ASSERT_EQ(model->weight.grad().numel(), 2 * 5);
1078
1079 auto y_exp = torch::mm(x, model->weight.t());
1080 ASSERT_TRUE(torch::allclose(y, y_exp));
1081 }
1082}
1083
1084TEST_F(ModulesTest, LocalResponseNorm) {
1085 {
1086 LocalResponseNorm model(LocalResponseNormOptions(2));
1087 const auto x =
1088 torch::arange(100., 136, torch::requires_grad()).reshape({2, 3, 3, 2});
1089 auto y = model(x);
1090 const auto y_exp = torch::tensor(
1091 {{{{73.7788, 74.1462}, {74.5031, 74.8572}, {75.2010, 75.5420}},
1092
1093 {{61.6057, 61.7227}, {61.8347, 61.9418}, {62.0441, 62.1418}},
1094
1095 {{62.2349, 62.3235}, {62.4077, 62.4877}, {62.5635, 62.6353}}},
1096
1097 {{{79.3915, 79.6491}, {79.8978, 80.1446}, {80.3827, 80.6190}},
1098
1099 {{63.0317, 63.0742}, {63.1135, 63.1496}, {63.1826, 63.2126}},
1100
1101 {{63.2396, 63.2637}, {63.2850, 63.3036}, {63.3195, 63.3328}}}},
1102 torch::kFloat);
1103 torch::Tensor s = y.sum();
1104
1105 s.backward();
1106 ASSERT_EQ(y.ndimension(), 4);
1107 ASSERT_EQ(s.ndimension(), 0);
1108 ASSERT_EQ(y.sizes(), x.sizes());
1109 ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7));
1110 }
1111}
1112
1113TEST_F(ModulesTest, LayerNorm) {
1114 LayerNorm model(LayerNormOptions({2, 2}).eps(2e-5));
1115 auto x = torch::randn({2, 2}, torch::requires_grad());
1116 auto y = model(x);
1117 auto y_exp = torch::layer_norm(x, {2, 2}, model->weight, model->bias, 2e-5);
1118 torch::Tensor s = y.sum();
1119
1120 s.backward();
1121 ASSERT_EQ(y.ndimension(), 2);
1122 ASSERT_EQ(s.ndimension(), 0);
1123 for (const auto i : c10::irange(2)) {
1124 ASSERT_EQ(y.size(i), 2);
1125 }
1126
1127 ASSERT_EQ(model->weight.grad().numel(), 2 * 2);
1128 ASSERT_TRUE(torch::allclose(y, y_exp));
1129}
1130
1131TEST_F(ModulesTest, GroupNorm) {
1132 GroupNorm model(GroupNormOptions(2, 2).eps(2e-5));
1133 auto x = torch::randn({2, 2}, torch::requires_grad());
1134 auto y = model(x);
1135 auto y_exp = torch::group_norm(x, 2, model->weight, model->bias, 2e-5);
1136 torch::Tensor s = y.sum();
1137
1138 s.backward();
1139 ASSERT_EQ(y.ndimension(), 2);
1140 ASSERT_EQ(s.ndimension(), 0);
1141 for (const auto i : c10::irange(2)) {
1142 ASSERT_EQ(y.size(i), 2);
1143 }
1144
1145 ASSERT_EQ(model->weight.grad().numel(), 2);
1146 ASSERT_TRUE(torch::allclose(y, y_exp));
1147}
1148
1149TEST_F(ModulesTest, Bilinear) {
1150 Bilinear model(5, 3, 2);
1151 auto x1 = torch::randn({10, 5}, torch::requires_grad());
1152 auto x2 = torch::randn({10, 3}, torch::requires_grad());
1153 auto y = model(x1, x2);
1154 torch::Tensor s = y.sum();
1155
1156 s.backward();
1157 ASSERT_EQ(y.ndimension(), 2);
1158 ASSERT_EQ(s.ndimension(), 0);
1159 ASSERT_EQ(y.size(0), 10);
1160 ASSERT_EQ(y.size(1), 2);
1161
1162 ASSERT_EQ(model->weight.grad().numel(), 2 * 5 * 3);
1163}
1164
1165TEST_F(ModulesTest, Fold) {
1166 {
1167 Fold model(FoldOptions({3, 2}, {2, 2}));
1168 auto input = torch::ones({1, 3 * 2 * 2, 2}, torch::requires_grad());
1169 auto output = model(input);
1170 auto expected = torch::tensor(
1171 {{{{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}},
1172 {{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}},
1173 {{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}}}},
1174 torch::kFloat);
1175 auto s = output.sum();
1176 s.backward();
1177
1178 ASSERT_EQ(s.ndimension(), 0);
1179 ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 3, 3, 2}));
1180 ASSERT_TRUE(output.allclose(expected));
1181 }
1182 {
1183 // input wrong dimension
1184 Fold model(FoldOptions({8, 8}, {3, 3}));
1185 ASSERT_THROWS_WITH(
1186 model(torch::randn({1, 3, 16, 16})),
1187 "Input Error: Only unbatched (2D) or batched (3D) input Tensors are supported (got 4D)");
1188 }
1189}
1190
1191TEST_F(ModulesTest, Unfold) {
1192 {
1193 Unfold model(UnfoldOptions({2, 2}).padding(1).stride(2));
1194 auto input =
1195 torch::arange(2., 14, torch::requires_grad()).view({1, 2, 2, 3});
1196 auto output = model(input);
1197 auto expected = torch::tensor(
1198 {{{0.0, 0.0, 0.0, 6.0},
1199 {0.0, 0.0, 5.0, 7.0},
1200 {0.0, 3.0, 0.0, 0.0},
1201 {2.0, 4.0, 0.0, 0.0},
1202 {0.0, 0.0, 0.0, 12.0},
1203 {0.0, 0.0, 11.0, 13.0},
1204 {0.0, 9.0, 0.0, 0.0},
1205 {8.0, 10.0, 0.0, 0.0}}},
1206 torch::kFloat);
1207 auto s = output.sum();
1208 s.backward();
1209
1210 ASSERT_EQ(s.ndimension(), 0);
1211 ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 8, 4}));
1212 ASSERT_TRUE(output.allclose(expected));
1213 }
1214 {
1215 // input wrong dimension
1216 Unfold model(UnfoldOptions({2, 4}));
1217 ASSERT_THROWS_WITH(
1218 model(torch::randn({1, 5, 2})),
1219 "Input Error: Only 4D input Tensors are supported (got 3D)");
1220 }
1221 {
1222 // calculated output shape is too small
1223 Unfold model(UnfoldOptions({2, 3}));
1224 ASSERT_THROWS_WITH(
1225 model(torch::randn({1, 2, 2, 2})),
1226 "Given input with spatial size (2, 2), kernel_size=(2, 3), "
1227 "dilation=(1, 1), padding=(0, 0), calculated shape of the array of "
1228 "sliding blocks as (1, 0), but its components must be at least one.");
1229 }
1230}
1231
1232TEST_F(ModulesTest, SimpleContainer) {
1233 auto model = std::make_shared<SimpleContainer>();
1234 auto l1 = model->add(Linear(10, 3), "l1");
1235 auto l2 = model->add(Linear(3, 5), "l2");
1236 auto l3 = model->add(Linear(5, 100), "l3");
1237
1238 auto x = torch::randn({1000, 10}, torch::requires_grad());
1239 x = l1(x).clamp_min(0);
1240 x = l2(x).clamp_min(0);
1241 x = l3(x).clamp_min(0);
1242
1243 x.backward(torch::ones_like(x));
1244 ASSERT_EQ(x.ndimension(), 2);
1245 ASSERT_EQ(x.size(0), 1000);
1246 ASSERT_EQ(x.size(1), 100);
1247 ASSERT_EQ(x.min().item<float>(), 0);
1248}
1249
1250TEST_F(ModulesTest, EmbeddingBasic) {
1251 const int64_t dict_size = 10;
1252 Embedding model(dict_size, 2);
1253 ASSERT_TRUE(model->named_parameters().contains("weight"));
1254 ASSERT_EQ(model->weight.ndimension(), 2);
1255 ASSERT_EQ(model->weight.size(0), dict_size);
1256 ASSERT_EQ(model->weight.size(1), 2);
1257
1258 // Cannot get gradients to change indices (input) - only for embedding
1259 // params
1260 auto x = torch::full({10}, dict_size - 1, torch::kInt64);
1261 auto y = model(x);
1262 torch::Tensor s = y.sum();
1263
1264 s.backward();
1265 ASSERT_EQ(y.ndimension(), 2);
1266 ASSERT_EQ(s.ndimension(), 0);
1267 ASSERT_EQ(y.size(0), 10);
1268 ASSERT_EQ(y.size(1), 2);
1269
1270 ASSERT_EQ(model->weight.grad().numel(), 2 * dict_size);
1271}
1272
1273TEST_F(ModulesTest, EmbeddingList) {
1274 Embedding model(6, 4);
1275 auto x = torch::full({2, 3}, 5, torch::kInt64);
1276 auto y = model(x);
1277 torch::Tensor s = y.sum();
1278
1279 s.backward();
1280 ASSERT_EQ(y.ndimension(), 3);
1281 ASSERT_EQ(y.size(0), 2);
1282 ASSERT_EQ(y.size(1), 3);
1283 ASSERT_EQ(y.size(2), 4);
1284}
1285
1286TEST_F(ModulesTest, EmbeddingFromPretrained) {
1287 auto weight = torch::tensor({{1., 2.3, 3.}, {4., 5.1, 6.3}});
1288 Embedding embedding = torch::nn::Embedding::from_pretrained(weight);
1289 auto input = torch::tensor({1}, torch::kLong);
1290 ASSERT_TRUE(torch::allclose(
1291 embedding(input), torch::tensor({4.0000, 5.1000, 6.3000})));
1292}
1293
1294TEST_F(ModulesTest, EmbeddingBagFromPretrained) {
1295 auto weight = torch::tensor({{1., 2.3, 3.}, {4., 5.1, 6.3}});
1296 EmbeddingBag embeddingbag = torch::nn::EmbeddingBag::from_pretrained(weight);
1297 auto input = torch::zeros({{1, 2}}, torch::kLong);
1298 input[0] = torch::tensor({1, 0});
1299 ASSERT_TRUE(torch::allclose(
1300 embeddingbag(input), torch::tensor({2.5000, 3.7000, 4.6500})));
1301}
1302
1303TEST_F(ModulesTest, AlphaDropout) {
1304 AlphaDropout alpha_dropout(0.5);
1305 torch::Tensor x = torch::ones(100, torch::requires_grad());
1306 torch::Tensor y = alpha_dropout(x);
1307
1308 y.backward(torch::ones_like(y));
1309
1310 ASSERT_EQ(y.ndimension(), 1);
1311 ASSERT_EQ(y.size(0), 100);
1312 ASSERT_LT(y.sum().item<float>(), 130); // Probably
1313 ASSERT_GT(y.sum().item<float>(), 40); // Probably
1314
1315 alpha_dropout->eval();
1316 y = alpha_dropout(x);
1317
1318 ASSERT_EQ(y.sum().item<float>(), 100);
1319}
1320
1321TEST_F(ModulesTest, FeatureAlphaDropout) {
1322 FeatureAlphaDropout feature_alpha_dropout(0.5);
1323 torch::Tensor x = torch::ones({10, 10}, torch::requires_grad());
1324 torch::Tensor y = feature_alpha_dropout(x);
1325
1326 y.backward(torch::ones_like(y));
1327
1328 ASSERT_EQ(y.ndimension(), 2);
1329 ASSERT_EQ(y.size(0), 10);
1330 ASSERT_EQ(y.size(1), 10);
1331 ASSERT_LT(y.sum().item<float>(), 130); // Probably
1332 ASSERT_GT(y.sum().item<float>(), 40); // Probably
1333
1334 feature_alpha_dropout->eval();
1335 y = feature_alpha_dropout(x);
1336
1337 ASSERT_EQ(y.sum().item<float>(), 100);
1338}
1339
1340TEST_F(ModulesTest, Dropout) {
1341 for (const auto inplace : {false, true}) {
1342 Dropout dropout(DropoutOptions(0.5).inplace(inplace));
1343 torch::Tensor x = torch::ones(100);
1344 if (!inplace) {
1345 x.requires_grad_(true);
1346 }
1347 torch::Tensor y = dropout(x);
1348
1349 ASSERT_EQ(y.ndimension(), 1);
1350 ASSERT_EQ(y.size(0), 100);
1351 ASSERT_LT(y.sum().item<float>(), 130); // Probably
1352 ASSERT_GT(y.sum().item<float>(), 70); // Probably
1353 if (inplace) {
1354 ASSERT_TRUE(y.allclose(x));
1355 } else {
1356 y.backward(torch::ones_like(y));
1357 }
1358
1359 dropout->eval();
1360 y = dropout(torch::ones(100));
1361 ASSERT_EQ(y.sum().item<float>(), 100);
1362 }
1363}
1364
1365TEST_F(ModulesTest, Dropout2d) {
1366 auto p = 0.5;
1367 for (const auto inplace : {false, true}) {
1368 Dropout2d dropout(Dropout2dOptions(p).inplace(inplace));
1369 torch::Tensor x = torch::empty({50, 50, 2, 2}).fill_(1 - p);
1370 if (!inplace) {
1371 x.requires_grad_(true);
1372 }
1373 torch::Tensor y = dropout(x);
1374
1375 ASSERT_EQ(y.ndimension(), 4);
1376 ASSERT_EQ(y.size(0), 50);
1377 ASSERT_EQ(y.size(1), 50);
1378 ASSERT_EQ(y.size(2), 2);
1379 ASSERT_EQ(y.size(3), 2);
1380 ASSERT_LT((y.mean() - (1 - p)).abs().item<float>(), 0.05);
1381
1382 if (inplace) {
1383 ASSERT_TRUE(y.allclose(x));
1384 } else {
1385 y.backward(torch::ones_like(y));
1386 }
1387
1388 dropout->eval();
1389 y = dropout(torch::ones({2, 2, 10, 10}));
1390 ASSERT_EQ(y.sum().item<float>(), 400);
1391 }
1392}
1393
1394TEST_F(ModulesTest, Dropout3d) {
1395 for (const auto inplace : {false, true}) {
1396 auto p = 0.5;
1397 Dropout3d dropout(Dropout3dOptions(p).inplace(inplace));
1398 torch::Tensor x = torch::empty({50, 50, 2, 2, 2}).fill_(1 - p);
1399 if (!inplace) {
1400 x.requires_grad_(true);
1401 }
1402 torch::Tensor y = dropout(x);
1403
1404 ASSERT_EQ(y.ndimension(), 5);
1405 ASSERT_EQ(y.size(0), 50);
1406 ASSERT_EQ(y.size(1), 50);
1407 ASSERT_EQ(y.size(2), 2);
1408 ASSERT_EQ(y.size(3), 2);
1409 ASSERT_EQ(y.size(4), 2);
1410 ASSERT_LT((y.mean() - (1 - p)).abs().item<float>(), 0.05);
1411
1412 if (inplace) {
1413 ASSERT_TRUE(y.allclose(x));
1414 } else {
1415 y.backward(torch::ones_like(y));
1416 }
1417
1418 dropout->eval();
1419 y = dropout(torch::ones({4, 4, 5, 5}));
1420 ASSERT_EQ(y.sum().item<float>(), 400);
1421 }
1422}
1423
1424TEST_F(ModulesTest, Parameters) {
1425 auto model = std::make_shared<NestedModel>();
1426 auto parameters = model->named_parameters();
1427 ASSERT_EQ(parameters["param"].size(0), 3);
1428 ASSERT_EQ(parameters["param"].size(1), 2);
1429 ASSERT_EQ(parameters["param"].size(2), 21);
1430 ASSERT_EQ(parameters["l1.bias"].size(0), 20);
1431 ASSERT_EQ(parameters["l1.weight"].size(0), 20);
1432 ASSERT_EQ(parameters["l1.weight"].size(1), 5);
1433 ASSERT_EQ(parameters["test.l1.bias"].size(0), 3);
1434 ASSERT_EQ(parameters["test.l1.weight"].size(0), 3);
1435 ASSERT_EQ(parameters["test.l1.weight"].size(1), 10);
1436 ASSERT_EQ(parameters["test.l2.bias"].size(0), 5);
1437 ASSERT_EQ(parameters["test.l2.weight"].size(0), 5);
1438 ASSERT_EQ(parameters["test.l2.weight"].size(1), 3);
1439 ASSERT_EQ(parameters["test.l3.bias"].size(0), 100);
1440 ASSERT_EQ(parameters["test.l3.weight"].size(0), 100);
1441 ASSERT_EQ(parameters["test.l3.weight"].size(1), 5);
1442}
1443
1444TEST_F(ModulesTest, FunctionalCallsSuppliedFunction) {
1445 bool was_called = false;
1446 auto functional = Functional([&was_called](torch::Tensor input) {
1447 was_called = true;
1448 return input;
1449 });
1450 auto output = functional(torch::ones(5, torch::requires_grad()));
1451 ASSERT_TRUE(was_called);
1452 ASSERT_TRUE(output.equal(torch::ones(5, torch::requires_grad())));
1453
1454 was_called = false;
1455 // Use the call operator overload here.
1456 output = functional(torch::ones(5, torch::requires_grad()));
1457 ASSERT_TRUE(was_called);
1458 ASSERT_TRUE(output.equal(torch::ones(5, torch::requires_grad())));
1459}
1460
1461TEST_F(ModulesTest, FunctionalWithTorchFunction) {
1462 auto functional = Functional(torch::relu);
1463 ASSERT_EQ(functional(torch::ones({})).item<float>(), 1);
1464 ASSERT_EQ(functional(torch::ones({})).item<float>(), 1);
1465 ASSERT_EQ(functional(torch::ones({}) * -1).item<float>(), 0);
1466}
1467
1468TEST_F(ModulesTest, FunctionalArgumentBinding) {
1469 auto functional =
1470 Functional(torch::elu, /*alpha=*/1, /*scale=*/0, /*input_scale=*/1);
1471 ASSERT_EQ(functional(torch::ones({})).item<float>(), 0);
1472}
1473
1474TEST_F(ModulesTest, BatchNorm1dStateful) {
1475 BatchNorm1d bn(5);
1476
1477 ASSERT_TRUE(bn->options.track_running_stats());
1478
1479 ASSERT_TRUE(bn->running_mean.defined());
1480 ASSERT_EQ(bn->running_mean.dim(), 1);
1481 ASSERT_EQ(bn->running_mean.size(0), 5);
1482
1483 ASSERT_TRUE(bn->running_var.defined());
1484 ASSERT_EQ(bn->running_var.dim(), 1);
1485 ASSERT_EQ(bn->running_var.size(0), 5);
1486
1487 ASSERT_TRUE(bn->num_batches_tracked.defined());
1488 ASSERT_EQ(bn->num_batches_tracked.dim(), 0);
1489
1490 ASSERT_TRUE(bn->options.affine());
1491
1492 ASSERT_TRUE(bn->weight.defined());
1493 ASSERT_EQ(bn->weight.dim(), 1);
1494 ASSERT_EQ(bn->weight.size(0), 5);
1495
1496 ASSERT_TRUE(bn->bias.defined());
1497 ASSERT_EQ(bn->bias.dim(), 1);
1498 ASSERT_EQ(bn->bias.size(0), 5);
1499}
1500
1501TEST_F(ModulesTest, BatchNorm1dStateless) {
1502 BatchNorm1d bn(
1503 BatchNorm1dOptions(5).track_running_stats(false).affine(false));
1504
1505 ASSERT_FALSE(bn->running_mean.defined());
1506 ASSERT_FALSE(bn->running_var.defined());
1507 ASSERT_FALSE(bn->num_batches_tracked.defined());
1508 ASSERT_FALSE(bn->weight.defined());
1509 ASSERT_FALSE(bn->bias.defined());
1510}
1511
1512TEST_F(ModulesTest, BatchNorm1d) {
1513 BatchNorm1d bn(5);
1514 bn->eval();
1515
1516 auto input = torch::arange(2. * 5 * 2).view({2, 5, 2}).requires_grad_();
1517 auto output = bn->forward(input);
1518 auto expected = torch::tensor(
1519 {{{0.0000, 1.0000},
1520 {2.0000, 3.0000},
1521 {4.0000, 5.0000},
1522 {6.0000, 7.0000},
1523 {8.0000, 9.0000}},
1524 {{10.0000, 10.9999},
1525 {11.9999, 12.9999},
1526 {13.9999, 14.9999},
1527 {15.9999, 16.9999},
1528 {17.9999, 18.9999}}});
1529 ASSERT_TRUE(output.allclose(expected));
1530 auto s = output.sum();
1531 s.backward();
1532
1533 ASSERT_EQ(input.sizes(), input.grad().sizes());
1534}
1535
1536TEST_F(ModulesTest, BatchNorm2dStateful) {
1537 BatchNorm2d bn(5);
1538
1539 ASSERT_TRUE(bn->options.track_running_stats());
1540
1541 ASSERT_TRUE(bn->running_mean.defined());
1542 ASSERT_EQ(bn->running_mean.dim(), 1);
1543 ASSERT_EQ(bn->running_mean.size(0), 5);
1544
1545 ASSERT_TRUE(bn->running_var.defined());
1546 ASSERT_EQ(bn->running_var.dim(), 1);
1547 ASSERT_EQ(bn->running_var.size(0), 5);
1548
1549 ASSERT_TRUE(bn->num_batches_tracked.defined());
1550 ASSERT_EQ(bn->num_batches_tracked.dim(), 0);
1551
1552 ASSERT_TRUE(bn->options.affine());
1553
1554 ASSERT_TRUE(bn->weight.defined());
1555 ASSERT_EQ(bn->weight.dim(), 1);
1556 ASSERT_EQ(bn->weight.size(0), 5);
1557
1558 ASSERT_TRUE(bn->bias.defined());
1559 ASSERT_EQ(bn->bias.dim(), 1);
1560 ASSERT_EQ(bn->bias.size(0), 5);
1561}
1562
1563TEST_F(ModulesTest, BatchNorm2dStateless) {
1564 BatchNorm2d bn(
1565 BatchNorm2dOptions(5).track_running_stats(false).affine(false));
1566
1567 ASSERT_FALSE(bn->running_mean.defined());
1568 ASSERT_FALSE(bn->running_var.defined());
1569 ASSERT_FALSE(bn->num_batches_tracked.defined());
1570 ASSERT_FALSE(bn->weight.defined());
1571 ASSERT_FALSE(bn->bias.defined());
1572}
1573
1574TEST_F(ModulesTest, BatchNorm2d) {
1575 BatchNorm2d bn(5);
1576 bn->eval();
1577
1578 auto input =
1579 torch::arange(2. * 5 * 2 * 2).view({2, 5, 2, 2}).requires_grad_();
1580 auto output = bn->forward(input);
1581 auto expected = torch::tensor(
1582 {{{{0.0000, 1.0000}, {2.0000, 3.0000}},
1583 {{4.0000, 5.0000}, {6.0000, 7.0000}},
1584 {{8.0000, 9.0000}, {10.0000, 10.9999}},
1585 {{11.9999, 12.9999}, {13.9999, 14.9999}},
1586 {{15.9999, 16.9999}, {17.9999, 18.9999}}},
1587 {{{19.9999, 20.9999}, {21.9999, 22.9999}},
1588 {{23.9999, 24.9999}, {25.9999, 26.9999}},
1589 {{27.9999, 28.9999}, {29.9998, 30.9998}},
1590 {{31.9998, 32.9998}, {33.9998, 34.9998}},
1591 {{35.9998, 36.9998}, {37.9998, 38.9998}}}});
1592 ASSERT_TRUE(output.allclose(expected));
1593 auto s = output.sum();
1594 s.backward();
1595
1596 ASSERT_EQ(input.sizes(), input.grad().sizes());
1597}
1598
1599TEST_F(ModulesTest, BatchNorm3dStateful) {
1600 BatchNorm3d bn(5);
1601
1602 ASSERT_TRUE(bn->options.track_running_stats());
1603
1604 ASSERT_TRUE(bn->running_mean.defined());
1605 ASSERT_EQ(bn->running_mean.dim(), 1);
1606 ASSERT_EQ(bn->running_mean.size(0), 5);
1607
1608 ASSERT_TRUE(bn->running_var.defined());
1609 ASSERT_EQ(bn->running_var.dim(), 1);
1610 ASSERT_EQ(bn->running_var.size(0), 5);
1611
1612 ASSERT_TRUE(bn->num_batches_tracked.defined());
1613 ASSERT_EQ(bn->num_batches_tracked.dim(), 0);
1614
1615 ASSERT_TRUE(bn->options.affine());
1616
1617 ASSERT_TRUE(bn->weight.defined());
1618 ASSERT_EQ(bn->weight.dim(), 1);
1619 ASSERT_EQ(bn->weight.size(0), 5);
1620
1621 ASSERT_TRUE(bn->bias.defined());
1622 ASSERT_EQ(bn->bias.dim(), 1);
1623 ASSERT_EQ(bn->bias.size(0), 5);
1624}
1625
1626TEST_F(ModulesTest, BatchNorm3dStateless) {
1627 BatchNorm3d bn(
1628 BatchNorm3dOptions(5).track_running_stats(false).affine(false));
1629
1630 ASSERT_FALSE(bn->running_mean.defined());
1631 ASSERT_FALSE(bn->running_var.defined());
1632 ASSERT_FALSE(bn->num_batches_tracked.defined());
1633 ASSERT_FALSE(bn->weight.defined());
1634 ASSERT_FALSE(bn->bias.defined());
1635}
1636
1637TEST_F(ModulesTest, BatchNorm3d) {
1638 BatchNorm3d bn(5);
1639 bn->eval();
1640
1641 auto input =
1642 torch::arange(2. * 5 * 2 * 2 * 2).view({2, 5, 2, 2, 2}).requires_grad_();
1643 auto output = bn->forward(input);
1644 auto expected = torch::tensor(
1645 {{{{{0.0000, 1.0000}, {2.0000, 3.0000}},
1646 {{4.0000, 5.0000}, {6.0000, 7.0000}}},
1647 {{{8.0000, 9.0000}, {10.0000, 10.9999}},
1648 {{11.9999, 12.9999}, {13.9999, 14.9999}}},
1649 {{{15.9999, 16.9999}, {17.9999, 18.9999}},
1650 {{19.9999, 20.9999}, {21.9999, 22.9999}}},
1651 {{{23.9999, 24.9999}, {25.9999, 26.9999}},
1652 {{27.9999, 28.9999}, {29.9998, 30.9998}}},
1653 {{{31.9998, 32.9998}, {33.9998, 34.9998}},
1654 {{35.9998, 36.9998}, {37.9998, 38.9998}}}},
1655 {{{{39.9998, 40.9998}, {41.9998, 42.9998}},
1656 {{43.9998, 44.9998}, {45.9998, 46.9998}}},
1657 {{{47.9998, 48.9998}, {49.9997, 50.9997}},
1658 {{51.9997, 52.9997}, {53.9997, 54.9997}}},
1659 {{{55.9997, 56.9997}, {57.9997, 58.9997}},
1660 {{59.9997, 60.9997}, {61.9997, 62.9997}}},
1661 {{{63.9997, 64.9997}, {65.9997, 66.9997}},
1662 {{67.9997, 68.9997}, {69.9996, 70.9996}}},
1663 {{{71.9996, 72.9996}, {73.9996, 74.9996}},
1664 {{75.9996, 76.9996}, {77.9996, 78.9996}}}}});
1665 ASSERT_TRUE(output.allclose(expected));
1666 auto s = output.sum();
1667 s.backward();
1668
1669 ASSERT_EQ(input.sizes(), input.grad().sizes());
1670}
1671
1672TEST_F(ModulesTest, InstanceNorm1dStateful) {
1673 InstanceNorm1d instance_norm(
1674 InstanceNorm1dOptions(5).track_running_stats(true).affine(true));
1675
1676 ASSERT_TRUE(instance_norm->options.track_running_stats());
1677
1678 ASSERT_TRUE(instance_norm->running_mean.defined());
1679 ASSERT_EQ(instance_norm->running_mean.dim(), 1);
1680 ASSERT_EQ(instance_norm->running_mean.size(0), 5);
1681
1682 ASSERT_TRUE(instance_norm->running_var.defined());
1683 ASSERT_EQ(instance_norm->running_var.dim(), 1);
1684 ASSERT_EQ(instance_norm->running_var.size(0), 5);
1685
1686 ASSERT_TRUE(instance_norm->num_batches_tracked.defined());
1687 ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0);
1688
1689 ASSERT_TRUE(instance_norm->options.affine());
1690
1691 ASSERT_TRUE(instance_norm->weight.defined());
1692 ASSERT_EQ(instance_norm->weight.dim(), 1);
1693 ASSERT_EQ(instance_norm->weight.size(0), 5);
1694
1695 ASSERT_TRUE(instance_norm->bias.defined());
1696 ASSERT_EQ(instance_norm->bias.dim(), 1);
1697 ASSERT_EQ(instance_norm->bias.size(0), 5);
1698}
1699
1700TEST_F(ModulesTest, InstanceNorm1dStateless) {
1701 InstanceNorm1d instance_norm(
1702 InstanceNorm1dOptions(5).track_running_stats(false).affine(false));
1703
1704 ASSERT_FALSE(instance_norm->running_mean.defined());
1705 ASSERT_FALSE(instance_norm->running_var.defined());
1706 ASSERT_FALSE(instance_norm->num_batches_tracked.defined());
1707 ASSERT_FALSE(instance_norm->weight.defined());
1708 ASSERT_FALSE(instance_norm->bias.defined());
1709}
1710
1711TEST_F(ModulesTest, InstanceNorm1d) {
1712 InstanceNorm1d instance_norm(5);
1713 instance_norm->eval();
1714
1715 auto input = torch::arange(2. * 5 * 2).view({2, 5, 2}).requires_grad_();
1716 auto output = instance_norm->forward(input);
1717 auto expected = torch::tensor(
1718 {{{-1.0000, 1.0000},
1719 {-1.0000, 1.0000},
1720 {-1.0000, 1.0000},
1721 {-1.0000, 1.0000},
1722 {-1.0000, 1.0000}},
1723 {{-1.0000, 1.0000},
1724 {-1.0000, 1.0000},
1725 {-1.0000, 1.0000},
1726 {-1.0000, 1.0000},
1727 {-1.0000, 1.0000}}});
1728 ASSERT_TRUE(output.allclose(expected, 1e-3));
1729 auto s = output.sum();
1730 s.backward();
1731
1732 ASSERT_EQ(input.sizes(), input.grad().sizes());
1733}
1734
1735TEST_F(ModulesTest, InstanceNorm2dStateful) {
1736 InstanceNorm2d instance_norm(
1737 InstanceNorm2dOptions(5).track_running_stats(true).affine(true));
1738
1739 ASSERT_TRUE(instance_norm->options.track_running_stats());
1740
1741 ASSERT_TRUE(instance_norm->running_mean.defined());
1742 ASSERT_EQ(instance_norm->running_mean.dim(), 1);
1743 ASSERT_EQ(instance_norm->running_mean.size(0), 5);
1744
1745 ASSERT_TRUE(instance_norm->running_var.defined());
1746 ASSERT_EQ(instance_norm->running_var.dim(), 1);
1747 ASSERT_EQ(instance_norm->running_var.size(0), 5);
1748
1749 ASSERT_TRUE(instance_norm->num_batches_tracked.defined());
1750 ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0);
1751
1752 ASSERT_TRUE(instance_norm->options.affine());
1753
1754 ASSERT_TRUE(instance_norm->weight.defined());
1755 ASSERT_EQ(instance_norm->weight.dim(), 1);
1756 ASSERT_EQ(instance_norm->weight.size(0), 5);
1757
1758 ASSERT_TRUE(instance_norm->bias.defined());
1759 ASSERT_EQ(instance_norm->bias.dim(), 1);
1760 ASSERT_EQ(instance_norm->bias.size(0), 5);
1761}
1762
1763TEST_F(ModulesTest, InstanceNorm2dStateless) {
1764 InstanceNorm2d instance_norm(
1765 InstanceNorm2dOptions(5).track_running_stats(false).affine(false));
1766
1767 ASSERT_FALSE(instance_norm->running_mean.defined());
1768 ASSERT_FALSE(instance_norm->running_var.defined());
1769 ASSERT_FALSE(instance_norm->num_batches_tracked.defined());
1770 ASSERT_FALSE(instance_norm->weight.defined());
1771 ASSERT_FALSE(instance_norm->bias.defined());
1772}
1773
1774TEST_F(ModulesTest, InstanceNorm2d) {
1775 InstanceNorm2d instance_norm(5);
1776 instance_norm->eval();
1777
1778 auto input =
1779 torch::arange(2. * 5 * 2 * 2).view({2, 5, 2, 2}).requires_grad_();
1780 auto output = instance_norm->forward(input);
1781 auto expected = torch::tensor(
1782 {{{{-1.3416, -0.4472}, {0.4472, 1.3416}},
1783 {{-1.3416, -0.4472}, {0.4472, 1.3416}},
1784 {{-1.3416, -0.4472}, {0.4472, 1.3416}},
1785 {{-1.3416, -0.4472}, {0.4472, 1.3416}},
1786 {{-1.3416, -0.4472}, {0.4472, 1.3416}}},
1787 {{{-1.3416, -0.4472}, {0.4472, 1.3416}},
1788 {{-1.3416, -0.4472}, {0.4472, 1.3416}},
1789 {{-1.3416, -0.4472}, {0.4472, 1.3416}},
1790 {{-1.3416, -0.4472}, {0.4472, 1.3416}},
1791 {{-1.3416, -0.4472}, {0.4472, 1.3416}}}});
1792 ASSERT_TRUE(output.allclose(expected, 1e-3));
1793 auto s = output.sum();
1794 s.backward();
1795
1796 ASSERT_EQ(input.sizes(), input.grad().sizes());
1797}
1798
1799TEST_F(ModulesTest, InstanceNorm3dStateful) {
1800 InstanceNorm3d instance_norm(
1801 InstanceNorm3dOptions(5).track_running_stats(true).affine(true));
1802
1803 ASSERT_TRUE(instance_norm->options.track_running_stats());
1804
1805 ASSERT_TRUE(instance_norm->running_mean.defined());
1806 ASSERT_EQ(instance_norm->running_mean.dim(), 1);
1807 ASSERT_EQ(instance_norm->running_mean.size(0), 5);
1808
1809 ASSERT_TRUE(instance_norm->running_var.defined());
1810 ASSERT_EQ(instance_norm->running_var.dim(), 1);
1811 ASSERT_EQ(instance_norm->running_var.size(0), 5);
1812
1813 ASSERT_TRUE(instance_norm->num_batches_tracked.defined());
1814 ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0);
1815
1816 ASSERT_TRUE(instance_norm->options.affine());
1817
1818 ASSERT_TRUE(instance_norm->weight.defined());
1819 ASSERT_EQ(instance_norm->weight.dim(), 1);
1820 ASSERT_EQ(instance_norm->weight.size(0), 5);
1821
1822 ASSERT_TRUE(instance_norm->bias.defined());
1823 ASSERT_EQ(instance_norm->bias.dim(), 1);
1824 ASSERT_EQ(instance_norm->bias.size(0), 5);
1825}
1826
1827TEST_F(ModulesTest, InstanceNorm3dStateless) {
1828 InstanceNorm3d instance_norm(
1829 InstanceNorm3dOptions(5).track_running_stats(false).affine(false));
1830
1831 ASSERT_FALSE(instance_norm->running_mean.defined());
1832 ASSERT_FALSE(instance_norm->running_var.defined());
1833 ASSERT_FALSE(instance_norm->num_batches_tracked.defined());
1834 ASSERT_FALSE(instance_norm->weight.defined());
1835 ASSERT_FALSE(instance_norm->bias.defined());
1836}
1837
1838TEST_F(ModulesTest, InstanceNorm3d) {
1839 InstanceNorm3d instance_norm(5);
1840 instance_norm->eval();
1841
1842 auto input =
1843 torch::arange(2. * 5 * 2 * 2 * 2).view({2, 5, 2, 2, 2}).requires_grad_();
1844 auto output = instance_norm->forward(input);
1845 auto expected = torch::tensor(
1846 {{{{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1847 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1848 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1849 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1850 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1851 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1852 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1853 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1854 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1855 {{0.2182, 0.6547}, {1.0911, 1.5275}}}},
1856 {{{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1857 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1858 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1859 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1860 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1861 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1862 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1863 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
1864 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
1865 {{0.2182, 0.6547}, {1.0911, 1.5275}}}}});
1866 ASSERT_TRUE(output.allclose(expected, 1e-3));
1867 auto s = output.sum();
1868 s.backward();
1869
1870 ASSERT_EQ(input.sizes(), input.grad().sizes());
1871}
1872
1873TEST_F(ModulesTest, Linear_CUDA) {
1874 Linear model(5, 2);
1875 model->to(torch::kCUDA);
1876 auto x =
1877 torch::randn({10, 5}, torch::device(torch::kCUDA).requires_grad(true));
1878 auto y = model(x);
1879 torch::Tensor s = y.sum();
1880
1881 s.backward();
1882 ASSERT_EQ(y.ndimension(), 2);
1883 ASSERT_EQ(s.ndimension(), 0);
1884 ASSERT_EQ(y.size(0), 10);
1885 ASSERT_EQ(y.size(1), 2);
1886
1887 ASSERT_EQ(model->weight.grad().numel(), 2 * 5);
1888}
1889
1890TEST_F(ModulesTest, Linear2_CUDA) {
1891 Linear model(5, 2);
1892 model->to(torch::kCUDA);
1893 model->to(torch::kCPU);
1894 auto x = torch::randn({10, 5}, torch::requires_grad());
1895 auto y = model(x);
1896 torch::Tensor s = y.sum();
1897
1898 s.backward();
1899 ASSERT_EQ(y.ndimension(), 2);
1900 ASSERT_EQ(s.ndimension(), 0);
1901 ASSERT_EQ(y.size(0), 10);
1902 ASSERT_EQ(y.size(1), 2);
1903
1904 ASSERT_EQ(model->weight.grad().numel(), 2 * 5);
1905}
1906
1907TEST_F(ModulesTest, L1Loss) {
1908 L1Loss loss;
1909 auto input = torch::randn({5, 6}, torch::requires_grad());
1910 auto target = torch::empty({5, 6}).random_(2);
1911 auto output = loss->forward(torch::sigmoid(input), target);
1912 auto s = output.sum();
1913 s.backward();
1914
1915 ASSERT_EQ(output.sizes(), std::vector<int64_t>());
1916 ASSERT_EQ(input.sizes(), input.grad().sizes());
1917}
1918
1919TEST_F(ModulesTest, MSELoss) {
1920 MSELoss loss;
1921 auto input = torch::randn({5, 6}, torch::requires_grad());
1922 auto target = torch::empty({5, 6}).random_(2);
1923 auto output = loss->forward(torch::sigmoid(input), target);
1924 auto s = output.sum();
1925 s.backward();
1926
1927 ASSERT_EQ(output.sizes(), torch::IntArrayRef());
1928 ASSERT_EQ(input.sizes(), input.grad().sizes());
1929}
1930
1931TEST_F(ModulesTest, BCELoss) {
1932 BCELoss loss;
1933 auto input = torch::randn({5, 6}, torch::requires_grad());
1934 auto target = torch::empty({5, 6}).random_(2);
1935 auto output = loss->forward(torch::sigmoid(input), target);
1936 auto s = output.sum();
1937 s.backward();
1938
1939 ASSERT_EQ(output.sizes(), torch::IntArrayRef());
1940 ASSERT_EQ(input.sizes(), input.grad().sizes());
1941}
1942
1943TEST_F(ModulesTest, KLDivLoss) {
1944 KLDivLoss loss;
1945 auto input = torch::randn({5, 6}, torch::requires_grad());
1946 auto target = torch::empty({5, 6}).random_(2);
1947 auto output = loss->forward(torch::sigmoid(input), target);
1948 auto s = output.sum();
1949 s.backward();
1950
1951 ASSERT_EQ(output.sizes(), torch::IntArrayRef());
1952 ASSERT_EQ(input.sizes(), input.grad().sizes());
1953}
1954
1955TEST_F(ModulesTest, HingeEmbeddingLoss) {
1956 HingeEmbeddingLoss loss(HingeEmbeddingLossOptions().margin(2));
1957 auto input = torch::tensor(
1958 {{2, 22, 4}, {20, 10, 0}},
1959 torch::dtype(torch::kFloat).requires_grad(true));
1960 auto target = torch::tensor({{2, 6, 4}, {1, 10, 0}}, torch::kFloat);
1961 auto output = loss->forward(input, target);
1962 auto expected = torch::tensor({10}, torch::kFloat);
1963 auto s = output.sum();
1964 s.backward();
1965
1966 ASSERT_TRUE(output.allclose(expected));
1967 ASSERT_EQ(input.sizes(), input.grad().sizes());
1968}
1969
1970TEST_F(ModulesTest, MultiMarginLoss) {
1971 auto weight = torch::tensor({0.3, 0.3, 0.4}, torch::kFloat);
1972 MultiMarginLoss loss(MultiMarginLossOptions().margin(2).weight(weight));
1973 auto input = torch::tensor(
1974 {{0.2, 0.2, 0.6}, {0.1, 0.8, 0.1}, {0.9, 0.09, 0.01}},
1975 torch::dtype(torch::kFloat).requires_grad(true));
1976 auto target = torch::tensor({2, 1, 0}, torch::kLong);
1977 auto output = loss->forward(input, target);
1978 auto expected = torch::tensor({0.305556}, torch::kFloat);
1979 auto s = output.sum();
1980 s.backward();
1981
1982 ASSERT_TRUE(output.allclose(expected, 1e-04));
1983 ASSERT_EQ(input.sizes(), input.grad().sizes());
1984}
1985
1986TEST_F(ModulesTest, CosineEmbeddingLoss) {
1987 CosineEmbeddingLoss cos(CosineEmbeddingLossOptions().margin(0.5));
1988 auto input1 = torch::tensor(
1989 {{2, 3, 4}, {6, 2, 4}}, torch::dtype(torch::kFloat).requires_grad(true));
1990 auto input2 = torch::tensor(
1991 {{2, 3, 5}, {9, 12, 0}}, torch::dtype(torch::kFloat).requires_grad(true));
1992 auto target = torch::tensor({1, -1});
1993 auto output = cos(input1, input2, target);
1994 auto expected = torch::tensor({0.1004}, torch::kFloat);
1995 auto s = output.sum();
1996 s.backward();
1997
1998 ASSERT_TRUE(output.allclose(expected, 1e-4));
1999 ASSERT_EQ(input1.sizes(), input1.grad().sizes());
2000 ASSERT_EQ(input2.sizes(), input2.grad().sizes());
2001}
2002
2003TEST_F(ModulesTest, SmoothL1LossDefaultOptions) {
2004 SmoothL1Loss loss;
2005 auto input = torch::tensor(
2006 {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
2007 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
2008 auto output = loss(input, target);
2009 auto expected = torch::tensor(0.0233335, torch::kFloat);
2010 auto s = output.sum();
2011 s.backward();
2012
2013 ASSERT_TRUE(output.allclose(expected));
2014 ASSERT_EQ(input.sizes(), input.grad().sizes());
2015}
2016
2017TEST_F(ModulesTest, HuberLossDefaultOptions) {
2018 HuberLoss loss;
2019 auto input = torch::tensor(
2020 {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
2021 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
2022 auto output = loss(input, target);
2023 auto expected = torch::tensor(0.0233335, torch::kFloat);
2024 auto s = output.sum();
2025 s.backward();
2026
2027 ASSERT_TRUE(output.allclose(expected));
2028 ASSERT_EQ(input.sizes(), input.grad().sizes());
2029}
2030
2031TEST_F(ModulesTest, MultiLabelMarginLossDefaultOptions) {
2032 MultiLabelMarginLoss loss;
2033 auto input = torch::tensor(
2034 {{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true));
2035 auto target = torch::tensor({{3, 0, -1, 1}}, torch::kLong);
2036 auto output = loss->forward(input, target);
2037 auto expected = torch::tensor({0.8500}, torch::kFloat);
2038 auto s = output.sum();
2039 s.backward();
2040
2041 ASSERT_TRUE(output.allclose(expected));
2042 ASSERT_EQ(input.sizes(), input.grad().sizes());
2043}
2044
2045TEST_F(ModulesTest, SmoothL1LossNoReduction) {
2046 SmoothL1Loss loss(/*reduction=*/torch::kNone);
2047 auto input = torch::tensor(
2048 {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
2049 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
2050 auto output = loss(input, target);
2051 auto expected = torch::tensor({0.005, 0.02, 0.045}, torch::kFloat);
2052 auto s = output.sum();
2053 s.backward();
2054
2055 ASSERT_TRUE(output.allclose(expected));
2056 ASSERT_EQ(input.sizes(), input.grad().sizes());
2057}
2058
2059TEST_F(ModulesTest, HuberLossNoReduction) {
2060 HuberLoss loss(/*reduction=*/torch::kNone);
2061 auto input = torch::tensor(
2062 {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
2063 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
2064 auto output = loss(input, target);
2065 auto expected = torch::tensor({0.005, 0.02, 0.045}, torch::kFloat);
2066 auto s = output.sum();
2067 s.backward();
2068
2069 ASSERT_TRUE(output.allclose(expected));
2070 ASSERT_EQ(input.sizes(), input.grad().sizes());
2071}
2072
2073TEST_F(ModulesTest, MultiLabelMarginLossNoReduction) {
2074 MultiLabelMarginLoss loss(torch::kNone);
2075 auto input = torch::tensor(
2076 {{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true));
2077 auto target = torch::tensor({{3, 0, -1, 1}}, torch::kLong);
2078 auto output = loss->forward(input, target);
2079 auto expected = torch::tensor({0.8500}, torch::kFloat);
2080 auto s = output.sum();
2081 s.backward();
2082
2083 ASSERT_TRUE(output.allclose(expected));
2084 ASSERT_EQ(input.sizes(), input.grad().sizes());
2085}
2086
2087TEST_F(ModulesTest, SmoothL1LossBeta) {
2088 auto options = SmoothL1LossOptions().beta(0.2);
2089 SmoothL1Loss loss(options);
2090 auto input = torch::tensor(
2091 {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
2092 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
2093 auto output = loss(input, target);
2094 auto expected = torch::tensor(0.108333, torch::kFloat);
2095 auto s = output.sum();
2096 s.backward();
2097
2098 ASSERT_TRUE(output.allclose(expected));
2099 ASSERT_EQ(input.sizes(), input.grad().sizes());
2100}
2101
2102TEST_F(ModulesTest, HuberLossDelta) {
2103 auto options = HuberLossOptions().delta(0.2);
2104 HuberLoss loss(options);
2105 auto input = torch::tensor(
2106 {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
2107 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
2108 auto output = loss(input, target);
2109 auto expected = torch::tensor(0.0216666, torch::kFloat);
2110 auto s = output.sum();
2111 s.backward();
2112
2113 ASSERT_TRUE(output.allclose(expected));
2114 ASSERT_EQ(input.sizes(), input.grad().sizes());
2115}
2116
2117TEST_F(ModulesTest, TripletMarginLoss) {
2118 TripletMarginLoss loss(TripletMarginLossOptions().margin(1.0));
2119 auto anchor = torch::tensor(
2120 {{3., 3.}}, torch::dtype(torch::kFloat).requires_grad(true));
2121 auto positive = torch::tensor(
2122 {{2., 2.}}, torch::dtype(torch::kFloat).requires_grad(true));
2123 auto negative = torch::tensor(
2124 {{0., 0.}}, torch::dtype(torch::kFloat).requires_grad(true));
2125 auto output = loss->forward(anchor, positive, negative);
2126 auto expected = torch::tensor({0.}, torch::kFloat);
2127 auto s = output.sum();
2128 s.backward();
2129
2130 ASSERT_TRUE(output.allclose(expected, 1e-04));
2131 ASSERT_EQ(anchor.sizes(), anchor.grad().sizes());
2132}
2133
2134TEST_F(ModulesTest, TripletMarginWithDistanceLossDefaultParity) {
2135 // Check that if we use torch::pairwise_distance with the default
2136 // TripletMarginLoss options as our distance function, the outputs
2137 // are equal (i.e., equal under defaults).
2138
2139 std::vector<TripletMarginWithDistanceLossOptions::reduction_t> reductions = {
2140 torch::kSum, torch::kMean, torch::kNone};
2141 std::vector<float> margins = {0.5, 1.0, 1.5};
2142 std::vector<bool> swaps = {true, false};
2143
2144 for (auto& reduction : reductions) {
2145 for (auto& margin : margins) {
2146 for (const auto swap : swaps) {
2147 auto anchor = torch::randn(
2148 {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
2149 auto positive = torch::randn(
2150 {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
2151 auto negative = torch::randn(
2152 {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
2153
2154 auto basicOptions =
2155 TripletMarginLossOptions().reduction(reduction).margin(margin).swap(
2156 swap);
2157 auto distanceOptions = TripletMarginWithDistanceLossOptions()
2158 .reduction(reduction)
2159 .margin(margin)
2160 .swap(swap);
2161 TripletMarginLoss basicLoss(basicOptions);
2162 TripletMarginWithDistanceLoss distanceLoss(distanceOptions);
2163
2164 auto basicOutput = basicLoss->forward(anchor, positive, negative);
2165 auto distanceOutput = distanceLoss->forward(anchor, positive, negative);
2166 auto basicOperatorOutput = basicLoss(anchor, positive, negative);
2167 auto distanceOperatorOutput = distanceLoss(anchor, positive, negative);
2168
2169 ASSERT_TRUE(distanceOutput.allclose(basicOutput, 1e-6, 1e-6));
2170 ASSERT_TRUE(
2171 distanceOperatorOutput.allclose(distanceOutput, 1e-6, 1e-6));
2172 ASSERT_TRUE(
2173 distanceOperatorOutput.allclose(basicOperatorOutput, 1e-6, 1e-6));
2174
2175 // handle for torch::kNone reduction
2176 auto sum = distanceOutput.sum();
2177 sum.backward();
2178 ASSERT_EQ(anchor.sizes(), anchor.grad().sizes());
2179 ASSERT_EQ(positive.sizes(), positive.grad().sizes());
2180 ASSERT_EQ(negative.sizes(), negative.grad().sizes());
2181 }
2182 }
2183 }
2184}
2185
2186TEST_F(ModulesTest, TripletMarginWithDistanceLossFunctionalParity) {
2187 // Check for parity between F::triplet_margin_with_distance_loss and
2188 // TripletMarginWithDistanceLoss.
2189 auto pairwise_distance = [&](const torch::Tensor& x, const torch::Tensor& y) {
2190 return torch::pairwise_distance(x, y);
2191 };
2192 auto cosine_distance = [&](const torch::Tensor& x, const torch::Tensor& y) {
2193 return 1.0 - torch::cosine_similarity(x, y);
2194 };
2195 std::vector<TripletMarginWithDistanceLossOptions::distance_function_t>
2196 distance_functions = {pairwise_distance, cosine_distance};
2197
2198 std::vector<TripletMarginWithDistanceLossOptions::reduction_t> reductions = {
2199 torch::kSum, torch::kMean, torch::kNone};
2200 std::vector<float> margins = {0.5, 1.0, 1.5};
2201 std::vector<bool> swaps = {true, false};
2202
2203 for (auto& function : distance_functions) {
2204 for (auto& reduction : reductions) {
2205 for (auto& margin : margins) {
2206 for (const auto swap : swaps) {
2207 auto moduleOptions = TripletMarginWithDistanceLossOptions()
2208 .distance_function(function)
2209 .reduction(reduction)
2210 .margin(margin)
2211 .swap(swap);
2212 auto functionOptions =
2213 torch::nn::functional::TripletMarginWithDistanceLossFuncOptions()
2214 .distance_function(function)
2215 .reduction(reduction)
2216 .margin(margin)
2217 .swap(swap);
2218
2219 auto anchor = torch::randn(
2220 {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
2221 auto positive = torch::randn(
2222 {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
2223 auto negative = torch::randn(
2224 {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
2225
2226 TripletMarginWithDistanceLoss distanceLoss(moduleOptions);
2227
2228 auto moduleOutput = distanceLoss->forward(anchor, positive, negative);
2229 auto moduleOperatorOutput = distanceLoss(anchor, positive, negative);
2230 auto functionOutput =
2231 torch::nn::functional::triplet_margin_with_distance_loss(
2232 anchor, positive, negative, functionOptions);
2233
2234 ASSERT_TRUE(moduleOutput.allclose(functionOutput, 1e-6, 1e-6));
2235 ASSERT_TRUE(
2236 moduleOperatorOutput.allclose(functionOutput, 1e-6, 1e-6));
2237 }
2238 }
2239 }
2240 }
2241}
2242
2243TEST_F(ModulesTest, NLLLoss) {
2244 NLLLoss loss;
2245 auto input = torch::tensor(
2246 {{-0.1315, -3.1315, -2.5315},
2247 {-3.7038, -0.1038, -2.6038},
2248 {-2.3422, -1.3422, -0.4422}},
2249 torch::dtype(torch::kFloat).requires_grad(true));
2250 auto target = torch::tensor({1, 0, 2}, torch::kLong);
2251 auto output = loss->forward(input, target);
2252 auto expected = torch::tensor(2.4258, torch::kFloat);
2253 auto s = output.sum();
2254 s.backward();
2255
2256 ASSERT_TRUE(output.allclose(expected, 1e-04));
2257 ASSERT_TRUE(
2258 NLLLoss(NLLLossOptions().ignore_index(-100).reduction(torch::kMean))
2259 ->forward(input, target)
2260 .allclose(expected, 1e-04));
2261}
2262
2263TEST_F(ModulesTest, CrossEntropyLoss) {
2264 CrossEntropyLoss loss;
2265 auto input = torch::tensor(
2266 {{3., 3.}, {2., 2.}}, torch::dtype(torch::kFloat).requires_grad(true));
2267 auto target = torch::tensor({0, 1}, torch::kLong);
2268 auto output = loss->forward(input, target);
2269 auto expected = torch::tensor(0.6931, torch::kFloat);
2270 auto s = output.sum();
2271 s.backward();
2272
2273 ASSERT_TRUE(output.allclose(expected, 1e-04));
2274 ASSERT_EQ(input.sizes(), input.grad().sizes());
2275 ASSERT_TRUE(
2276 CrossEntropyLoss(
2277 CrossEntropyLossOptions().ignore_index(-100).reduction(torch::kMean))
2278 ->forward(input, target)
2279 .allclose(expected, 1e-04));
2280
2281 // label smoothing with class indices
2282 loss = CrossEntropyLoss(
2283 CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kMean));
2284 input = torch::tensor(
2285 {{3., 1.}, {1., 2.}}, torch::dtype(torch::kFloat).requires_grad(true));
2286 target = torch::tensor({0, 1}, torch::kLong);
2287 output = loss->forward(input, target);
2288 expected = torch::tensor(0.3326, torch::kFloat);
2289 s = output.sum();
2290 s.backward();
2291
2292 ASSERT_TRUE(output.allclose(expected, 1e-04));
2293 ASSERT_EQ(input.sizes(), input.grad().sizes());
2294
2295 // label smoothing with with target probabilities
2296 loss = CrossEntropyLoss(
2297 CrossEntropyLossOptions().label_smoothing(0.2).reduction(torch::kMean));
2298 input = torch::tensor(
2299 {{3., 1.}, {1., 2.}}, torch::dtype(torch::kFloat).requires_grad(true));
2300 target = torch::tensor({{0.8, 0.2}, {0.1, 0.9}}, torch::kFloat);
2301 output = loss->forward(input, target);
2302 expected = torch::tensor(0.5701, torch::kFloat);
2303 s = output.sum();
2304 s.backward();
2305
2306 ASSERT_TRUE(output.allclose(expected, 1e-04));
2307 ASSERT_EQ(input.sizes(), input.grad().sizes());
2308}
2309
2310TEST_F(ModulesTest, CosineSimilarity) {
2311 CosineSimilarity cos(CosineSimilarityOptions().dim(1));
2312 auto input1 = torch::tensor(
2313 {{1, 2, 3}, {4, 5, 6}}, torch::dtype(torch::kFloat).requires_grad(true));
2314 auto input2 = torch::tensor(
2315 {{1, 8, 3}, {2, 1, 6}}, torch::dtype(torch::kFloat).requires_grad(true));
2316 auto output = cos->forward(input1, input2);
2317 auto expected = torch::tensor({0.8078, 0.8721}, torch::kFloat);
2318 auto s = output.sum();
2319 s.backward();
2320
2321 ASSERT_TRUE(output.allclose(expected, 1e-04));
2322 ASSERT_EQ(input1.sizes(), input1.grad().sizes());
2323}
2324
2325TEST_F(ModulesTest, SoftMarginLossDefaultOptions) {
2326 SoftMarginLoss loss;
2327 auto input = torch::tensor(
2328 {2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true));
2329 auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat);
2330 auto output = loss->forward(input, target);
2331 auto expected = torch::tensor({1.3767317}, torch::kFloat);
2332 auto s = output.sum();
2333 s.backward();
2334
2335 ASSERT_TRUE(output.allclose(expected));
2336 ASSERT_EQ(input.sizes(), input.grad().sizes());
2337}
2338
2339TEST_F(ModulesTest, MultiLabelSoftMarginLossDefaultOptions) {
2340 MultiLabelSoftMarginLoss loss;
2341 auto input = torch::tensor(
2342 {{0., 2., 2., 0.}, {2., 1., 0., 1.}},
2343 torch::dtype(torch::kFloat).requires_grad(true));
2344 auto target =
2345 torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat);
2346 auto output = loss->forward(input, target);
2347 auto expected = torch::tensor({0.7608436}, torch::kFloat);
2348 auto s = output.sum();
2349 s.backward();
2350
2351 ASSERT_TRUE(output.allclose(expected));
2352 ASSERT_EQ(input.sizes(), input.grad().sizes());
2353}
2354
2355TEST_F(ModulesTest, SoftMarginLossNoReduction) {
2356 SoftMarginLoss loss(torch::kNone);
2357 auto input = torch::tensor(
2358 {2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true));
2359 auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat);
2360 auto output = loss->forward(input, target);
2361 auto expected = torch::tensor(
2362 {2.1269281, 0.01814993, 0.3132617, 3.0485873}, torch::kFloat);
2363 auto s = output.sum();
2364 s.backward();
2365
2366 ASSERT_TRUE(output.allclose(expected));
2367 ASSERT_EQ(input.sizes(), input.grad().sizes());
2368}
2369
2370TEST_F(ModulesTest, MultiLabelSoftMarginLossWeightedNoReduction) {
2371 auto input = torch::tensor(
2372 {{0., 2., 2., 0.}, {2., 1., 0., 1.}},
2373 torch::dtype(torch::kFloat).requires_grad(true));
2374 auto target =
2375 torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat);
2376 auto weight = torch::tensor({0.1, 0.6, 0.4, 0.8}, torch::kFloat);
2377 auto options =
2378 MultiLabelSoftMarginLossOptions().reduction(torch::kNone).weight(weight);
2379 MultiLabelSoftMarginLoss loss = MultiLabelSoftMarginLoss(options);
2380 auto output = loss->forward(input, target);
2381 auto expected = torch::tensor({0.4876902, 0.3321295}, torch::kFloat);
2382 auto s = output.sum();
2383 s.backward();
2384
2385 ASSERT_TRUE(output.allclose(expected));
2386 ASSERT_EQ(input.sizes(), input.grad().sizes());
2387}
2388
2389TEST_F(ModulesTest, PairwiseDistance) {
2390 PairwiseDistance dist(PairwiseDistanceOptions().p(1));
2391 auto input1 = torch::tensor(
2392 {{1, 2, 3}, {4, 5, 6}}, torch::dtype(torch::kFloat).requires_grad(true));
2393 auto input2 = torch::tensor(
2394 {{1, 8, 3}, {2, 1, 6}}, torch::dtype(torch::kFloat).requires_grad(true));
2395 auto output = dist->forward(input1, input2);
2396 auto expected = torch::tensor({6, 6}, torch::kFloat);
2397 auto s = output.sum();
2398 s.backward();
2399
2400 ASSERT_TRUE(output.allclose(expected));
2401 ASSERT_EQ(input1.sizes(), input1.grad().sizes());
2402}
2403
2404TEST_F(ModulesTest, ELU) {
2405 const auto size = 3;
2406 for (const auto alpha : {0.0, 0.42, 1.0, 4.2, 42.42}) {
2407 for (const auto inplace : {false, true}) {
2408 ELU model{ELUOptions().alpha(alpha).inplace(inplace)};
2409 auto x = torch::linspace(-10.0, 10.0, size * size * size);
2410 x.resize_({size, size, size});
2411 if (!inplace) {
2412 x.requires_grad_(true);
2413 }
2414 auto x_orig = x.clone();
2415 auto y = model(x);
2416 torch::Tensor s = y.sum();
2417
2418 ASSERT_EQ(s.ndimension(), 0);
2419
2420 ASSERT_EQ(y.ndimension(), 3);
2421 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2422 auto y_exp = torch::max(torch::zeros_like(x_orig), x_orig) +
2423 torch::min(torch::zeros_like(x_orig),
2424 alpha * (torch::exp(x_orig) - 1.0));
2425 ASSERT_TRUE(torch::allclose(y, y_exp));
2426 if (inplace) {
2427 ASSERT_TRUE(torch::allclose(x, y_exp));
2428 } else {
2429 s.backward();
2430 }
2431 }
2432 }
2433}
2434
2435TEST_F(ModulesTest, SELU) {
2436 for (const auto inplace : {false, true}) {
2437 SELU model(inplace);
2438 auto input = torch::randn({5, 5});
2439 if (!inplace) {
2440 input.requires_grad_(true);
2441 }
2442 auto input_orig = input.clone();
2443 auto output = model->forward(input);
2444 const double scale = 1.0507009873554804934193349852946;
2445 const double alpha = 1.6732632423543772848170429916717;
2446 auto zero = torch::zeros_like(input);
2447 auto expected = scale *
2448 (torch::max(zero, input_orig) +
2449 torch::min(zero, alpha * (torch::exp(input_orig) - 1)));
2450 auto s = output.sum();
2451
2452 ASSERT_EQ(s.ndimension(), 0);
2453 ASSERT_TRUE(output.allclose(expected));
2454 if (inplace) {
2455 ASSERT_TRUE(input.allclose(expected));
2456 } else {
2457 s.backward();
2458 }
2459 }
2460}
2461
2462TEST_F(ModulesTest, Hardshrink) {
2463 const auto size = 3;
2464 for (const auto lambda : {-4.2, -1.0, -0.42, 0.0, 0.42, 1.0, 4.2, 42.42}) {
2465 Hardshrink model{HardshrinkOptions().lambda(lambda)};
2466 auto x = torch::linspace(-10.0, 10.0, size * size * size);
2467 x.resize_({size, size, size}).set_requires_grad(true);
2468 auto y = model(x);
2469 torch::Tensor s = y.sum();
2470
2471 s.backward();
2472 ASSERT_EQ(s.ndimension(), 0);
2473 ASSERT_EQ(y.ndimension(), 3);
2474 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2475 auto y_exp = (x.abs() > lambda) * x;
2476 ASSERT_TRUE(torch::allclose(y, y_exp));
2477 }
2478}
2479
2480TEST_F(ModulesTest, Hardtanh) {
2481 const auto size = 3;
2482 for (const auto min_val : {-4.2, -1.0, -0.42, 0.0}) {
2483 for (const auto max_val : {0.42, 1.0, 4.2}) {
2484 for (const auto inplace : {false, true}) {
2485 Hardtanh model{
2486 HardtanhOptions().min_val(min_val).max_val(max_val).inplace(
2487 inplace)};
2488 auto x = torch::linspace(-10.0, 10.0, size * size * size);
2489 x.resize_({size, size, size});
2490 if (!inplace) {
2491 x.requires_grad_(true);
2492 }
2493 auto x_orig = x.clone();
2494 auto y = model(x);
2495 torch::Tensor s = y.sum();
2496
2497 ASSERT_EQ(s.ndimension(), 0);
2498 ASSERT_EQ(y.ndimension(), 3);
2499 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2500 auto y_exp = (x_orig < min_val) * min_val +
2501 ((x_orig >= min_val) * (x_orig <= max_val)) * x_orig +
2502 (x_orig > max_val) * max_val;
2503 ASSERT_TRUE(torch::allclose(y, y_exp));
2504 if (inplace) {
2505 ASSERT_TRUE(torch::allclose(x, y_exp));
2506 } else {
2507 s.backward();
2508 }
2509 }
2510 }
2511 }
2512}
2513
2514TEST_F(ModulesTest, HardtanhMinValGEMaxVal) {
2515 ASSERT_THROWS_WITH(
2516 Hardtanh{HardtanhOptions().min_val(0.42).max_val(0.42)},
2517 "max_val must be greater than min_val");
2518 ASSERT_THROWS_WITH(
2519 Hardtanh{HardtanhOptions().min_val(0.42).max_val(-0.42)},
2520 "max_val must be greater than min_val");
2521
2522 Hardtanh ht{HardtanhOptions().min_val(-0.42).max_val(0.42)};
2523 ht->options.min_val(0.42);
2524 ASSERT_THROWS_WITH(ht->reset(), "max_val must be greater than min_val");
2525 ht->options.max_val(-0.42);
2526 ASSERT_THROWS_WITH(ht->reset(), "max_val must be greater than min_val");
2527}
2528
2529TEST_F(ModulesTest, LeakyReLU) {
2530 const auto size = 3;
2531 for (const auto inplace : {false, true}) {
2532 for (const auto negative_slope : {0.0, 0.42, 1.0}) {
2533 for (const auto type : {torch::kFloat, torch::kBFloat16}) {
2534 LeakyReLU model{
2535 LeakyReLUOptions().negative_slope(negative_slope).inplace(inplace)};
2536 auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type);
2537 x.resize_({size, size, size});
2538 if (!inplace) {
2539 x.requires_grad_(true);
2540 }
2541 auto x_orig = x.clone();
2542 auto y = model(x);
2543 torch::Tensor s = y.sum();
2544
2545 ASSERT_EQ(s.ndimension(), 0);
2546 ASSERT_EQ(y.ndimension(), 3);
2547 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2548 auto y_exp =
2549 (x_orig < 0) * x_orig * negative_slope + (x_orig >= 0) * x_orig;
2550 ASSERT_TRUE(torch::allclose(y, y_exp));
2551 if (inplace) {
2552 ASSERT_TRUE(torch::allclose(x, y_exp));
2553 } else {
2554 s.backward();
2555 }
2556 }
2557 }
2558 }
2559}
2560
2561TEST_F(ModulesTest, LogSigmoid) {
2562 const auto size = 3;
2563 LogSigmoid model;
2564 auto x = torch::linspace(-10.0, 10.0, size * size * size);
2565 x.resize_({size, size, size}).set_requires_grad(true);
2566 auto y = model(x);
2567 torch::Tensor s = y.sum();
2568
2569 s.backward();
2570 ASSERT_EQ(s.ndimension(), 0);
2571
2572 ASSERT_EQ(y.ndimension(), 3);
2573 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2574 auto y_exp = torch::log(
2575 torch::ones_like(x) / (torch::ones_like(x) + torch::exp(torch::neg(x))));
2576 ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7));
2577}
2578
2579TEST_F(ModulesTest, Softmax) {
2580 Softmax m(/*dim=*/1);
2581 auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
2582 auto output = m(input);
2583 auto sum = torch::sum(torch::exp(input), 1);
2584
2585 for (const auto i : c10::irange(2)) {
2586 auto expected = torch::exp(input[i]) / sum[i];
2587 ASSERT_TRUE(torch::allclose(output[i], expected));
2588 }
2589}
2590
2591TEST_F(ModulesTest, Softmin) {
2592 Softmin m(/*dim=*/1);
2593 auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
2594 auto output = m(input);
2595 auto sum = torch::sum(torch::exp(-input), 1);
2596
2597 for (const auto i : c10::irange(2)) {
2598 auto expected = torch::exp(-input[i]) / sum[i];
2599 ASSERT_TRUE(torch::allclose(output[i], expected));
2600 }
2601}
2602
2603TEST_F(ModulesTest, LogSoftmax) {
2604 LogSoftmax m(/*dim=*/1);
2605 auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
2606 auto output = m(input);
2607 auto sum = torch::sum(torch::exp(input), 1);
2608
2609 for (const auto i : c10::irange(2)) {
2610 auto expected = torch::log(torch::exp(input[i]) / sum[i]);
2611 ASSERT_TRUE(torch::allclose(output[i], expected));
2612 }
2613}
2614
2615TEST_F(ModulesTest, AdaptiveLogSoftmaxWithLoss) {
2616 {
2617 // log_probs actually returns log_proba
2618 AdaptiveLogSoftmaxWithLoss asfm(
2619 AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.));
2620 auto x = torch::randn({4, 8});
2621 auto logprob_out = asfm->log_prob(x);
2622 ASSERT_TRUE(
2623 torch::allclose(torch::exp(logprob_out).data().sum(1), torch::ones(4)));
2624 }
2625 {
2626 // test predict
2627 AdaptiveLogSoftmaxWithLoss asfm(
2628 AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8})
2629 .div_value(2.)
2630 .head_bias(true));
2631 auto x = torch::randn({64, 8});
2632 auto logprob_out = asfm->log_prob(x);
2633 auto predict_out = asfm->predict(x);
2634 ASSERT_TRUE(torch::allclose(predict_out, logprob_out.argmax(1)));
2635 }
2636 {
2637 // cluster sizes
2638 AdaptiveLogSoftmaxWithLoss asfm(
2639 AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(2.));
2640 auto x = torch::arange(100, 132, torch::kFloat).reshape({2, 16});
2641 auto y = torch::tensor({0, 17}, torch::kLong);
2642 auto asm_out = asfm(x, y);
2643 ASSERT_EQ(asm_out.output.sizes(), std::vector<int64_t>({2}));
2644 }
2645 {
2646 // forward returns the same thing as log_probs
2647 AdaptiveLogSoftmaxWithLoss asfm(
2648 AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.));
2649 auto x = torch::randn({4, 8});
2650 auto logprob_out = asfm->log_prob(x);
2651 NLLLoss nll_loss;
2652
2653 for (const auto v : c10::irange(4)) {
2654 auto y = torch::full({4}, v, torch::kLong);
2655 auto asm_out = asfm(x, y);
2656 auto out = asm_out.output;
2657 auto loss = torch::tensor(asm_out.loss, torch::kFloat);
2658 auto expected = nll_loss->forward(logprob_out, y);
2659
2660 ASSERT_TRUE(torch::allclose(loss, expected));
2661 ASSERT_TRUE(torch::allclose(
2662 out, logprob_out.gather(1, y.unsqueeze(1)).squeeze()));
2663 }
2664 }
2665 {
2666 // test no batch dim
2667 AdaptiveLogSoftmaxWithLoss asfm(
2668 AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(2.));
2669 auto x = torch::randn({1, 16});
2670 auto y = torch::tensor({17});
2671 auto x2 = x.squeeze(0);
2672 auto y2 = y.squeeze(0);
2673 ASSERT_TRUE(
2674 torch::allclose(asfm(x, y).output.squeeze(0), asfm(x2, y2).output));
2675 }
2676}
2677
2678TEST_F(ModulesTest, Softmax2d) {
2679 Softmax2d m;
2680 auto input = torch::arange(24, torch::kFloat).reshape({1, 2, 3, 4});
2681 auto output = m(input);
2682 auto sum = torch::sum(torch::exp(input), 1);
2683
2684 for (const auto i : c10::irange(1)) {
2685 for (const auto j : c10::irange(2)) {
2686 for (const auto k : c10::irange(3)) {
2687 for (const auto l : c10::irange(4)) {
2688 auto expected = torch::exp(input[i][j][k][l]) / sum[i][k][l];
2689 ASSERT_TRUE(torch::allclose(output[i][j][k][l], expected));
2690 }
2691 }
2692 }
2693 }
2694}
2695
2696TEST_F(ModulesTest, PReLU) {
2697 const auto num_parameters = 42;
2698 const auto init = 0.42;
2699
2700 PReLU model{PReLUOptions().num_parameters(num_parameters).init(init)};
2701
2702 ASSERT_EQ(model->weight.sizes(), std::vector<int64_t>({num_parameters}));
2703 ASSERT_TRUE(
2704 torch::allclose(model->weight, torch::full(num_parameters, init)));
2705
2706 const auto x = torch::rand({100, num_parameters}) * 200 - 100;
2707 const auto y = model(x);
2708 const auto s = y.sum();
2709
2710 s.backward();
2711 ASSERT_EQ(s.ndimension(), 0);
2712
2713 ASSERT_EQ(y.ndimension(), x.ndimension());
2714 ASSERT_EQ(y.sizes(), x.sizes());
2715 const auto y_exp = (x < 0) * model->weight * x + (x >= 0) * x;
2716 ASSERT_TRUE(torch::allclose(y, y_exp));
2717}
2718
2719TEST_F(ModulesTest, ReLU) {
2720 for (const auto inplace : {false, true}) {
2721 const auto size = 3;
2722 ReLU model(inplace);
2723 auto x = torch::linspace(-10.0, 10.0, size * size * size);
2724 x.resize_({size, size, size});
2725 if (!inplace) {
2726 x.requires_grad_(true);
2727 }
2728 auto x_orig = x.clone();
2729 auto y = model(x);
2730 torch::Tensor s = y.sum();
2731
2732 ASSERT_EQ(s.ndimension(), 0);
2733 ASSERT_EQ(y.ndimension(), 3);
2734 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2735 auto y_exp = (x_orig < 0) * 0 + (x_orig >= 0) * x_orig;
2736 ASSERT_TRUE(torch::allclose(y, y_exp));
2737 if (inplace) {
2738 ASSERT_TRUE(torch::allclose(x, y_exp));
2739 } else {
2740 s.backward();
2741 }
2742 }
2743}
2744
2745TEST_F(ModulesTest, ReLU6) {
2746 for (const auto inplace : {false, true}) {
2747 const auto size = 3;
2748 ReLU6 model(inplace);
2749 auto x = torch::linspace(-10.0, 10.0, size * size * size);
2750 x.resize_({size, size, size});
2751 if (!inplace) {
2752 x.requires_grad_(true);
2753 }
2754 auto x_orig = x.clone();
2755 auto y = model(x);
2756 torch::Tensor s = y.sum();
2757
2758 ASSERT_EQ(s.ndimension(), 0);
2759 ASSERT_EQ(y.ndimension(), 3);
2760 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2761 auto y_exp = (x_orig < 0) * 0 + ((x_orig >= 0) * (x_orig <= 6)) * x_orig +
2762 (x_orig > 6) * 6;
2763 ASSERT_TRUE(torch::allclose(y, y_exp));
2764 if (inplace) {
2765 ASSERT_TRUE(torch::allclose(x, y_exp));
2766 } else {
2767 s.backward();
2768 }
2769 }
2770}
2771
2772TEST_F(ModulesTest, RReLU) {
2773 const auto size = 3;
2774 for (const auto lower : {0.01, 0.1, 0.2}) {
2775 for (const auto upper : {0.3, 0.4, 0.5}) {
2776 for (const auto inplace : {false, true}) {
2777 for (const auto type : {torch::kFloat, torch::kBFloat16}) {
2778 RReLU model{
2779 RReLUOptions().lower(lower).upper(upper).inplace(inplace)};
2780 auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type);
2781 x.resize_({size, size, size});
2782 if (!inplace) {
2783 x.requires_grad_(true);
2784 }
2785 auto x_orig = x.clone();
2786 auto y = model(x);
2787 torch::Tensor s = y.sum();
2788
2789 ASSERT_EQ(s.ndimension(), 0);
2790 ASSERT_EQ(y.ndimension(), 3);
2791 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2792 auto z =
2793 ((x_orig >= 0) * (x_orig == y) +
2794 (x_orig < 0) * (y >= x_orig * upper) * (y <= lower * x_orig)) *
2795 1.0;
2796 ASSERT_TRUE(torch::allclose(z, torch::ones_like(z)));
2797 if (inplace) {
2798 ASSERT_TRUE(torch::allclose(x, y));
2799 } else {
2800 s.backward();
2801 }
2802 }
2803 }
2804 }
2805 }
2806}
2807
2808TEST_F(ModulesTest, CELU) {
2809 const auto size = 3;
2810 for (const auto inplace : {false, true}) {
2811 for (const auto alpha : {0.42, 1.0, 4.2, 42.42}) {
2812 CELU model{CELUOptions().alpha(alpha).inplace(inplace)};
2813 auto x = torch::linspace(-10.0, 10.0, size * size * size);
2814 x.resize_({size, size, size});
2815 if (!inplace) {
2816 x.requires_grad_(true);
2817 }
2818 auto x_orig = x.clone();
2819 auto y = model(x);
2820 torch::Tensor s = y.sum();
2821
2822 ASSERT_EQ(s.ndimension(), 0);
2823 ASSERT_EQ(y.ndimension(), 3);
2824 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2825 auto y_exp = torch::max(torch::zeros_like(x_orig), x_orig) +
2826 torch::min(torch::zeros_like(x_orig),
2827 alpha * (torch::exp(x_orig / alpha) - 1.0));
2828 ASSERT_TRUE(torch::allclose(y, y_exp));
2829 if (inplace) {
2830 ASSERT_TRUE(torch::allclose(x, y_exp));
2831 } else {
2832 s.backward();
2833 }
2834 }
2835 }
2836}
2837
2838TEST_F(ModulesTest, GLU) {
2839 int64_t dim = 1;
2840 GLU model(dim);
2841 auto input = torch::randn({4, 2}, torch::requires_grad());
2842 auto output = model->forward(input);
2843 auto input_size = input.sizes()[dim] / 2;
2844 auto first_half = input.narrow(dim, 0, input_size);
2845 auto second_half = input.narrow(dim, input_size, input_size);
2846 auto expected = first_half * torch::sigmoid(second_half);
2847 auto s = output.sum();
2848 s.backward();
2849
2850 ASSERT_EQ(s.ndimension(), 0);
2851 ASSERT_TRUE(output.allclose(expected));
2852
2853 GLU model_default_options;
2854 ASSERT_TRUE(model_default_options->forward(input).allclose(expected));
2855}
2856
2857TEST_F(ModulesTest, GELU) {
2858 GELU model(GELUOptions().approximate("none"));
2859 const auto x = torch::linspace(-3.0, 3.0, 100);
2860 const auto y_exp = x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0)));
2861 const auto y = model(x);
2862 ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
2863}
2864
2865TEST_F(ModulesTest, TanhGELU) {
2866 GELU model(GELUOptions().approximate("tanh"));
2867 const auto x = torch::linspace(-3.0, 3.0, 100);
2868 const auto inner = std::sqrt(2 / M_PI) * (x + 0.044715 * x.pow(3.0));
2869 const auto y_exp = 0.5 * x * (1.0 + inner.tanh());
2870 const auto y = model(x);
2871 ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
2872}
2873
2874// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
2875TEST_F(ModulesTest, Mish) {
2876 Mish model;
2877 auto x = torch::randn(100) * 10;
2878 auto y_exp = x * x.exp().log1p().tanh();
2879 auto y = model(x);
2880
2881 ASSERT_TRUE(torch::allclose(y, y_exp));
2882}
2883
2884TEST_F(ModulesTest, Sigmoid) {
2885 Sigmoid model;
2886 auto x = torch::randn(100) * 10;
2887 auto y_exp = 1 / (1 + torch::exp(-x));
2888 auto y = model(x);
2889
2890 ASSERT_TRUE(torch::allclose(y, y_exp));
2891}
2892
2893TEST_F(ModulesTest, PixelShuffle) {
2894 PixelShuffle module(/*upscale_factor=*/2);
2895 auto x = torch::tensor(
2896 {{{{-17, 19}, {-1, 2}},
2897 {{7, 14}, {-3, 1}},
2898 {{0, -2}, {-12, 14}},
2899 {{-15, 0}, {-3, 9}}}},
2900 torch::kFloat);
2901 auto y_exp = torch::tensor(
2902 {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}},
2903 torch::kFloat);
2904 auto y = module(x);
2905
2906 ASSERT_EQ(y.ndimension(), 4);
2907 ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 4, 4}));
2908 ASSERT_TRUE(y.allclose(y_exp));
2909}
2910
2911TEST_F(ModulesTest, PixelUnshuffle) {
2912 PixelUnshuffle module(/*downscale_factor=*/2);
2913 auto x = torch::tensor(
2914 {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}},
2915 torch::kFloat);
2916 auto y_exp = torch::tensor(
2917 {{{{-17, 19}, {-1, 2}},
2918 {{7, 14}, {-3, 1}},
2919 {{0, -2}, {-12, 14}},
2920 {{-15, 0}, {-3, 9}}}},
2921 torch::kFloat);
2922 auto y = module(x);
2923
2924 ASSERT_EQ(y.ndimension(), 4);
2925 ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 4, 2, 2}));
2926 ASSERT_TRUE(y.allclose(y_exp));
2927}
2928
2929TEST_F(ModulesTest, Softplus) {
2930 const auto size = 3;
2931 for (const auto beta : {0.5, 1.0, 2.0}) {
2932 for (const auto threshold : {1.0, 3.0, 5.0}) {
2933 Softplus model{SoftplusOptions().beta(beta).threshold(threshold)};
2934 auto x = torch::linspace(-3.0, 3.0, 61);
2935 x.resize_({size, size, size});
2936 auto y_exp =
2937 (x <= threshold) * torch::log(1 + torch::exp(x * beta)) / beta +
2938 (x > threshold) * x;
2939 auto y = model(x);
2940
2941 ASSERT_EQ(y.ndimension(), 3);
2942 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2943 ASSERT_TRUE(torch::allclose(y, y_exp));
2944 }
2945 }
2946}
2947
2948TEST_F(ModulesTest, Softshrink) {
2949 const auto size = 3;
2950 for (const auto lambda : {0.0, 0.42, 1.0, 4.2, 42.42}) {
2951 Softshrink model{/*lambda=*/lambda};
2952 auto x = torch::linspace(-10.0, 10.0, size * size * size);
2953 x.resize_({size, size, size}).set_requires_grad(true);
2954 auto y = model(x);
2955 torch::Tensor s = y.sum();
2956
2957 s.backward();
2958 ASSERT_EQ(s.ndimension(), 0);
2959
2960 ASSERT_EQ(y.ndimension(), 3);
2961 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
2962 auto y_exp = (x < -lambda) * (x + lambda) + (x > lambda) * (x - lambda);
2963 ASSERT_TRUE(torch::allclose(y, y_exp));
2964 }
2965}
2966
2967TEST_F(ModulesTest, Softsign) {
2968 Softsign model;
2969 auto x = torch::randn(100) * 10;
2970 auto y_exp = x / (1 + x.abs());
2971 auto y = model(x);
2972
2973 ASSERT_TRUE(torch::allclose(y, y_exp));
2974}
2975
2976TEST_F(ModulesTest, Tanh) {
2977 Tanh model;
2978 auto x = torch::randn(100) * 10;
2979 auto y_exp = (x.exp() - (-x).exp()) / (x.exp() + (-x).exp());
2980 auto y = model(x);
2981
2982 ASSERT_TRUE(torch::allclose(y, y_exp));
2983}
2984
2985TEST_F(ModulesTest, Tanhshrink) {
2986 Tanhshrink model;
2987 auto x = torch::randn(100) * 10;
2988 auto y_exp = x - x.tanh();
2989 auto y = model(x);
2990
2991 ASSERT_TRUE(torch::allclose(y, y_exp));
2992}
2993
2994TEST_F(ModulesTest, Threshold) {
2995 const auto size = 3;
2996 for (const auto threshold : {0.5, 1.0, 2.0}) {
2997 for (const auto value : {0.5, 1.0, 2.0}) {
2998 for (const auto inplace : {false, true}) {
2999 Threshold model{ThresholdOptions(threshold, value).inplace(inplace)};
3000 auto x = torch::linspace(-3.0, 3.0, 61);
3001 x.resize_({size, size, size});
3002 auto x_orig = x.clone();
3003 auto y_exp =
3004 (x_orig <= threshold) * value + (x_orig > threshold) * x_orig;
3005 auto y = model(x);
3006
3007 ASSERT_EQ(y.ndimension(), 3);
3008 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
3009 ASSERT_TRUE(torch::allclose(y, y_exp));
3010 if (inplace) {
3011 ASSERT_TRUE(torch::allclose(x, y_exp));
3012 }
3013 }
3014 }
3015 }
3016}
3017
3018TEST_F(ModulesTest, Upsampling1D) {
3019 {
3020 Upsample model(UpsampleOptions()
3021 .size(std::vector<int64_t>({4}))
3022 .mode(torch::kNearest));
3023 auto input = torch::ones({1, 1, 2}, torch::requires_grad());
3024 auto output = model->forward(input);
3025 auto expected = torch::ones({1, 1, 4});
3026 auto s = output.sum();
3027 s.backward();
3028
3029 ASSERT_EQ(s.ndimension(), 0);
3030 ASSERT_TRUE(output.allclose(expected));
3031 }
3032 {
3033 for (const auto align_corners : {true, false}) {
3034 // test float scale factor up & down sampling
3035 for (const auto scale_factor : {0.5, 1.5, 2.0}) {
3036 Upsample model(UpsampleOptions()
3037 .scale_factor(std::vector<double>({scale_factor}))
3038 .mode(torch::kLinear)
3039 .align_corners(align_corners));
3040 auto input = torch::ones({1, 1, 2}, torch::requires_grad());
3041 auto output = model->forward(input);
3042 auto expected_size =
3043 static_cast<int64_t>(std::floor(input.size(-1) * scale_factor));
3044 auto expected = torch::ones({1, 1, expected_size});
3045 auto s = output.sum();
3046 s.backward();
3047
3048 ASSERT_EQ(s.ndimension(), 0);
3049 ASSERT_TRUE(output.allclose(expected));
3050 }
3051 }
3052 }
3053 {
3054 // linear (1D) upsampling spatial invariance
3055 Upsample model(UpsampleOptions()
3056 .scale_factor(std::vector<double>({3}))
3057 .mode(torch::kLinear)
3058 .align_corners(false));
3059 auto input = torch::zeros({1, 1, 9});
3060 input.narrow(2, 0, 4).normal_();
3061 auto output = model->forward(input);
3062 auto expected = model->forward(input.narrow(2, 0, 5));
3063
3064 ASSERT_TRUE(torch::allclose(output.narrow(2, 0, 15), expected));
3065 }
3066}
3067
3068TEST_F(ModulesTest, Upsampling2D) {
3069 {
3070 Upsample model(UpsampleOptions()
3071 .size(std::vector<int64_t>({4, 4}))
3072 .mode(torch::kNearest));
3073 auto input = torch::ones({1, 1, 2, 2}, torch::requires_grad());
3074 auto output = model->forward(input);
3075 auto expected = torch::ones({1, 1, 4, 4});
3076 auto s = output.sum();
3077 s.backward();
3078
3079 ASSERT_EQ(s.ndimension(), 0);
3080 ASSERT_TRUE(output.allclose(expected));
3081 }
3082 {
3083 for (const auto align_corners : {true, false}) {
3084 // test float scale factor up & down sampling
3085 for (const auto scale_factor : {0.5, 1.5, 2.0}) {
3086 Upsample model(
3087 UpsampleOptions()
3088 .scale_factor(std::vector<double>({scale_factor, scale_factor}))
3089 .mode(torch::kBilinear)
3090 .align_corners(align_corners));
3091 auto input = torch::ones({1, 1, 2, 2}, torch::requires_grad());
3092 auto output = model->forward(input);
3093 auto expected_size =
3094 static_cast<int64_t>(std::floor(input.size(-1) * scale_factor));
3095 auto expected = torch::ones({1, 1, expected_size, expected_size});
3096 auto s = output.sum();
3097 s.backward();
3098
3099 ASSERT_EQ(s.ndimension(), 0);
3100 ASSERT_TRUE(output.allclose(expected));
3101 }
3102 }
3103 }
3104 {
3105 for (const auto align_corners : {true, false}) {
3106 // test float scale factor up & down sampling
3107 for (const auto scale_factor : {0.5, 1.5, 2.0}) {
3108 Upsample model(
3109 UpsampleOptions()
3110 .scale_factor(std::vector<double>({scale_factor, scale_factor}))
3111 .mode(torch::kBicubic)
3112 .align_corners(align_corners));
3113 auto input = torch::ones({1, 1, 2, 2}, torch::requires_grad());
3114 auto output = model->forward(input);
3115 auto expected_size =
3116 static_cast<int64_t>(std::floor(input.size(-1) * scale_factor));
3117 auto expected = torch::ones({1, 1, expected_size, expected_size});
3118 auto s = output.sum();
3119 s.backward();
3120
3121 ASSERT_EQ(s.ndimension(), 0);
3122 ASSERT_TRUE(output.allclose(expected));
3123 }
3124 }
3125 }
3126}
3127
3128TEST_F(ModulesTest, Upsampling3D) {
3129 {
3130 Upsample model(UpsampleOptions()
3131 .size(std::vector<int64_t>({4, 4, 4}))
3132 .mode(torch::kNearest));
3133 auto input = torch::ones({1, 1, 2, 2, 2}, torch::requires_grad());
3134 auto output = model->forward(input);
3135 auto expected = torch::ones({1, 1, 4, 4, 4});
3136 auto s = output.sum();
3137 s.backward();
3138
3139 ASSERT_EQ(s.ndimension(), 0);
3140 ASSERT_TRUE(output.allclose(expected));
3141 }
3142 {
3143 for (const auto align_corners : {true, false}) {
3144 // test float scale factor up & down sampling
3145 for (const auto scale_factor : {0.5, 1.5, 2.0}) {
3146 Upsample model(UpsampleOptions()
3147 .scale_factor(std::vector<double>(
3148 {scale_factor, scale_factor, scale_factor}))
3149 .mode(torch::kTrilinear)
3150 .align_corners(align_corners));
3151 auto input = torch::ones({1, 1, 2, 2, 2}, torch::requires_grad());
3152 auto output = model->forward(input);
3153 auto expected_size =
3154 static_cast<int64_t>(std::floor(input.size(-1) * scale_factor));
3155 auto expected =
3156 torch::ones({1, 1, expected_size, expected_size, expected_size});
3157 auto s = output.sum();
3158 s.backward();
3159
3160 ASSERT_EQ(s.ndimension(), 0);
3161 ASSERT_TRUE(output.allclose(expected));
3162 }
3163 }
3164 }
3165}
3166
3167TEST_F(ModulesTest, CTCLoss) {
3168 CTCLoss loss{CTCLossOptions().reduction(torch::kNone)};
3169 const auto target_lengths = torch::tensor({0, 0, 0});
3170 const auto input_lengths = torch::tensor({50, 50, 50});
3171 const auto targets =
3172 torch::randint(1, 15, at::IntArrayRef({0}), torch::kLong);
3173 const auto log_probs =
3174 torch::randn({50, 3, 15}, torch::kDouble).log_softmax(2);
3175 const auto output =
3176 loss->forward(log_probs, targets, input_lengths, target_lengths);
3177 ASSERT_TRUE(output.ge(0).all().item<bool>());
3178 ASSERT_TRUE(torch::allclose(
3179 -log_probs.sum(0).slice(1, 0, 1).view_as(output), output));
3180}
3181
3182TEST_F(ModulesTest, PoissonNLLLoss) {
3183 const auto input = torch::tensor({0.5, 1.5, 2.5});
3184 const auto target = torch::tensor({1., 2., 3.});
3185 const auto component_wise_loss = torch::exp(input) - target * input;
3186 {
3187 PoissonNLLLoss loss{PoissonNLLLossOptions().reduction(torch::kNone)};
3188 ASSERT_TRUE(
3189 torch::allclose(component_wise_loss, loss->forward(input, target)));
3190 }
3191 {
3192 PoissonNLLLoss loss{PoissonNLLLossOptions().reduction(torch::kSum)};
3193 ASSERT_TRUE(torch::allclose(
3194 torch::sum(component_wise_loss), loss->forward(input, target)));
3195 }
3196 {
3197 PoissonNLLLoss loss{PoissonNLLLossOptions().reduction(torch::kMean)};
3198 ASSERT_TRUE(torch::allclose(
3199 torch::mean(component_wise_loss), loss->forward(input, target)));
3200 }
3201}
3202
3203TEST_F(ModulesTest, MarginRankingLoss) {
3204 {
3205 MarginRankingLoss loss;
3206 const auto input1 = torch::randn(15) * 10;
3207 const auto input2 = torch::randn(15) * 10;
3208 const auto target = torch::randn(15).sign();
3209 ASSERT_TRUE(torch::allclose(
3210 loss->forward(input1, input2, target),
3211 (-target * (input1 - input2)).clamp(0).mean()));
3212 }
3213 {
3214 MarginRankingLoss loss{
3215 MarginRankingLossOptions().margin(0.5).reduction(torch::kSum)};
3216 const auto input1 = torch::randn(15) * 10;
3217 const auto input2 = torch::randn(15) * 10;
3218 const auto target = torch::randn(15).sign();
3219 const auto margin = 0.5;
3220 ASSERT_TRUE(torch::allclose(
3221 loss->forward(input1, input2, target),
3222 (-target * (input1 - input2) + margin).clamp(0).sum()));
3223 }
3224 {
3225 MarginRankingLoss loss{
3226 MarginRankingLossOptions().margin(0.5).reduction(torch::kMean)};
3227 const auto input1 = torch::randn(15) * 10;
3228 const auto input2 = torch::randn(15) * 10;
3229 const auto target = torch::randn(15).sign();
3230 const auto margin = 0.5;
3231 ASSERT_TRUE(torch::allclose(
3232 loss->forward(input1, input2, target),
3233 (-target * (input1 - input2) + margin).clamp(0).mean()));
3234 }
3235}
3236
3237TEST_F(ModulesTest, BCEWithLogitsLoss) {
3238 {// test BCE with logits raises if target and input are different size
3239 {const auto target = torch::rand(5);
3240 const auto input = torch::rand({5, 1});
3241 ASSERT_THROWS_WITH(
3242 BCEWithLogitsLoss()(input, target), "must be the same as input size");
3243}
3244{
3245 const auto target = torch::rand({5, 1});
3246 const auto input = torch::rand(5);
3247 ASSERT_THROWS_WITH(
3248 BCEWithLogitsLoss()(input, target), "must be the same as input size");
3249}
3250}
3251{ // test BCE with logits gives same result as sigmoid and bce loss
3252 auto sigmoid = Sigmoid();
3253
3254 auto target = torch::rand({64, 4});
3255 auto output = torch::rand({64, 4}) - 0.5;
3256
3257 ASSERT_TRUE(torch::allclose(
3258 BCEWithLogitsLoss()(output, target), BCELoss()(sigmoid(output), target)));
3259
3260 auto weight = torch::rand(4);
3261 ASSERT_TRUE(torch::allclose(
3262 BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))(
3263 output, target),
3264 BCELoss(BCELossOptions().weight(weight))(sigmoid(output), target)));
3265
3266 target = torch::zeros({4, 1}, torch::kFloat);
3267 output = torch::empty({4, 1}, torch::kFloat).fill_(-100);
3268
3269 ASSERT_TRUE(torch::allclose(
3270 BCEWithLogitsLoss()(output, target), BCELoss()(sigmoid(output), target)));
3271
3272 ASSERT_TRUE(torch::allclose(
3273 BCEWithLogitsLoss(BCEWithLogitsLossOptions().reduction(torch::kNone))(
3274 output, target),
3275 BCELoss(BCELossOptions().reduction(torch::kNone))(
3276 sigmoid(output), target)));
3277
3278 weight = torch::rand({1}, torch::kFloat);
3279 ASSERT_TRUE(torch::allclose(
3280 BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))(
3281 output, target),
3282 BCELoss(BCELossOptions().weight(weight))(sigmoid(output), target)));
3283}
3284{ // test BCE with logits has correct grad at zero
3285 const auto output = torch::zeros({3, 1}, torch::requires_grad());
3286 const auto target = torch::zeros({3, 1});
3287 BCEWithLogitsLoss(BCEWithLogitsLossOptions().reduction(torch::kSum))(
3288 output, target)
3289 .backward();
3290 const auto expected_grad = torch::empty({3, 1}).fill_(0.5);
3291 ASSERT_TRUE(torch::allclose(output.grad(), expected_grad));
3292}
3293{ // test BCE with logits broadcasts weights
3294 const auto target = torch::rand({16, 4});
3295 const auto output = torch::rand({16, 4}) - 0.5;
3296
3297 auto weight = torch::rand(4);
3298 auto out1 = BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))(
3299 output, target);
3300
3301 weight = weight.expand({16, 4}).contiguous();
3302 auto out2 = BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))(
3303 output, target);
3304
3305 ASSERT_TRUE(torch::allclose(out1, out2));
3306
3307 weight = torch::rand({16, 1});
3308 out1 = BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))(
3309 output, target);
3310
3311 weight = weight.expand({16, 4}).contiguous();
3312 out2 = BCEWithLogitsLoss(BCEWithLogitsLossOptions().weight(weight))(
3313 output, target);
3314
3315 ASSERT_TRUE(torch::allclose(out1, out2));
3316}
3317{ // test BCE with logits ones in pos weights are the same as none
3318 const auto target = torch::rand({64, 4});
3319 const auto output = torch::rand({64, 4}) - 0.5;
3320 const auto pos_weight = torch::ones({64, 4});
3321
3322 ASSERT_TRUE(torch::allclose(
3323 BCEWithLogitsLoss()(output, target),
3324 BCEWithLogitsLoss(BCEWithLogitsLossOptions().pos_weight(pos_weight))(
3325 output, target)));
3326}
3327{ // test BCE with logits broadcasts pos weights
3328 const auto target = torch::rand({64, 4});
3329 const auto output = torch::rand({64, 4}) - 0.5;
3330 const auto pos_weight = torch::rand(4);
3331 const auto out1 = BCEWithLogitsLoss(
3332 BCEWithLogitsLossOptions().pos_weight(pos_weight))(output, target);
3333
3334 const auto pos_weight1 = pos_weight.expand({1, 4});
3335 const auto out2 = BCEWithLogitsLoss(
3336 BCEWithLogitsLossOptions().pos_weight(pos_weight))(output, target);
3337
3338 const auto pos_weight2 = pos_weight.expand({64, 4});
3339 const auto out3 = BCEWithLogitsLoss(
3340 BCEWithLogitsLossOptions().pos_weight(pos_weight))(output, target);
3341
3342 ASSERT_TRUE(torch::allclose(out1, out2));
3343 ASSERT_TRUE(torch::allclose(out1, out3));
3344}
3345{ // test BCE with logits with pos weight has correct grad at zero
3346 const auto output = torch::zeros({3, 1}, torch::requires_grad());
3347 const auto target = torch::zeros({3, 1});
3348 const auto pos_weight = torch::ones({3, 1});
3349 BCEWithLogitsLoss(
3350 BCEWithLogitsLossOptions().pos_weight(pos_weight).reduction(torch::kSum))(
3351 output, target)
3352 .backward();
3353 const auto expected_grad = torch::empty({3, 1}).fill_(0.5);
3354 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3355 const auto grad = output.grad();
3356 ASSERT_TRUE(torch::allclose(grad, expected_grad));
3357}
3358{ // test BCE with logits stability
3359 const auto output = torch::tensor({0., -120.});
3360 const auto target = torch::tensor({0., 1.});
3361 const auto pos_weight = torch::tensor({1., 1.});
3362
3363 const auto out1 = BCEWithLogitsLoss()(output, target);
3364 ASSERT_TRUE(torch::isfinite(out1).all().item<bool>());
3365
3366 const auto out2 = BCEWithLogitsLoss(
3367 BCEWithLogitsLossOptions().pos_weight(pos_weight))(output, target);
3368 ASSERT_TRUE(torch::isfinite(out2).all().item<bool>());
3369}
3370}
3371
3372namespace detail {
3373
3374namespace F = torch::nn::functional;
3375
3376torch::Tensor _batchmatmul(const torch::Tensor& a, const torch::Tensor& b) {
3377 TORCH_INTERNAL_ASSERT(a.size(0) == b.size(0));
3378 TORCH_INTERNAL_ASSERT(a.size(1) == b.size(1));
3379 auto retval = torch::zeros(
3380 {a.size(0), a.size(1), a.size(2), b.size(3)}, torch::kFloat32);
3381 for (const auto i : c10::irange(a.size(0))) {
3382 for (const auto j : c10::irange(a.size(1))) {
3383 retval[i][j] = torch::matmul(a[i][j], b[i][j]);
3384 }
3385 }
3386 return retval;
3387}
3388
3389torch::Tensor _softmax(const torch::Tensor& x) {
3390 auto output = torch::zeros(x.sizes());
3391 for (const auto i : c10::irange(x.size(0))) {
3392 for (const auto j : c10::irange(x.size(1))) {
3393 for (const auto k : c10::irange(x.size(2))) {
3394 const auto& x_curr = x[i][j][k];
3395 const auto e_x = torch::exp(x_curr - torch::max(x_curr));
3396 output[i][j][k] = e_x / torch::sum(e_x);
3397 }
3398 }
3399 }
3400 return output;
3401}
3402
3403std::tuple<torch::Tensor, torch::Tensor> _scaled_dot_attn_ref(
3404 const torch::Tensor& Q,
3405 const torch::Tensor& K,
3406 const torch::Tensor& V,
3407 at::IntArrayRef dims,
3408 const torch::Tensor& unseen_mask = {},
3409 const torch::Tensor& key_padding_mask = {},
3410 bool average_attn_weights = true) {
3411 auto QKT = _batchmatmul(Q, K.permute({0, 1, 3, 2}) / std::sqrt(dims[3]));
3412 const auto b1 = QKT.size(0);
3413 const auto b2 = QKT.size(1);
3414 const auto s1 = QKT.size(2);
3415 const auto s2 = QKT.size(3);
3416 if (unseen_mask.defined() || key_padding_mask.defined()) {
3417 for (const auto i : c10::irange(b1)) {
3418 for (const auto j : c10::irange(b2)) {
3419 for (const auto m : c10::irange(s1)) {
3420 for (const auto n : c10::irange(s2)) {
3421 if (unseen_mask.defined() &&
3422 unseen_mask[m][n].item<double>() == 0) {
3423 QKT[i][j][m][n] = -std::numeric_limits<double>::infinity();
3424 }
3425 if (key_padding_mask.defined() &&
3426 key_padding_mask[i][n].item<double>() != 0) {
3427 QKT[i][j][m][n] = -std::numeric_limits<double>::infinity();
3428 }
3429 }
3430 }
3431 }
3432 }
3433 }
3434 auto reference = _softmax(QKT);
3435 auto ref_attn_weight = reference;
3436 if (average_attn_weights) {
3437 // NOLINTNEXTLINE(bugprone-argument-comment)
3438 ref_attn_weight = torch::sum(ref_attn_weight, /*axis=*/1) / b2;
3439 }
3440 reference = _batchmatmul(reference, V);
3441 return std::tie(reference, ref_attn_weight);
3442}
3443
3444torch::Tensor _split_heads_ref(
3445 const torch::Tensor& X,
3446 at::IntArrayRef dims,
3447 int nheads,
3448 int d_head) {
3449 auto X_split = X.reshape({dims[0], dims[1], nheads, d_head});
3450 auto X_split_transposed = X_split.permute({0, 2, 1, 3});
3451 return X_split_transposed.reshape({dims[0], nheads, dims[1], d_head});
3452}
3453
3454torch::Tensor _combine_heads_ref(
3455 const torch::Tensor& X,
3456 at::IntArrayRef dims,
3457 int nheads,
3458 int d_head) {
3459 auto X_transposed = X.permute({0, 2, 1, 3});
3460 auto reference = X_transposed.reshape({dims[0], dims[1], nheads * d_head});
3461 return reference;
3462}
3463
3464torch::Tensor _fc(
3465 torch::Tensor X,
3466 torch::Tensor X_weight,
3467 torch::Tensor X_bias) {
3468 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3469 auto X_fc_b = X_bias;
3470 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3471 auto X_fc_w = X_weight;
3472 return torch::matmul(X, torch::t(X_fc_w)) + X_fc_b;
3473}
3474
3475void _multihead_attn_test_helper(
3476 bool add_key_padding_mask = false,
3477 bool add_bias_kv = false,
3478 bool add_zero_attn = false,
3479 bool saved_kv = false,
3480 bool same_embed_dim = false,
3481 bool average_attn_weights = true) {
3482 std::random_device device;
3483 std::mt19937 generator(device());
3484 std::uniform_int_distribution<int> d_2_10(2, 10);
3485 std::uniform_int_distribution<int> d_3_10(3, 10);
3486 bool registration_checked = false;
3487 for (const auto i : c10::irange(100)) {
3488 (void)i; // Suppress unused variable warning
3489 const auto batch_sz = d_2_10(generator);
3490 const auto seq_len = d_2_10(generator);
3491 const auto d_head = d_3_10(generator);
3492 const auto nheads = d_3_10(generator);
3493 const auto d_model = d_head * nheads;
3494 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3495 int kv_dim;
3496 if (same_embed_dim) {
3497 kv_dim = d_model;
3498 } else {
3499 std::uniform_int_distribution<int> d(5, 20);
3500 kv_dim = d(generator);
3501 while (kv_dim == d_model) {
3502 kv_dim = d(generator);
3503 }
3504 }
3505 std::vector<int64_t> dims{batch_sz, seq_len, kv_dim};
3506 torch::Tensor saved_k;
3507 torch::Tensor saved_k_tensor;
3508 torch::Tensor saved_v;
3509 torch::Tensor saved_v_tensor;
3510 if (saved_kv) {
3511 saved_k = torch::rand({batch_sz * nheads, seq_len, d_head});
3512 saved_k_tensor = saved_k;
3513 saved_v = torch::rand({batch_sz * nheads, seq_len, d_head});
3514 saved_v_tensor = saved_v;
3515 }
3516 torch::Tensor key_padding_mask;
3517 torch::Tensor key_padding_mask_tensor;
3518 if (add_key_padding_mask) {
3519 const auto seq_mask = torch::randint(0, 2, {1, seq_len});
3520 key_padding_mask = seq_mask.repeat({batch_sz, 1}) == 1;
3521 key_padding_mask_tensor = key_padding_mask;
3522 }
3523 const auto decoder_state = torch::rand({batch_sz, d_model});
3524 const torch::Tensor K = torch::rand(dims);
3525 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3526 const torch::Tensor V = K;
3527 const torch::Tensor Q =
3528 decoder_state.clone().resize_({batch_sz, 1, d_model});
3529 auto attn_mask = torch::randint(0, 2, {1, seq_len}, torch::kFloat);
3530 const torch::Tensor attn_mask_tensor = attn_mask.clone();
3531 attn_mask_tensor.masked_fill_(
3532 attn_mask_tensor == 0, -std::numeric_limits<double>::infinity());
3533 attn_mask_tensor.masked_fill_(attn_mask_tensor > 0, double(0.0));
3534
3535 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3536 const torch::Tensor decoder_state_tensor = decoder_state;
3537 const torch::Tensor source_hid_tensor = K.transpose(0, 1);
3538
3539 const auto options = MultiheadAttentionOptions(d_model, nheads)
3540 .add_bias_kv(add_bias_kv)
3541 .add_zero_attn(add_zero_attn)
3542 .kdim(kv_dim)
3543 .vdim(kv_dim);
3544 const auto multihead_attn_module = MultiheadAttention(options);
3545
3546 if (!registration_checked) {
3547 // make sure parameters are all registered correctly
3548 auto named_parameters = multihead_attn_module->named_parameters();
3549 if (same_embed_dim) {
3550 ASSERT_TRUE(named_parameters.contains("in_proj_weight"));
3551 } else {
3552 ASSERT_TRUE(named_parameters.contains("q_proj_weight"));
3553 ASSERT_TRUE(named_parameters.contains("k_proj_weight"));
3554 ASSERT_TRUE(named_parameters.contains("v_proj_weight"));
3555 }
3556 if (add_bias_kv) {
3557 ASSERT_TRUE(named_parameters.contains("bias_k"));
3558 ASSERT_TRUE(named_parameters.contains("bias_v"));
3559 }
3560 // make sure sub modules are all registered correctly
3561 auto submodules = multihead_attn_module->named_children();
3562 ASSERT_TRUE(submodules.contains("out_proj"));
3563 registration_checked = true;
3564 }
3565
3566 torch::Tensor bias_k;
3567 torch::Tensor bias_v;
3568 if (add_bias_kv) {
3569 bias_k = multihead_attn_module->bias_k.detach();
3570 bias_v = multihead_attn_module->bias_v.detach();
3571 } else {
3572 bias_k.reset();
3573 bias_v.reset();
3574 }
3575
3576 torch::Tensor _Q = decoder_state_tensor.unsqueeze(1).transpose(0, 1);
3577 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3578 torch::Tensor _V = source_hid_tensor;
3579 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3580 torch::Tensor _K = source_hid_tensor;
3581
3582 torch::Tensor result;
3583 torch::Tensor result_weight;
3584 if (multihead_attn_module->_qkv_same_embed_dim) {
3585 std::tie(result, result_weight) = F::multi_head_attention_forward(
3586 _Q,
3587 _K,
3588 _V,
3589 F::MultiheadAttentionForwardFuncOptions(
3590 /*embed_dim_to_check=*/d_model,
3591 /*num_heads=*/nheads,
3592 /*in_proj_weight=*/multihead_attn_module->in_proj_weight,
3593 /*in_proj_bias=*/multihead_attn_module->in_proj_bias,
3594 /*bias_k=*/multihead_attn_module->bias_k,
3595 /*bias_v=*/multihead_attn_module->bias_v,
3596 /*add_zero_attn=*/multihead_attn_module->options.add_zero_attn(),
3597 /*dropout_p=*/multihead_attn_module->options.dropout(),
3598 /*out_proj_weight=*/multihead_attn_module->out_proj->weight,
3599 /*out_proj_bias=*/multihead_attn_module->out_proj->bias)
3600 .training(multihead_attn_module->is_training())
3601 .key_padding_mask(key_padding_mask_tensor)
3602 .need_weights(true)
3603 .attn_mask(attn_mask_tensor)
3604 .static_k(saved_k_tensor)
3605 .static_v(saved_v_tensor)
3606 .average_attn_weights(average_attn_weights));
3607 } else {
3608 std::tie(result, result_weight) = F::multi_head_attention_forward(
3609 _Q,
3610 _K,
3611 _V,
3612 F::MultiheadAttentionForwardFuncOptions(
3613 /*embed_dim_to_check=*/d_model,
3614 /*num_heads=*/nheads,
3615 /*in_proj_weight=*/{},
3616 /*in_proj_bias=*/multihead_attn_module->in_proj_bias,
3617 /*bias_k=*/multihead_attn_module->bias_k,
3618 /*bias_v=*/multihead_attn_module->bias_v,
3619 /*add_zero_attn=*/multihead_attn_module->options.add_zero_attn(),
3620 /*dropout_p=*/multihead_attn_module->options.dropout(),
3621 /*out_proj_weight=*/multihead_attn_module->out_proj->weight,
3622 /*out_proj_bias=*/multihead_attn_module->out_proj->bias)
3623 .training(multihead_attn_module->is_training())
3624 .key_padding_mask(key_padding_mask_tensor)
3625 .need_weights(true)
3626 .attn_mask(attn_mask_tensor)
3627 .use_separate_proj_weight(true)
3628 .q_proj_weight(multihead_attn_module->q_proj_weight)
3629 .k_proj_weight(multihead_attn_module->k_proj_weight)
3630 .v_proj_weight(multihead_attn_module->v_proj_weight)
3631 .static_k(saved_k_tensor)
3632 .static_v(saved_v_tensor)
3633 .average_attn_weights(average_attn_weights));
3634 }
3635 result = result.squeeze(0).detach();
3636 torch::Tensor q_proj_weight;
3637 torch::Tensor k_proj_weight;
3638 torch::Tensor v_proj_weight;
3639 if (multihead_attn_module->_qkv_same_embed_dim) {
3640 q_proj_weight =
3641 multihead_attn_module->in_proj_weight.slice(/*dim=*/0, 0, d_model);
3642 k_proj_weight = multihead_attn_module->in_proj_weight.slice(
3643 /*dim=*/0, d_model, (d_model * 2));
3644 v_proj_weight =
3645 multihead_attn_module->in_proj_weight.slice(/*dim=*/0, (d_model * 2));
3646 } else {
3647 q_proj_weight = multihead_attn_module->q_proj_weight;
3648 k_proj_weight = multihead_attn_module->k_proj_weight;
3649 v_proj_weight = multihead_attn_module->v_proj_weight;
3650 }
3651 auto Q_fc =
3652 _fc(Q,
3653 q_proj_weight,
3654 multihead_attn_module->in_proj_bias.slice(/*dim=*/0, 0, d_model));
3655 auto K_fc =
3656 _fc(K,
3657 k_proj_weight,
3658 multihead_attn_module->in_proj_bias.slice(
3659 /*dim=*/0, d_model, (d_model * 2)));
3660 auto V_fc = _fc(
3661 V,
3662 v_proj_weight,
3663 multihead_attn_module->in_proj_bias.slice(/*dim=*/0, (d_model * 2)));
3664
3665 if (add_bias_kv) {
3666 K_fc = torch::cat(
3667 {K_fc,
3668 bias_k.repeat({K_fc.size(0) / bias_k.size(0), 1, 1} /*, axis=0*/)},
3669 /*dim=*/1);
3670 V_fc = torch::cat(
3671 {V_fc,
3672 bias_v.repeat({V_fc.size(0) / bias_v.size(0), 1, 1} /*, axis=0*/)},
3673 /*dim=*/1);
3674 if (attn_mask.defined()) {
3675 attn_mask = torch::cat({attn_mask, torch::ones({1, 1})}, /*dim=*/1);
3676 }
3677 if (key_padding_mask.defined()) {
3678 key_padding_mask = torch::cat(
3679 {key_padding_mask, torch::full({batch_sz, 1}, false, torch::kBool)},
3680 /*dim=*/1);
3681 }
3682 dims[1] += 1;
3683 }
3684 const auto Q_split =
3685 _split_heads_ref(Q_fc, {batch_sz, 1, d_model}, nheads, d_head);
3686 torch::Tensor K_split;
3687 if (saved_k.defined()) {
3688 K_split = saved_k.reshape({dims[0], nheads, dims[1], d_head});
3689 } else {
3690 K_split = _split_heads_ref(K_fc, dims, nheads, d_head);
3691 }
3692 torch::Tensor V_split;
3693 if (saved_v.defined()) {
3694 V_split = saved_v.reshape({dims[0], nheads, dims[1], d_head});
3695 } else {
3696 V_split = _split_heads_ref(V_fc, dims, nheads, d_head);
3697 }
3698 if (add_zero_attn) {
3699 dims[1] += 1;
3700 K_split = torch::cat(
3701 {K_split,
3702 torch::zeros(
3703 {K_split.size(0), K_split.size(1), 1, K_split.size(3)})},
3704 /*dim=*/2);
3705 V_split = torch::cat(
3706 {V_split,
3707 torch::zeros(
3708 {V_split.size(0), V_split.size(1), 1, V_split.size(3)})},
3709 /*dim=*/2);
3710 if (attn_mask.defined()) {
3711 attn_mask = torch::cat({attn_mask, torch::ones({1, 1})}, /*dim=*/1);
3712 }
3713 if (key_padding_mask.defined()) {
3714 key_padding_mask = torch::cat(
3715 {key_padding_mask, torch::full({batch_sz, 1}, false, torch::kBool)},
3716 /*dim=*/1);
3717 }
3718 }
3719 torch::Tensor attn_heads;
3720 torch::Tensor ref_attn_weight;
3721 std::tie(attn_heads, ref_attn_weight) = _scaled_dot_attn_ref(
3722 Q_split,
3723 K_split,
3724 V_split,
3725 Q_split.sizes(),
3726 attn_mask,
3727 key_padding_mask,
3728 average_attn_weights);
3729 const auto combined_attn_heads =
3730 _combine_heads_ref(attn_heads, {batch_sz, 1}, nheads, d_head);
3731 auto reference =
3732 _fc(combined_attn_heads,
3733 multihead_attn_module->out_proj->weight,
3734 multihead_attn_module->out_proj->bias);
3735 // NOLINTNEXTLINE(bugprone-argument-comment)
3736 reference = torch::squeeze(reference, /*axis=*/1);
3737
3738 // result = reference
3739 ASSERT_EQ(result.sizes(), std::vector<int64_t>({batch_sz, d_model}));
3740 ASSERT_TRUE(
3741 torch::allclose(result, reference, 1e-5, 1e-5, /*equal_nan=*/true));
3742
3743 // result_weight = ref_attn_weight
3744 result_weight = result_weight.detach();
3745 ASSERT_EQ(result_weight.sizes(), ref_attn_weight.sizes());
3746 ASSERT_TRUE(torch::allclose(
3747 result_weight, ref_attn_weight, 1e-5, 1e-5, /*equal_nan=*/true));
3748 }
3749}
3750} // namespace detail
3751
3752TEST_F(ModulesTest, MultiheadAttention) {
3753 using namespace ::detail;
3754
3755 for (auto average_attn_weights : {false, true}) {
3756 // test_multihead_attn_add_zero_attn
3757 _multihead_attn_test_helper(
3758 /*add_key_padding_mask=*/false,
3759 /*add_bias_kv=*/false,
3760 /*add_zero_attn=*/true,
3761 /*saved_kv=*/false,
3762 /*same_embed_dim=*/false,
3763 /*average_attn_weights=*/average_attn_weights);
3764
3765 // test_multihead_attn_add_bias_kv
3766 _multihead_attn_test_helper(
3767 /*add_key_padding_mask=*/false,
3768 /*add_bias_kv=*/true,
3769 /*add_zero_attn=*/false,
3770 /*saved_kv=*/false,
3771 /*same_embed_dim=*/false,
3772 /*average_attn_weights=*/average_attn_weights);
3773
3774 // test_multihead_attn_no_masking():
3775 _multihead_attn_test_helper();
3776
3777 // test_multihead_attn_key_padding_mask
3778 _multihead_attn_test_helper(
3779 /*add_key_padding_mask=*/true,
3780 /*add_bias_kv=*/false,
3781 /*add_zero_attn=*/false,
3782 /*saved_kv=*/false,
3783 /*same_embed_dim=*/false,
3784 /*average_attn_weights=*/average_attn_weights);
3785
3786 // test_multihead_attn_saved_kv
3787 _multihead_attn_test_helper(
3788 /*add_key_padding_mask=*/false,
3789 /*add_bias_kv=*/false,
3790 /*add_zero_attn=*/false,
3791 /*saved_kv=*/true,
3792 /*same_embed_dim=*/false,
3793 /*average_attn_weights=*/average_attn_weights);
3794
3795 // test_multihead_attn_add_bias_kv_zero_attn
3796 _multihead_attn_test_helper(
3797 /*add_key_padding_mask=*/true,
3798 /*add_bias_kv=*/true,
3799 /*add_zero_attn=*/true,
3800 /*saved_kv=*/false,
3801 /*same_embed_dim=*/false,
3802 /*average_attn_weights=*/average_attn_weights);
3803
3804 // test_multihead_attn_all_arguments1
3805 _multihead_attn_test_helper(
3806 /*add_key_padding_mask=*/true,
3807 /*add_bias_kv=*/false,
3808 /*add_zero_attn=*/true,
3809 /*saved_kv=*/true,
3810 /*same_embed_dim=*/false,
3811 /*average_attn_weights=*/average_attn_weights);
3812
3813 ASSERT_THROWS_WITH(
3814 // test_multihead_attn_all_arguments2
3815 _multihead_attn_test_helper(
3816 /*add_key_padding_mask=*/true,
3817 /*add_bias_kv=*/true,
3818 /*add_zero_attn=*/true,
3819 /*saved_kv=*/true,
3820 /*same_embed_dim=*/false,
3821 /*average_attn_weights=*/average_attn_weights),
3822 "bias cannot be added to static key");
3823
3824 // test_multihead_attn_all_arguments3
3825 _multihead_attn_test_helper(
3826 /*add_key_padding_mask=*/true,
3827 /*add_bias_kv=*/false,
3828 /*add_zero_attn=*/true,
3829 /*saved_kv=*/true,
3830 /*same_embed_dim=*/true,
3831 /*average_attn_weights=*/average_attn_weights);
3832 }
3833}
3834
3835TEST_F(ModulesTest, PrettyPrintIdentity) {
3836 ASSERT_EQ(c10::str(Identity()), "torch::nn::Identity()");
3837}
3838
3839TEST_F(ModulesTest, PrettyPrintFlatten) {
3840 ASSERT_EQ(c10::str(Flatten()), "torch::nn::Flatten(start_dim=1, end_dim=-1)");
3841 ASSERT_EQ(
3842 c10::str(Flatten(FlattenOptions().start_dim(2).end_dim(4))),
3843 "torch::nn::Flatten(start_dim=2, end_dim=4)");
3844}
3845
3846TEST_F(ModulesTest, PrettyPrintUnflatten) {
3847 ASSERT_EQ(
3848 c10::str(Unflatten(UnflattenOptions(0, {2, 2}))),
3849 "torch::nn::Unflatten(dim=0, unflattened_size={2, 2})");
3850 ASSERT_EQ(
3851 c10::str(Unflatten(UnflattenOptions(
3852 "B",
3853 {std::pair<std::string, int64_t>{"B1", 2},
3854 std::pair<std::string, int64_t>{"B2", 2}}))),
3855 "torch::nn::Unflatten(dim=\"B\", unflattened_size={{\"B1\", 2}, {\"B2\", 2}})");
3856}
3857
3858TEST_F(ModulesTest, ReflectionPad1d) {
3859 {
3860 ReflectionPad1d m(ReflectionPad1dOptions(2));
3861 auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
3862 auto output = m(input);
3863 auto expected = torch::tensor(
3864 {{{2., 1., 0., 1., 2., 3., 2., 1.}, {6., 5., 4., 5., 6., 7., 6., 5.}}},
3865 torch::kFloat);
3866 ASSERT_TRUE(output.allclose(expected));
3867 }
3868 {
3869 ReflectionPad1d m(ReflectionPad1dOptions({3, 1}));
3870 auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
3871 auto output = m(input);
3872 auto expected = torch::tensor(
3873 {{{3., 2., 1., 0., 1., 2., 3., 2.}, {7., 6., 5., 4., 5., 6., 7., 6.}}},
3874 torch::kFloat);
3875 ASSERT_TRUE(output.allclose(expected));
3876 }
3877}
3878
3879TEST_F(ModulesTest, ReflectionPad2d) {
3880 {
3881 ReflectionPad2d m(ReflectionPad2dOptions(2));
3882 auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3});
3883 auto output = m(input);
3884 auto expected = torch::tensor(
3885 {{{{8., 7., 6., 7., 8., 7., 6.},
3886 {5., 4., 3., 4., 5., 4., 3.},
3887 {2., 1., 0., 1., 2., 1., 0.},
3888 {5., 4., 3., 4., 5., 4., 3.},
3889 {8., 7., 6., 7., 8., 7., 6.},
3890 {5., 4., 3., 4., 5., 4., 3.},
3891 {2., 1., 0., 1., 2., 1., 0.}}}},
3892 torch::kFloat);
3893 ASSERT_TRUE(output.allclose(expected));
3894 }
3895 {
3896 ReflectionPad2d m(ReflectionPad2dOptions({1, 1, 2, 0}));
3897 auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3});
3898 auto output = m(input);
3899 auto expected = torch::tensor(
3900 {{{{7., 6., 7., 8., 7.},
3901 {4., 3., 4., 5., 4.},
3902 {1., 0., 1., 2., 1.},
3903 {4., 3., 4., 5., 4.},
3904 {7., 6., 7., 8., 7.}}}},
3905 torch::kFloat);
3906 ASSERT_TRUE(output.allclose(expected));
3907 }
3908}
3909
3910TEST_F(ModulesTest, ReflectionPad3d) {
3911 {
3912 ReflectionPad3d m(ReflectionPad3dOptions(1));
3913 auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
3914 auto output = m(input);
3915 auto expected = torch::tensor(
3916 {{{{{7., 6., 7., 6.},
3917 {5., 4., 5., 4.},
3918 {7., 6., 7., 6.},
3919 {5., 4., 5., 4.}},
3920 {{3., 2., 3., 2.},
3921 {1., 0., 1., 0.},
3922 {3., 2., 3., 2.},
3923 {1., 0., 1., 0.}},
3924 {{7., 6., 7., 6.},
3925 {5., 4., 5., 4.},
3926 {7., 6., 7., 6.},
3927 {5., 4., 5., 4.}},
3928 {{3., 2., 3., 2.},
3929 {1., 0., 1., 0.},
3930 {3., 2., 3., 2.},
3931 {1., 0., 1., 0.}}}}},
3932 torch::kFloat);
3933 ASSERT_TRUE(output.allclose(expected));
3934 }
3935 {
3936 ReflectionPad3d m(ReflectionPad3dOptions({0, 1, 1, 0, 1, 2}));
3937 auto input = torch::arange(16, torch::kFloat).reshape({1, 1, 4, 2, 2});
3938 auto output = m(input);
3939 auto expected = torch::tensor(
3940 {{{{{6., 7., 6.}, {4., 5., 4.}, {6., 7., 6.}},
3941 {{2., 3., 2.}, {0., 1., 0.}, {2., 3., 2.}},
3942 {{6., 7., 6.}, {4., 5., 4.}, {6., 7., 6.}},
3943 {{10., 11., 10.}, {8., 9., 8.}, {10., 11., 10.}},
3944 {{14., 15., 14.}, {12., 13., 12.}, {14., 15., 14.}},
3945 {{10., 11., 10.}, {8., 9., 8.}, {10., 11., 10.}},
3946 {{6., 7., 6.}, {4., 5., 4.}, {6., 7., 6.}}}}},
3947 torch::kFloat);
3948 ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 7, 3, 3}));
3949 ASSERT_TRUE(output.allclose(expected));
3950 }
3951}
3952TEST_F(ModulesTest, ReplicationPad1d) {
3953 {
3954 ReplicationPad1d m(ReplicationPad1dOptions(2));
3955 auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
3956 auto output = m(input);
3957 auto expected = torch::tensor(
3958 {{{0., 0., 0., 1., 2., 3., 3., 3.}, {4., 4., 4., 5., 6., 7., 7., 7.}}},
3959 torch::kFloat);
3960 ASSERT_TRUE(output.allclose(expected));
3961 }
3962 {
3963 ReplicationPad1d m(ReplicationPad1dOptions({3, 1}));
3964 auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
3965 auto output = m(input);
3966 auto expected = torch::tensor(
3967 {{{0., 0., 0., 0., 1., 2., 3., 3.}, {4., 4., 4., 4., 5., 6., 7., 7.}}},
3968 torch::kFloat);
3969 ASSERT_TRUE(output.allclose(expected));
3970 }
3971}
3972
3973TEST_F(ModulesTest, ReplicationPad2d) {
3974 {
3975 ReplicationPad2d m(ReplicationPad2dOptions(2));
3976 auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3});
3977 auto output = m(input);
3978 auto expected = torch::tensor(
3979 {{{{0., 0., 0., 1., 2., 2., 2.},
3980 {0., 0., 0., 1., 2., 2., 2.},
3981 {0., 0., 0., 1., 2., 2., 2.},
3982 {3., 3., 3., 4., 5., 5., 5.},
3983 {6., 6., 6., 7., 8., 8., 8.},
3984 {6., 6., 6., 7., 8., 8., 8.},
3985 {6., 6., 6., 7., 8., 8., 8.}}}},
3986 torch::kFloat);
3987 ASSERT_TRUE(output.allclose(expected));
3988 }
3989 {
3990 ReplicationPad2d m(ReplicationPad2dOptions({1, 1, 2, 0}));
3991 auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3});
3992 auto output = m(input);
3993 auto expected = torch::tensor(
3994 {{{{0., 0., 1., 2., 2.},
3995 {0., 0., 1., 2., 2.},
3996 {0., 0., 1., 2., 2.},
3997 {3., 3., 4., 5., 5.},
3998 {6., 6., 7., 8., 8.}}}},
3999 torch::kFloat);
4000 ASSERT_TRUE(output.allclose(expected));
4001 }
4002}
4003
4004TEST_F(ModulesTest, ReplicationPad3d) {
4005 {
4006 ReplicationPad3d m(ReplicationPad3dOptions(1));
4007 auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
4008 auto output = m(input);
4009 auto expected = torch::tensor(
4010 {{{{{0., 0., 1., 1.},
4011 {0., 0., 1., 1.},
4012 {2., 2., 3., 3.},
4013 {2., 2., 3., 3.}},
4014 {{0., 0., 1., 1.},
4015 {0., 0., 1., 1.},
4016 {2., 2., 3., 3.},
4017 {2., 2., 3., 3.}},
4018 {{4., 4., 5., 5.},
4019 {4., 4., 5., 5.},
4020 {6., 6., 7., 7.},
4021 {6., 6., 7., 7.}},
4022 {{4., 4., 5., 5.},
4023 {4., 4., 5., 5.},
4024 {6., 6., 7., 7.},
4025 {6., 6., 7., 7.}}}}},
4026 torch::kFloat);
4027 ASSERT_TRUE(output.allclose(expected));
4028 }
4029 {
4030 ReplicationPad3d m(ReplicationPad3dOptions({1, 2, 1, 2, 1, 2}));
4031 auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
4032 auto output = m(input);
4033 auto expected = torch::tensor(
4034 {{{{{0., 0., 1., 1., 1.},
4035 {0., 0., 1., 1., 1.},
4036 {2., 2., 3., 3., 3.},
4037 {2., 2., 3., 3., 3.},
4038 {2., 2., 3., 3., 3.}},
4039 {{0., 0., 1., 1., 1.},
4040 {0., 0., 1., 1., 1.},
4041 {2., 2., 3., 3., 3.},
4042 {2., 2., 3., 3., 3.},
4043 {2., 2., 3., 3., 3.}},
4044 {{4., 4., 5., 5., 5.},
4045 {4., 4., 5., 5., 5.},
4046 {6., 6., 7., 7., 7.},
4047 {6., 6., 7., 7., 7.},
4048 {6., 6., 7., 7., 7.}},
4049 {{4., 4., 5., 5., 5.},
4050 {4., 4., 5., 5., 5.},
4051 {6., 6., 7., 7., 7.},
4052 {6., 6., 7., 7., 7.},
4053 {6., 6., 7., 7., 7.}},
4054 {{4., 4., 5., 5., 5.},
4055 {4., 4., 5., 5., 5.},
4056 {6., 6., 7., 7., 7.},
4057 {6., 6., 7., 7., 7.},
4058 {6., 6., 7., 7., 7.}}}}},
4059 torch::kFloat);
4060 ASSERT_TRUE(output.allclose(expected));
4061 }
4062}
4063
4064TEST_F(ModulesTest, ZeroPad2d) {
4065 {
4066 ZeroPad2d m(ZeroPad2dOptions(2));
4067 auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3});
4068 auto output = m(input);
4069 auto expected = torch::tensor(
4070 {{{{0., 0., 0., 0., 0., 0., 0.},
4071 {0., 0., 0., 0., 0., 0., 0.},
4072 {0., 0., 0., 1., 2., 0., 0.},
4073 {0., 0., 3., 4., 5., 0., 0.},
4074 {0., 0., 6., 7., 8., 0., 0.},
4075 {0., 0., 0., 0., 0., 0., 0.},
4076 {0., 0., 0., 0., 0., 0., 0.}}}},
4077 torch::kFloat);
4078 ASSERT_TRUE(output.allclose(expected));
4079 }
4080 {
4081 ZeroPad2d m(ZeroPad2dOptions({1, 1, 2, 0}));
4082 auto input = torch::arange(9, torch::kFloat).reshape({1, 1, 3, 3});
4083 auto output = m(input);
4084 auto expected = torch::tensor(
4085 {{{{0., 0., 0., 0., 0.},
4086 {0., 0., 0., 0., 0.},
4087 {0., 0., 1., 2., 0.},
4088 {0., 3., 4., 5., 0.},
4089 {0., 6., 7., 8., 0.}}}},
4090 torch::kFloat);
4091 ASSERT_TRUE(output.allclose(expected));
4092 }
4093}
4094
4095TEST_F(ModulesTest, ConstantPad1d) {
4096 {
4097 ConstantPad1d m(ConstantPad1dOptions(2, 3.5));
4098 auto input = torch::arange(8, torch::kFloat).reshape({1, 2, 4});
4099 auto output = m(input);
4100 auto expected = torch::tensor(
4101 {{{3.5000, 3.5000, 0.0000, 1.0000, 2.0000, 3.0000, 3.5000, 3.5000},
4102 {3.5000, 3.5000, 4.0000, 5.0000, 6.0000, 7.0000, 3.5000, 3.5000}}},
4103 torch::kFloat);
4104 ASSERT_TRUE(output.allclose(expected));
4105 }
4106 {
4107 ConstantPad1d m(ConstantPad1dOptions({3, 1}, 3.5));
4108 auto input = torch::arange(6, torch::kFloat).reshape({1, 2, 3});
4109 auto output = m(input);
4110 auto expected = torch::tensor(
4111 {{{3.5000, 3.5000, 3.5000, 0.0000, 1.0000, 2.0000, 3.5000},
4112 {3.5000, 3.5000, 3.5000, 3.0000, 4.0000, 5.0000, 3.5000}}},
4113 torch::kFloat);
4114 ASSERT_TRUE(output.allclose(expected));
4115 }
4116}
4117
4118TEST_F(ModulesTest, ConstantPad2d) {
4119 {
4120 ConstantPad2d m(ConstantPad2dOptions(2, 3.5));
4121 auto input = torch::arange(4, torch::kFloat).reshape({1, 2, 2});
4122 auto output = m(input);
4123 auto expected = torch::tensor(
4124 {{{3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4125 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4126 {3.5000, 3.5000, 0.0000, 1.0000, 3.5000, 3.5000},
4127 {3.5000, 3.5000, 2.0000, 3.0000, 3.5000, 3.5000},
4128 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4129 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}},
4130 torch::kFloat);
4131 ASSERT_TRUE(output.allclose(expected));
4132 }
4133 {
4134 ConstantPad2d m(ConstantPad2dOptions({3, 0, 2, 1}, 3.5));
4135 auto input = torch::arange(4, torch::kFloat).reshape({1, 2, 2});
4136 auto output = m(input);
4137 auto expected = torch::tensor(
4138 {{{3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4139 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4140 {3.5000, 3.5000, 3.5000, 0.0000, 1.0000},
4141 {3.5000, 3.5000, 3.5000, 2.0000, 3.0000},
4142 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}},
4143 torch::kFloat);
4144 ASSERT_TRUE(output.allclose(expected));
4145 }
4146}
4147
4148TEST_F(ModulesTest, ConstantPad3d) {
4149 {
4150 ConstantPad3d m(ConstantPad3dOptions(1, 3.5));
4151 auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
4152 auto output = m(input);
4153 auto expected = torch::tensor(
4154 {{{{{3.5000, 3.5000, 3.5000, 3.5000},
4155 {3.5000, 3.5000, 3.5000, 3.5000},
4156 {3.5000, 3.5000, 3.5000, 3.5000},
4157 {3.5000, 3.5000, 3.5000, 3.5000}},
4158 {{3.5000, 3.5000, 3.5000, 3.5000},
4159 {3.5000, 0.0000, 1.0000, 3.5000},
4160 {3.5000, 2.0000, 3.0000, 3.5000},
4161 {3.5000, 3.5000, 3.5000, 3.5000}},
4162 {{3.5000, 3.5000, 3.5000, 3.5000},
4163 {3.5000, 4.0000, 5.0000, 3.5000},
4164 {3.5000, 6.0000, 7.0000, 3.5000},
4165 {3.5000, 3.5000, 3.5000, 3.5000}},
4166 {{3.5000, 3.5000, 3.5000, 3.5000},
4167 {3.5000, 3.5000, 3.5000, 3.5000},
4168 {3.5000, 3.5000, 3.5000, 3.5000},
4169 {3.5000, 3.5000, 3.5000, 3.5000}}}}},
4170 torch::kFloat);
4171 ASSERT_TRUE(output.allclose(expected));
4172 }
4173 {
4174 ConstantPad3d m(ConstantPad3dOptions({1, 2, 1, 2, 1, 2}, 3.5));
4175 auto input = torch::arange(8, torch::kFloat).reshape({1, 1, 2, 2, 2});
4176 auto output = m(input);
4177 auto expected = torch::tensor(
4178 {{{{{3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4179 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4180 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4181 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4182 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}},
4183 {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4184 {3.5000, 0.0000, 1.0000, 3.5000, 3.5000},
4185 {3.5000, 2.0000, 3.0000, 3.5000, 3.5000},
4186 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4187 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}},
4188 {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4189 {3.5000, 4.0000, 5.0000, 3.5000, 3.5000},
4190 {3.5000, 6.0000, 7.0000, 3.5000, 3.5000},
4191 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4192 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}},
4193 {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4194 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4195 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4196 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4197 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}},
4198 {{3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4199 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4200 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4201 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000},
4202 {3.5000, 3.5000, 3.5000, 3.5000, 3.5000}}}}},
4203 torch::kFloat);
4204 ASSERT_TRUE(output.allclose(expected));
4205 }
4206}
4207
4208TEST_F(ModulesTest, CrossMapLRN2d) {
4209 /// size 3, default options
4210 auto input =
4211 torch::arange(9, torch::kFloat32).view({1, 1, 3, 3}).requires_grad_(true);
4212 auto expected = torch::tensor(
4213 {{{{0.00000000, 0.99997497, 1.99980010},
4214 {2.99932500, 3.99840070, 4.99687700},
4215 {5.99460600, 6.99143740, 7.98722360}}}},
4216 torch::kFloat32);
4217 auto grad_expected = torch::tensor(
4218 {{{{1.00000000, 0.99992496, 0.99970007},
4219 {0.99932520, 0.99880093, 0.99812720},
4220 {0.99730474, 0.99633380, 0.99521490}}}},
4221 torch::kFloat32);
4222 auto crossmaplrn2d = CrossMapLRN2d(3);
4223 auto output = crossmaplrn2d(input);
4224 output.sum().backward();
4225
4226 ASSERT_TRUE(input.grad().allclose(grad_expected));
4227 ASSERT_TRUE(output.allclose(expected));
4228
4229 /// size change
4230 crossmaplrn2d =
4231 CrossMapLRN2d(CrossMapLRN2dOptions(4).alpha(1e-4).beta(0.75).k(1));
4232 output = crossmaplrn2d(input);
4233 expected = torch::tensor(
4234 {{{{0.00000000, 0.99998120, 1.99985000},
4235 {2.99949400, 3.99880050, 4.99765800},
4236 {5.99595300, 6.99357600, 7.99041300}}}},
4237 torch::kFloat32);
4238 ASSERT_TRUE(output.allclose(expected));
4239
4240 /// alpha change
4241 crossmaplrn2d =
4242 CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-3).beta(0.75).k(1));
4243 output = crossmaplrn2d(input);
4244 expected = torch::tensor(
4245 {{{{0.00000000, 0.99975010, 1.99800230},
4246 {2.99326750, 3.98407440, 4.96897600},
4247 {5.94656100, 6.91545720, 7.87434340}}}},
4248 torch::kFloat32);
4249 ASSERT_TRUE(output.allclose(expected));
4250
4251 /// beta change
4252 crossmaplrn2d =
4253 CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-4).beta(0.95).k(1));
4254 output = crossmaplrn2d(input);
4255 expected = torch::tensor(
4256 {{{{0.00000000, 0.99996830, 1.99974680},
4257 {2.99914500, 3.99797440, 4.99604460},
4258 {5.99316840, 6.98915600, 7.98382000}}}},
4259 torch::kFloat32);
4260 ASSERT_TRUE(output.allclose(expected));
4261
4262 /// k change
4263 crossmaplrn2d =
4264 CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-4).beta(0.75).k(2));
4265 output = crossmaplrn2d(input);
4266 expected = torch::tensor(
4267 {{{{0.00000000, 0.59459610, 1.18914770},
4268 {1.78361000, 2.37793870, 2.97208900},
4269 {3.56601700, 4.15967700, 4.75302650}}}},
4270 torch::kFloat32);
4271 ASSERT_TRUE(output.allclose(expected));
4272}
4273
4274TEST_F(ModulesTest, RNNCell) {
4275 torch::manual_seed(0);
4276 auto rnn = RNNCell(1, 2);
4277 auto input = torch::randn({3, 1});
4278 auto hx = torch::randn({3, 2});
4279 auto output = rnn(input, hx);
4280 auto expected =
4281 torch::tensor({{-0.5078, 0.4380}, {-0.7215, 0.2969}, {-0.1304, 0.0653}});
4282 ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04));
4283
4284 output = rnn(input);
4285 expected =
4286 torch::tensor({{-0.0775, 0.6688}, {-0.0734, 0.4759}, {-0.0725, 0.4225}});
4287 ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04));
4288}
4289
4290TEST_F(ModulesTest, LSTMCell) {
4291 torch::manual_seed(0);
4292 auto rnn = LSTMCell(1, 2);
4293 auto input = torch::randn({3, 1});
4294 auto hx = torch::randn({3, 2});
4295 auto cx = torch::randn({3, 2});
4296 auto output = rnn(input, std::make_tuple(hx, cx));
4297 auto output_hx = std::get<0>(output);
4298 auto output_cx = std::get<1>(output);
4299 auto expected_hx =
4300 torch::tensor({{-0.2462, 0.0810}, {-0.2206, 0.1867}, {-0.0146, 0.0429}});
4301 auto expected_cx =
4302 torch::tensor({{-0.4480, 0.1071}, {-0.6245, 0.2687}, {-0.0322, 0.0518}});
4303 ASSERT_TRUE(torch::allclose(output_hx, expected_hx, 1e-05, 2e-04));
4304 ASSERT_TRUE(torch::allclose(output_cx, expected_cx, 1e-05, 2e-04));
4305
4306 output = rnn(input);
4307 output_hx = std::get<0>(output);
4308 output_cx = std::get<1>(output);
4309 expected_hx =
4310 torch::tensor({{-0.1331, 0.1634}, {-0.1494, 0.2869}, {-0.1428, 0.2263}});
4311 expected_cx =
4312 torch::tensor({{-0.2679, 0.2180}, {-0.3049, 0.3493}, {-0.2896, 0.2853}});
4313 ASSERT_TRUE(torch::allclose(output_hx, expected_hx, 1e-05, 2e-04));
4314 ASSERT_TRUE(torch::allclose(output_cx, expected_cx, 1e-05, 2e-04));
4315}
4316
4317TEST_F(ModulesTest, GRUCell) {
4318 torch::manual_seed(0);
4319 auto rnn = GRUCell(1, 2);
4320 auto input = torch::randn({3, 1});
4321 auto hx = torch::randn({3, 2});
4322 auto output = rnn(input, hx);
4323 auto expected =
4324 torch::tensor({{1.0243, 0.3227}, {-0.5659, 0.0330}, {-0.4030, -0.2800}});
4325 ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04));
4326
4327 output = rnn(input);
4328 expected =
4329 torch::tensor({{-0.0085, 0.1095}, {-0.1291, 0.2675}, {-0.1339, 0.2725}});
4330 ASSERT_TRUE(torch::allclose(output, expected, 1e-05, 2e-04));
4331}
4332
4333TEST_F(ModulesTest, PrettyPrintLinear) {
4334 ASSERT_EQ(
4335 c10::str(Linear(3, 4)),
4336 "torch::nn::Linear(in_features=3, out_features=4, bias=true)");
4337}
4338
4339TEST_F(ModulesTest, PrettyPrintBilinear) {
4340 ASSERT_EQ(
4341 c10::str(Bilinear(3, 2, 4)),
4342 "torch::nn::Bilinear(in1_features=3, in2_features=2, out_features=4, bias=true)");
4343 ASSERT_EQ(
4344 c10::str(Bilinear(BilinearOptions(3, 2, 4).bias(false))),
4345 "torch::nn::Bilinear(in1_features=3, in2_features=2, out_features=4, bias=false)");
4346}
4347
4348TEST_F(ModulesTest, PrettyPrintConv) {
4349 ASSERT_EQ(
4350 c10::str(Conv1d(3, 4, 5)),
4351 "torch::nn::Conv1d(3, 4, kernel_size=5, stride=1)");
4352
4353 ASSERT_EQ(
4354 c10::str(Conv2d(3, 4, 5)),
4355 "torch::nn::Conv2d(3, 4, kernel_size=[5, 5], stride=[1, 1])");
4356 ASSERT_EQ(
4357 c10::str(Conv2d(Conv2dOptions(3, 4, 5).stride(2))),
4358 "torch::nn::Conv2d(3, 4, kernel_size=[5, 5], stride=[2, 2])");
4359 {
4360 const auto options =
4361 Conv2dOptions(3, 4, std::vector<int64_t>{5, 6}).stride({1, 2});
4362 ASSERT_EQ(
4363 c10::str(Conv2d(options)),
4364 "torch::nn::Conv2d(3, 4, kernel_size=[5, 6], stride=[1, 2])");
4365 }
4366
4367 ASSERT_EQ(
4368 c10::str(Conv3d(4, 4, std::vector<int64_t>{5, 6, 7})),
4369 "torch::nn::Conv3d(4, 4, kernel_size=[5, 6, 7], stride=[1, 1, 1])");
4370 {
4371 const auto options = Conv3dOptions(4, 4, std::vector<int64_t>{5, 6, 7})
4372 .stride({1, 2, 3})
4373 .padding(1)
4374 .dilation(0)
4375 .groups(2)
4376 .bias(false)
4377 .padding_mode(torch::kCircular);
4378 ASSERT_EQ(
4379 c10::str(Conv3d(options)),
4380 "torch::nn::Conv3d("
4381 "4, "
4382 "4, "
4383 "kernel_size=[5, 6, 7], "
4384 "stride=[1, 2, 3], "
4385 "padding=[1, 1, 1], "
4386 "dilation=[0, 0, 0], "
4387 "groups=2, "
4388 "bias=false, "
4389 "padding_mode=kCircular)");
4390 }
4391}
4392
4393TEST_F(ModulesTest, PrettyPrintConvTranspose) {
4394 ASSERT_EQ(
4395 c10::str(ConvTranspose1d(3, 4, 5)),
4396 "torch::nn::ConvTranspose1d(3, 4, kernel_size=5, stride=1)");
4397
4398 ASSERT_EQ(
4399 c10::str(ConvTranspose2d(3, 4, 5)),
4400 "torch::nn::ConvTranspose2d(3, 4, kernel_size=[5, 5], stride=[1, 1])");
4401 ASSERT_EQ(
4402 c10::str(ConvTranspose2d(ConvTranspose2dOptions(3, 4, 5).stride(2))),
4403 "torch::nn::ConvTranspose2d(3, 4, kernel_size=[5, 5], stride=[2, 2])");
4404 {
4405 const auto options =
4406 ConvTranspose2dOptions(3, 4, std::vector<int64_t>{5, 6}).stride({1, 2});
4407 ASSERT_EQ(
4408 c10::str(ConvTranspose2d(options)),
4409 "torch::nn::ConvTranspose2d(3, 4, kernel_size=[5, 6], stride=[1, 2])");
4410 }
4411
4412 ASSERT_EQ(
4413 c10::str(ConvTranspose3d(4, 4, std::vector<int64_t>{5, 6, 7})),
4414 "torch::nn::ConvTranspose3d(4, 4, kernel_size=[5, 6, 7], stride=[1, 1, 1])");
4415 {
4416 const auto options =
4417 ConvTranspose3dOptions(4, 4, std::vector<int64_t>{5, 6, 7})
4418 .stride({1, 2, 3})
4419 .padding(1)
4420 .dilation(0)
4421 .groups(2)
4422 .bias(false)
4423 .padding_mode(torch::kCircular);
4424 ASSERT_EQ(
4425 c10::str(ConvTranspose3d(options)),
4426 "torch::nn::ConvTranspose3d("
4427 "4, "
4428 "4, "
4429 "kernel_size=[5, 6, 7], "
4430 "stride=[1, 2, 3], "
4431 "padding=[1, 1, 1], "
4432 "dilation=[0, 0, 0], "
4433 "groups=2, "
4434 "bias=false, "
4435 "padding_mode=kCircular)");
4436 }
4437}
4438
4439TEST_F(ModulesTest, PrettyPrintUpsample) {
4440 ASSERT_EQ(
4441 c10::str(
4442 Upsample(UpsampleOptions().size(std::vector<int64_t>({2, 4, 4})))),
4443 "torch::nn::Upsample(size=[2, 4, 4], mode=kNearest)");
4444 ASSERT_EQ(
4445 c10::str(Upsample(UpsampleOptions()
4446 .scale_factor(std::vector<double>({0.5, 1.5}))
4447 .mode(torch::kBilinear))),
4448 "torch::nn::Upsample(scale_factor=[0.5, 1.5], mode=kBilinear)");
4449}
4450
4451TEST_F(ModulesTest, PrettyPrintFold) {
4452 ASSERT_EQ(
4453 c10::str(Fold(FoldOptions({2, 2}, {5, 5}))),
4454 "torch::nn::Fold(output_size=[2, 2], kernel_size=[5, 5], dilation=[1, 1], padding=[0, 0], stride=[1, 1])");
4455 ASSERT_EQ(
4456 c10::str(Fold(
4457 FoldOptions({8, 8}, {3, 3}).dilation(2).padding({2, 1}).stride(2))),
4458 "torch::nn::Fold(output_size=[8, 8], kernel_size=[3, 3], dilation=[2, 2], padding=[2, 1], stride=[2, 2])");
4459}
4460
4461TEST_F(ModulesTest, PrettyPrintUnfold) {
4462 ASSERT_EQ(
4463 c10::str(Unfold(torch::IntArrayRef({2, 4}))),
4464 "torch::nn::Unfold(kernel_size=[2, 4], dilation=[1, 1], padding=[0, 0], stride=[1, 1])");
4465 ASSERT_EQ(
4466 c10::str(
4467 Unfold(UnfoldOptions({2, 4}).dilation(2).padding({2, 1}).stride(2))),
4468 "torch::nn::Unfold(kernel_size=[2, 4], dilation=[2, 2], padding=[2, 1], stride=[2, 2])");
4469}
4470
4471TEST_F(ModulesTest, PrettyPrintMaxPool) {
4472 ASSERT_EQ(
4473 c10::str(MaxPool1d(5)),
4474 "torch::nn::MaxPool1d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=false)");
4475 ASSERT_EQ(
4476 c10::str(MaxPool2d(5)),
4477 "torch::nn::MaxPool2d(kernel_size=[5, 5], stride=[5, 5], padding=[0, 0], dilation=[1, 1], ceil_mode=false)");
4478 ASSERT_EQ(
4479 c10::str(MaxPool2d(MaxPool2dOptions(5).stride(2))),
4480 "torch::nn::MaxPool2d(kernel_size=[5, 5], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=false)");
4481 ASSERT_EQ(
4482 c10::str(MaxPool3d(5)),
4483 "torch::nn::MaxPool3d(kernel_size=[5, 5, 5], stride=[5, 5, 5], padding=[0, 0, 0], dilation=[1, 1, 1], ceil_mode=false)");
4484 ASSERT_EQ(
4485 c10::str(MaxPool3d(MaxPool3dOptions(5).stride(2))),
4486 "torch::nn::MaxPool3d(kernel_size=[5, 5, 5], stride=[2, 2, 2], padding=[0, 0, 0], dilation=[1, 1, 1], ceil_mode=false)");
4487
4488 const auto options =
4489 MaxPool2dOptions(std::vector<int64_t>{5, 6}).stride({1, 2});
4490 ASSERT_EQ(
4491 c10::str(MaxPool2d(options)),
4492 "torch::nn::MaxPool2d(kernel_size=[5, 6], stride=[1, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=false)");
4493}
4494
4495TEST_F(ModulesTest, PrettyPrintAvgPool) {
4496 ASSERT_EQ(
4497 c10::str(AvgPool1d(5)),
4498 "torch::nn::AvgPool1d(kernel_size=5, stride=5, padding=0)");
4499 ASSERT_EQ(
4500 c10::str(AvgPool2d(5)),
4501 "torch::nn::AvgPool2d(kernel_size=[5, 5], stride=[5, 5], padding=[0, 0])");
4502 ASSERT_EQ(
4503 c10::str(AvgPool2d(AvgPool2dOptions(5).stride(2))),
4504 "torch::nn::AvgPool2d(kernel_size=[5, 5], stride=[2, 2], padding=[0, 0])");
4505 ASSERT_EQ(
4506 c10::str(AvgPool3d(5)),
4507 "torch::nn::AvgPool3d(kernel_size=[5, 5, 5], stride=[5, 5, 5], padding=[0, 0, 0])");
4508 ASSERT_EQ(
4509 c10::str(AvgPool3d(AvgPool3dOptions(5).stride(2))),
4510 "torch::nn::AvgPool3d(kernel_size=[5, 5, 5], stride=[2, 2, 2], padding=[0, 0, 0])");
4511
4512 const auto options =
4513 AvgPool2dOptions(std::vector<int64_t>{5, 6}).stride({1, 2});
4514 ASSERT_EQ(
4515 c10::str(AvgPool2d(options)),
4516 "torch::nn::AvgPool2d(kernel_size=[5, 6], stride=[1, 2], padding=[0, 0])");
4517}
4518
4519TEST_F(ModulesTest, PrettyPrinFractionalMaxPool) {
4520 ASSERT_EQ(
4521 c10::str(
4522 FractionalMaxPool2d(FractionalMaxPool2dOptions(5).output_size(1))),
4523 "torch::nn::FractionalMaxPool2d()");
4524 ASSERT_EQ(
4525 c10::str(
4526 FractionalMaxPool3d(FractionalMaxPool3dOptions(5).output_size(1))),
4527 "torch::nn::FractionalMaxPool3d()");
4528}
4529
4530TEST_F(ModulesTest, PrettyPrintLPPool) {
4531 ASSERT_EQ(
4532 c10::str(LPPool1d(2, 5)),
4533 "torch::nn::LPPool1d(norm_type=2, kernel_size=5, stride=5, ceil_mode=false)");
4534 ASSERT_EQ(
4535 c10::str(LPPool1d(LPPool1dOptions(1, 2).stride(5).ceil_mode(true))),
4536 "torch::nn::LPPool1d(norm_type=1, kernel_size=2, stride=5, ceil_mode=true)");
4537 ASSERT_EQ(
4538 c10::str(LPPool2d(2, std::vector<int64_t>({1, 2}))),
4539 "torch::nn::LPPool2d(norm_type=2, kernel_size=[1, 2], stride=[1, 2], ceil_mode=false)");
4540 ASSERT_EQ(
4541 c10::str(LPPool2d(LPPool2dOptions(1, std::vector<int64_t>({3, 4}))
4542 .stride({5, 6})
4543 .ceil_mode(true))),
4544 "torch::nn::LPPool2d(norm_type=1, kernel_size=[3, 4], stride=[5, 6], ceil_mode=true)");
4545}
4546
4547TEST_F(ModulesTest, PrettyPrintAdaptiveMaxPool) {
4548 ASSERT_EQ(
4549 c10::str(AdaptiveMaxPool1d(5)),
4550 "torch::nn::AdaptiveMaxPool1d(output_size=5)");
4551
4552 const auto options = AdaptiveMaxPool1dOptions(3);
4553 ASSERT_EQ(
4554 c10::str(AdaptiveMaxPool1d(options)),
4555 "torch::nn::AdaptiveMaxPool1d(output_size=3)");
4556
4557 ASSERT_EQ(
4558 c10::str(AdaptiveMaxPool2d(5)),
4559 "torch::nn::AdaptiveMaxPool2d(output_size=[5, 5])");
4560 ASSERT_EQ(
4561 c10::str(AdaptiveMaxPool2d(AdaptiveMaxPool2dOptions({5, 6}))),
4562 "torch::nn::AdaptiveMaxPool2d(output_size=[5, 6])");
4563 ASSERT_EQ(
4564 c10::str(AdaptiveMaxPool2d(AdaptiveMaxPool2dOptions({5, c10::nullopt}))),
4565 "torch::nn::AdaptiveMaxPool2d(output_size=[5, None])");
4566 ASSERT_EQ(
4567 c10::str(AdaptiveMaxPool2d(
4568 AdaptiveMaxPool2dOptions({c10::nullopt, c10::nullopt}))),
4569 "torch::nn::AdaptiveMaxPool2d(output_size=[None, None])");
4570
4571 ASSERT_EQ(
4572 c10::str(AdaptiveMaxPool3d(5)),
4573 "torch::nn::AdaptiveMaxPool3d(output_size=[5, 5, 5])");
4574 ASSERT_EQ(
4575 c10::str(AdaptiveMaxPool3d(AdaptiveMaxPool3dOptions({5, 6, 7}))),
4576 "torch::nn::AdaptiveMaxPool3d(output_size=[5, 6, 7])");
4577 ASSERT_EQ(
4578 c10::str(
4579 AdaptiveMaxPool3d(AdaptiveMaxPool3dOptions({5, c10::nullopt, 7}))),
4580 "torch::nn::AdaptiveMaxPool3d(output_size=[5, None, 7])");
4581 ASSERT_EQ(
4582 c10::str(AdaptiveMaxPool3d(AdaptiveMaxPool3dOptions(
4583 {c10::nullopt, c10::nullopt, c10::nullopt}))),
4584 "torch::nn::AdaptiveMaxPool3d(output_size=[None, None, None])");
4585}
4586
4587TEST_F(ModulesTest, PrettyPrintAdaptiveAvgPool) {
4588 ASSERT_EQ(
4589 c10::str(AdaptiveAvgPool1d(5)),
4590 "torch::nn::AdaptiveAvgPool1d(output_size=5)");
4591
4592 ASSERT_EQ(
4593 c10::str(AdaptiveAvgPool2d(5)),
4594 "torch::nn::AdaptiveAvgPool2d(output_size=[5, 5])");
4595 ASSERT_EQ(
4596 c10::str(AdaptiveAvgPool2d(AdaptiveAvgPool2dOptions({5, 6}))),
4597 "torch::nn::AdaptiveAvgPool2d(output_size=[5, 6])");
4598 ASSERT_EQ(
4599 c10::str(AdaptiveAvgPool2d(AdaptiveAvgPool2dOptions({5, c10::nullopt}))),
4600 "torch::nn::AdaptiveAvgPool2d(output_size=[5, None])");
4601 ASSERT_EQ(
4602 c10::str(AdaptiveAvgPool2d(
4603 AdaptiveAvgPool2dOptions({c10::nullopt, c10::nullopt}))),
4604 "torch::nn::AdaptiveAvgPool2d(output_size=[None, None])");
4605
4606 ASSERT_EQ(
4607 c10::str(AdaptiveAvgPool3d(5)),
4608 "torch::nn::AdaptiveAvgPool3d(output_size=[5, 5, 5])");
4609 ASSERT_EQ(
4610 c10::str(AdaptiveAvgPool3d(AdaptiveAvgPool3dOptions({5, 6, 7}))),
4611 "torch::nn::AdaptiveAvgPool3d(output_size=[5, 6, 7])");
4612 ASSERT_EQ(
4613 c10::str(
4614 AdaptiveAvgPool3d(AdaptiveAvgPool3dOptions({5, c10::nullopt, 7}))),
4615 "torch::nn::AdaptiveAvgPool3d(output_size=[5, None, 7])");
4616 ASSERT_EQ(
4617 c10::str(AdaptiveAvgPool3d(AdaptiveAvgPool3dOptions(
4618 {c10::nullopt, c10::nullopt, c10::nullopt}))),
4619 "torch::nn::AdaptiveAvgPool3d(output_size=[None, None, None])");
4620}
4621
4622TEST_F(ModulesTest, PrettyPrintMaxUnpool) {
4623 ASSERT_EQ(
4624 c10::str(MaxUnpool1d(5)),
4625 "torch::nn::MaxUnpool1d(kernel_size=5, stride=5, padding=0)");
4626 ASSERT_EQ(
4627 c10::str(MaxUnpool1d(MaxUnpool1dOptions(5).stride(3).padding(1))),
4628 "torch::nn::MaxUnpool1d(kernel_size=5, stride=3, padding=1)");
4629
4630 ASSERT_EQ(
4631 c10::str(MaxUnpool2d(5)),
4632 "torch::nn::MaxUnpool2d(kernel_size=[5, 5], stride=[5, 5], padding=[0, 0])");
4633 ASSERT_EQ(
4634 c10::str(MaxUnpool2d(std::vector<int64_t>{5, 6})),
4635 "torch::nn::MaxUnpool2d(kernel_size=[5, 6], stride=[5, 6], padding=[0, 0])");
4636 ASSERT_EQ(
4637 c10::str(MaxUnpool2d(MaxUnpool2dOptions(std::vector<int64_t>{5, 6})
4638 .stride({3, 4})
4639 .padding({1, 2}))),
4640 "torch::nn::MaxUnpool2d(kernel_size=[5, 6], stride=[3, 4], padding=[1, 2])");
4641}
4642
4643TEST_F(ModulesTest, PrettyPrintDropout) {
4644 ASSERT_EQ(c10::str(Dropout()), "torch::nn::Dropout(p=0.5, inplace=false)");
4645 ASSERT_EQ(
4646 c10::str(Dropout(0.42)), "torch::nn::Dropout(p=0.42, inplace=false)");
4647 ASSERT_EQ(
4648 c10::str(Dropout(DropoutOptions().p(0.42).inplace(true))),
4649 "torch::nn::Dropout(p=0.42, inplace=true)");
4650}
4651
4652TEST_F(ModulesTest, PrettyPrintDropout2d) {
4653 ASSERT_EQ(
4654 c10::str(Dropout2d()), "torch::nn::Dropout2d(p=0.5, inplace=false)");
4655 ASSERT_EQ(
4656 c10::str(Dropout2d(0.42)), "torch::nn::Dropout2d(p=0.42, inplace=false)");
4657 ASSERT_EQ(
4658 c10::str(Dropout2d(Dropout2dOptions().p(0.42).inplace(true))),
4659 "torch::nn::Dropout2d(p=0.42, inplace=true)");
4660}
4661
4662TEST_F(ModulesTest, PrettyPrintDropout3d) {
4663 ASSERT_EQ(
4664 c10::str(Dropout3d()), "torch::nn::Dropout3d(p=0.5, inplace=false)");
4665 ASSERT_EQ(
4666 c10::str(Dropout3d(0.42)), "torch::nn::Dropout3d(p=0.42, inplace=false)");
4667 ASSERT_EQ(
4668 c10::str(Dropout3d(Dropout3dOptions().p(0.42).inplace(true))),
4669 "torch::nn::Dropout3d(p=0.42, inplace=true)");
4670}
4671
4672TEST_F(ModulesTest, PrettyPrintFunctional) {
4673 ASSERT_EQ(c10::str(Functional(torch::relu)), "torch::nn::Functional()");
4674}
4675
4676TEST_F(ModulesTest, PrettyPrintBatchNorm1d) {
4677 ASSERT_EQ(
4678 c10::str(BatchNorm1d(BatchNorm1dOptions(4)
4679 .eps(0.5)
4680 .momentum(0.1)
4681 .affine(false)
4682 .track_running_stats(true))),
4683 "torch::nn::BatchNorm1d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
4684}
4685
4686TEST_F(ModulesTest, PrettyPrintBatchNorm2d) {
4687 ASSERT_EQ(
4688 c10::str(BatchNorm2d(BatchNorm2dOptions(4)
4689 .eps(0.5)
4690 .momentum(0.1)
4691 .affine(false)
4692 .track_running_stats(true))),
4693 "torch::nn::BatchNorm2d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
4694}
4695
4696TEST_F(ModulesTest, PrettyPrintBatchNorm3d) {
4697 ASSERT_EQ(
4698 c10::str(BatchNorm3d(BatchNorm3dOptions(4)
4699 .eps(0.5)
4700 .momentum(0.1)
4701 .affine(false)
4702 .track_running_stats(true))),
4703 "torch::nn::BatchNorm3d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
4704}
4705
4706TEST_F(ModulesTest, PrettyPrintInstanceNorm1d) {
4707 ASSERT_EQ(
4708 c10::str(InstanceNorm1d(InstanceNorm1dOptions(4)
4709 .eps(0.5)
4710 .momentum(0.1)
4711 .affine(false)
4712 .track_running_stats(true))),
4713 "torch::nn::InstanceNorm1d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
4714}
4715
4716TEST_F(ModulesTest, PrettyPrintInstanceNorm2d) {
4717 ASSERT_EQ(
4718 c10::str(InstanceNorm2d(InstanceNorm2dOptions(4)
4719 .eps(0.5)
4720 .momentum(0.1)
4721 .affine(false)
4722 .track_running_stats(true))),
4723 "torch::nn::InstanceNorm2d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
4724}
4725
4726TEST_F(ModulesTest, PrettyPrintInstanceNorm3d) {
4727 ASSERT_EQ(
4728 c10::str(InstanceNorm3d(InstanceNorm3dOptions(4)
4729 .eps(0.5)
4730 .momentum(0.1)
4731 .affine(false)
4732 .track_running_stats(true))),
4733 "torch::nn::InstanceNorm3d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
4734}
4735
4736TEST_F(ModulesTest, PrettyPrintLayerNorm) {
4737 ASSERT_EQ(
4738 c10::str(LayerNorm(LayerNormOptions({2, 2}))),
4739 "torch::nn::LayerNorm([2, 2], eps=1e-05, elementwise_affine=true)");
4740 ASSERT_EQ(
4741 c10::str(LayerNorm(
4742 LayerNormOptions({2, 2}).elementwise_affine(false).eps(2e-5))),
4743 "torch::nn::LayerNorm([2, 2], eps=2e-05, elementwise_affine=false)");
4744}
4745
4746TEST_F(ModulesTest, PrettyPrintGroupNorm) {
4747 ASSERT_EQ(
4748 c10::str(GroupNorm(GroupNormOptions(2, 2))),
4749 "torch::nn::GroupNorm(2, 2, eps=1e-05, affine=true)");
4750 ASSERT_EQ(
4751 c10::str(GroupNorm(GroupNormOptions(2, 2).eps(2e-5).affine(false))),
4752 "torch::nn::GroupNorm(2, 2, eps=2e-05, affine=false)");
4753}
4754
4755TEST_F(ModulesTest, PrettyPrintLocalResponseNorm) {
4756 ASSERT_EQ(
4757 c10::str(LocalResponseNorm(LocalResponseNormOptions(2))),
4758 "torch::nn::LocalResponseNorm(2, alpha=0.0001, beta=0.75, k=1)");
4759 ASSERT_EQ(
4760 c10::str(LocalResponseNorm(
4761 LocalResponseNormOptions(2).alpha(0.0002).beta(0.85).k(2.))),
4762 "torch::nn::LocalResponseNorm(2, alpha=0.0002, beta=0.85, k=2)");
4763}
4764
4765TEST_F(ModulesTest, PrettyPrintEmbedding) {
4766 ASSERT_EQ(
4767 c10::str(Embedding(EmbeddingOptions(10, 2))),
4768 "torch::nn::Embedding(num_embeddings=10, embedding_dim=2)");
4769 ASSERT_EQ(
4770 c10::str(Embedding(EmbeddingOptions(10, 2).padding_idx(3).max_norm(2))),
4771 "torch::nn::Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2)");
4772 ASSERT_EQ(
4773 c10::str(Embedding(EmbeddingOptions(10, 2)
4774 .padding_idx(3)
4775 .max_norm(2)
4776 .norm_type(2.5)
4777 .scale_grad_by_freq(true)
4778 .sparse(true))),
4779 "torch::nn::Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true)");
4780}
4781
4782TEST_F(ModulesTest, PrettyPrintEmbeddingBag) {
4783 ASSERT_EQ(
4784 c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2))),
4785 "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2)");
4786 ASSERT_EQ(
4787 c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2).max_norm(2))),
4788 "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2)");
4789 ASSERT_EQ(
4790 c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2)
4791 .max_norm(2)
4792 .norm_type(2.5)
4793 .scale_grad_by_freq(true)
4794 .sparse(true))),
4795 "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true)");
4796 ASSERT_EQ(
4797 c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2)
4798 .max_norm(2)
4799 .norm_type(2.5)
4800 .scale_grad_by_freq(true)
4801 .sparse(true)
4802 .mode(torch::kSum))),
4803 "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true, mode=kSum)");
4804 ASSERT_EQ(
4805 c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2)
4806 .max_norm(2)
4807 .norm_type(2.5)
4808 .scale_grad_by_freq(true)
4809 .sparse(true)
4810 .mode(torch::kSum)
4811 .padding_idx(5))),
4812 "torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true, mode=kSum, padding_idx=5)");
4813}
4814
4815TEST_F(ModulesTest, PrettyPrintL1Loss) {
4816 ASSERT_EQ(c10::str(L1Loss()), "torch::nn::L1Loss()");
4817}
4818TEST_F(ModulesTest, PrettyPrintKLDivLoss) {
4819 ASSERT_EQ(c10::str(KLDivLoss()), "torch::nn::KLDivLoss()");
4820}
4821TEST_F(ModulesTest, PrettyPrintMSELoss) {
4822 ASSERT_EQ(c10::str(MSELoss()), "torch::nn::MSELoss()");
4823}
4824TEST_F(ModulesTest, PrettyPrintBCELoss) {
4825 ASSERT_EQ(c10::str(BCELoss()), "torch::nn::BCELoss()");
4826}
4827TEST_F(ModulesTest, PrettyPrintHingeEmbeddingLoss) {
4828 ASSERT_EQ(
4829 c10::str(HingeEmbeddingLoss(HingeEmbeddingLossOptions().margin(4))),
4830 "torch::nn::HingeEmbeddingLoss(margin=4)");
4831}
4832
4833TEST_F(ModulesTest, PrettyPrintCosineEmbeddingLoss) {
4834 ASSERT_EQ(
4835 c10::str(CosineEmbeddingLoss(CosineEmbeddingLossOptions().margin(0.25))),
4836 "torch::nn::CosineEmbeddingLoss(margin=0.25)");
4837}
4838
4839TEST_F(ModulesTest, PrettyPrintTripletMarginLoss) {
4840 ASSERT_EQ(
4841 c10::str(TripletMarginLoss(
4842 TripletMarginLossOptions().margin(3).p(2).eps(1e-06).swap(false))),
4843 "torch::nn::TripletMarginLoss(margin=3, p=2, eps=1e-06, swap=false)");
4844}
4845
4846TEST_F(ModulesTest, PrettyPrintTripletMarginWithDistanceLoss) {
4847 auto distanceOptions = TripletMarginWithDistanceLossOptions()
4848 .distance_function([&](const torch::Tensor& x,
4849 const torch::Tensor& y) {
4850 return torch::pairwise_distance(x, y, 2.0, 1e-6);
4851 })
4852 .margin(1.5)
4853 .swap(true)
4854 .reduction(torch::kMean);
4855 ASSERT_EQ(
4856 c10::str(TripletMarginWithDistanceLoss(distanceOptions)),
4857 "torch::nn::TripletMarginWithDistanceLoss(margin=1.5, swap=true)");
4858}
4859
4860TEST_F(ModulesTest, PrettyPrintNLLLoss) {
4861 ASSERT_EQ(c10::str(NLLLoss()), "torch::nn::NLLLoss()");
4862}
4863
4864TEST_F(ModulesTest, PrettyPrinCrossEntropyLoss) {
4865 ASSERT_EQ(c10::str(CrossEntropyLoss()), "torch::nn::CrossEntropyLoss()");
4866}
4867
4868TEST_F(ModulesTest, PrettyPrintMultiLabelMarginLoss) {
4869 ASSERT_EQ(
4870 c10::str(MultiLabelMarginLoss()), "torch::nn::MultiLabelMarginLoss()");
4871}
4872
4873TEST_F(ModulesTest, PrettyPrintMultiLabelSoftMarginLoss) {
4874 ASSERT_EQ(
4875 c10::str(MultiLabelSoftMarginLoss()),
4876 "torch::nn::MultiLabelSoftMarginLoss()");
4877}
4878
4879TEST_F(ModulesTest, PrettyPrintSoftMarginLoss) {
4880 ASSERT_EQ(c10::str(SoftMarginLoss()), "torch::nn::SoftMarginLoss()");
4881}
4882
4883TEST_F(ModulesTest, PrettyPrintCosineSimilarity) {
4884 ASSERT_EQ(
4885 c10::str(CosineSimilarity()),
4886 "torch::nn::CosineSimilarity(dim=1, eps=1e-08)");
4887 ASSERT_EQ(
4888 c10::str(CosineSimilarity(CosineSimilarityOptions().dim(0).eps(0.5))),
4889 "torch::nn::CosineSimilarity(dim=0, eps=0.5)");
4890}
4891
4892TEST_F(ModulesTest, PrettyPrintPairwiseDistance) {
4893 ASSERT_EQ(
4894 c10::str(PairwiseDistance()),
4895 "torch::nn::PairwiseDistance(p=2, eps=1e-06, keepdim=false)");
4896 ASSERT_EQ(
4897 c10::str(PairwiseDistance(
4898 PairwiseDistanceOptions().p(3).eps(0.5).keepdim(true))),
4899 "torch::nn::PairwiseDistance(p=3, eps=0.5, keepdim=true)");
4900}
4901
4902TEST_F(ModulesTest, PrettyPrintReflectionPad) {
4903 ASSERT_EQ(
4904 c10::str(ReflectionPad1d(ReflectionPad1dOptions(2))),
4905 "torch::nn::ReflectionPad1d(padding=[2, 2])");
4906 ASSERT_EQ(
4907 c10::str(ReflectionPad1d(ReflectionPad1dOptions({3, 1}))),
4908 "torch::nn::ReflectionPad1d(padding=[3, 1])");
4909 ASSERT_EQ(
4910 c10::str(ReflectionPad2d(ReflectionPad2dOptions(2))),
4911 "torch::nn::ReflectionPad2d(padding=[2, 2, 2, 2])");
4912 ASSERT_EQ(
4913 c10::str(ReflectionPad2d(ReflectionPad2dOptions({1, 1, 2, 0}))),
4914 "torch::nn::ReflectionPad2d(padding=[1, 1, 2, 0])");
4915}
4916
4917TEST_F(ModulesTest, PrettyPrintReplicationPad) {
4918 ASSERT_EQ(
4919 c10::str(ReplicationPad1d(ReplicationPad1dOptions(2))),
4920 "torch::nn::ReplicationPad1d(padding=[2, 2])");
4921 ASSERT_EQ(
4922 c10::str(ReplicationPad1d(ReplicationPad1dOptions({3, 1}))),
4923 "torch::nn::ReplicationPad1d(padding=[3, 1])");
4924 ASSERT_EQ(
4925 c10::str(ReplicationPad2d(ReplicationPad2dOptions(2))),
4926 "torch::nn::ReplicationPad2d(padding=[2, 2, 2, 2])");
4927 ASSERT_EQ(
4928 c10::str(ReplicationPad2d(ReplicationPad2dOptions({1, 1, 2, 0}))),
4929 "torch::nn::ReplicationPad2d(padding=[1, 1, 2, 0])");
4930 ASSERT_EQ(
4931 c10::str(ReplicationPad3d(ReplicationPad3dOptions(1))),
4932 "torch::nn::ReplicationPad3d(padding=[1, 1, 1, 1, 1, 1])");
4933 ASSERT_EQ(
4934 c10::str(ReplicationPad3d(ReplicationPad3dOptions({1, 2, 1, 2, 1, 2}))),
4935 "torch::nn::ReplicationPad3d(padding=[1, 2, 1, 2, 1, 2])");
4936}
4937
4938TEST_F(ModulesTest, PrettyPrintZeroPad2d) {
4939 ASSERT_EQ(
4940 c10::str(ZeroPad2d(ZeroPad2dOptions(2))),
4941 "torch::nn::ZeroPad2d(padding=[2, 2, 2, 2])");
4942 ASSERT_EQ(
4943 c10::str(ZeroPad2d(ZeroPad2dOptions({1, 1, 2, 0}))),
4944 "torch::nn::ZeroPad2d(padding=[1, 1, 2, 0])");
4945}
4946
4947TEST_F(ModulesTest, PrettyPrintConstantPad) {
4948 ASSERT_EQ(
4949 c10::str(ConstantPad1d(ConstantPad1dOptions(2, 3.5))),
4950 "torch::nn::ConstantPad1d(padding=[2, 2], value=3.5)");
4951 ASSERT_EQ(
4952 c10::str(ConstantPad1d(ConstantPad1dOptions({3, 1}, 3.5))),
4953 "torch::nn::ConstantPad1d(padding=[3, 1], value=3.5)");
4954 ASSERT_EQ(
4955 c10::str(ConstantPad2d(ConstantPad2dOptions(2, 3.5))),
4956 "torch::nn::ConstantPad2d(padding=[2, 2, 2, 2], value=3.5)");
4957 ASSERT_EQ(
4958 c10::str(ConstantPad2d(ConstantPad2dOptions({3, 0, 2, 1}, 3.5))),
4959 "torch::nn::ConstantPad2d(padding=[3, 0, 2, 1], value=3.5)");
4960 ASSERT_EQ(
4961 c10::str(ConstantPad3d(ConstantPad3dOptions(1, 3.5))),
4962 "torch::nn::ConstantPad3d(padding=[1, 1, 1, 1, 1, 1], value=3.5)");
4963 ASSERT_EQ(
4964 c10::str(ConstantPad3d(ConstantPad3dOptions({1, 2, 1, 2, 1, 2}, 3.5))),
4965 "torch::nn::ConstantPad3d(padding=[1, 2, 1, 2, 1, 2], value=3.5)");
4966}
4967
4968TEST_F(ModulesTest, PrettyPrintNestedModel) {
4969 struct InnerTestModule : torch::nn::Module {
4970 InnerTestModule()
4971 : torch::nn::Module("InnerTestModule"),
4972 fc(register_module("fc", torch::nn::Linear(3, 4))),
4973 table(register_module("table", torch::nn::Embedding(10, 2))) {}
4974
4975 torch::nn::Linear fc;
4976 torch::nn::Embedding table;
4977 };
4978
4979 struct TestModule : torch::nn::Module {
4980 TestModule()
4981 : torch::nn::Module("TestModule"),
4982 fc(register_module("fc", torch::nn::Linear(4, 5))),
4983 table(register_module(
4984 "table",
4985 torch::nn::Embedding(EmbeddingOptions(10, 2)))),
4986 inner(register_module("inner", std::make_shared<InnerTestModule>())) {
4987 }
4988
4989 torch::nn::Linear fc;
4990 torch::nn::Embedding table;
4991 std::shared_ptr<InnerTestModule> inner;
4992 };
4993
4994 ASSERT_EQ(
4995 c10::str(TestModule{}),
4996 "TestModule(\n"
4997 " (fc): torch::nn::Linear(in_features=4, out_features=5, bias=true)\n"
4998 " (table): torch::nn::Embedding(num_embeddings=10, embedding_dim=2)\n"
4999 " (inner): InnerTestModule(\n"
5000 " (fc): torch::nn::Linear(in_features=3, out_features=4, bias=true)\n"
5001 " (table): torch::nn::Embedding(num_embeddings=10, embedding_dim=2)\n"
5002 " )\n"
5003 ")");
5004}
5005
5006TEST_F(ModulesTest, PrettyPrintELU) {
5007 ASSERT_EQ(c10::str(ELU()), "torch::nn::ELU(alpha=1)");
5008 ASSERT_EQ(
5009 c10::str(ELU(ELUOptions().alpha(42.42).inplace(true))),
5010 "torch::nn::ELU(alpha=42.42, inplace=true)");
5011}
5012
5013TEST_F(ModulesTest, PrettyPrintSELU) {
5014 ASSERT_EQ(c10::str(SELU()), "torch::nn::SELU()");
5015 ASSERT_EQ(
5016 c10::str(SELU(SELUOptions().inplace(true))),
5017 "torch::nn::SELU(inplace=true)");
5018}
5019
5020TEST_F(ModulesTest, PrettyPrintGLU) {
5021 ASSERT_EQ(c10::str(GLU()), "torch::nn::GLU(dim=-1)");
5022 ASSERT_EQ(c10::str(GLU(1)), "torch::nn::GLU(dim=1)");
5023}
5024
5025TEST_F(ModulesTest, PrettyPrintHardshrink) {
5026 ASSERT_EQ(c10::str(Hardshrink()), "torch::nn::Hardshrink(0.5)");
5027 ASSERT_EQ(
5028 c10::str(Hardshrink(HardshrinkOptions().lambda(42.42))),
5029 "torch::nn::Hardshrink(42.42)");
5030}
5031
5032TEST_F(ModulesTest, PrettyPrintHardtanh) {
5033 ASSERT_EQ(c10::str(Hardtanh()), "torch::nn::Hardtanh(min_val=-1, max_val=1)");
5034 ASSERT_EQ(
5035 c10::str(Hardtanh(
5036 HardtanhOptions().min_val(-42.42).max_val(0.42).inplace(true))),
5037 "torch::nn::Hardtanh(min_val=-42.42, max_val=0.42, inplace=true)");
5038}
5039
5040TEST_F(ModulesTest, PrettyPrintLeakyReLU) {
5041 ASSERT_EQ(c10::str(LeakyReLU()), "torch::nn::LeakyReLU(negative_slope=0.01)");
5042 ASSERT_EQ(
5043 c10::str(
5044 LeakyReLU(LeakyReLUOptions().negative_slope(0.42).inplace(true))),
5045 "torch::nn::LeakyReLU(negative_slope=0.42, inplace=true)");
5046}
5047
5048TEST_F(ModulesTest, PrettyPrintLogSigmoid) {
5049 ASSERT_EQ(c10::str(LogSigmoid()), "torch::nn::LogSigmoid()");
5050}
5051
5052TEST_F(ModulesTest, PrettyPrintSoftmax) {
5053 ASSERT_EQ(c10::str(Softmax(SoftmaxOptions(1))), "torch::nn::Softmax(dim=1)");
5054}
5055
5056TEST_F(ModulesTest, PrettyPrintSoftmin) {
5057 ASSERT_EQ(c10::str(Softmin(SoftminOptions(1))), "torch::nn::Softmin(dim=1)");
5058}
5059
5060TEST_F(ModulesTest, PrettyPrintLogSoftmax) {
5061 ASSERT_EQ(
5062 c10::str(LogSoftmax(LogSoftmaxOptions(1))),
5063 "torch::nn::LogSoftmax(dim=1)");
5064}
5065
5066TEST_F(ModulesTest, PrettyPrintSoftmax2d) {
5067 ASSERT_EQ(c10::str(Softmax2d()), "torch::nn::Softmax2d()");
5068}
5069
5070TEST_F(ModulesTest, PrettyPrintPReLU) {
5071 ASSERT_EQ(c10::str(PReLU()), "torch::nn::PReLU(num_parameters=1)");
5072 ASSERT_EQ(
5073 c10::str(PReLU(PReLUOptions().num_parameters(42))),
5074 "torch::nn::PReLU(num_parameters=42)");
5075}
5076
5077TEST_F(ModulesTest, PrettyPrintReLU) {
5078 ASSERT_EQ(c10::str(ReLU()), "torch::nn::ReLU()");
5079 ASSERT_EQ(
5080 c10::str(ReLU(ReLUOptions().inplace(true))),
5081 "torch::nn::ReLU(inplace=true)");
5082 ASSERT_EQ(c10::str(ReLU(/*inplace=*/true)), "torch::nn::ReLU(inplace=true)");
5083}
5084
5085TEST_F(ModulesTest, PrettyPrintReLU6) {
5086 ASSERT_EQ(c10::str(ReLU6()), "torch::nn::ReLU6()");
5087 ASSERT_EQ(
5088 c10::str(ReLU6(ReLU6Options().inplace(true))),
5089 "torch::nn::ReLU6(inplace=true)");
5090 ASSERT_EQ(
5091 c10::str(ReLU6(/*inplace=*/true)), "torch::nn::ReLU6(inplace=true)");
5092}
5093
5094TEST_F(ModulesTest, PrettyPrintRReLU) {
5095 ASSERT_EQ(c10::str(RReLU()), "torch::nn::RReLU(lower=0.125, upper=0.333333)");
5096 ASSERT_EQ(
5097 c10::str(RReLU(RReLUOptions().lower(0.24).upper(0.42).inplace(true))),
5098 "torch::nn::RReLU(lower=0.24, upper=0.42, inplace=true)");
5099}
5100
5101TEST_F(ModulesTest, PrettyPrintCELU) {
5102 ASSERT_EQ(c10::str(CELU()), "torch::nn::CELU(alpha=1)");
5103 ASSERT_EQ(
5104 c10::str(CELU(CELUOptions().alpha(42.42).inplace(true))),
5105 "torch::nn::CELU(alpha=42.42, inplace=true)");
5106}
5107
5108TEST_F(ModulesTest, PrettyPrintSigmoid) {
5109 ASSERT_EQ(c10::str(Sigmoid()), "torch::nn::Sigmoid()");
5110}
5111
5112TEST_F(ModulesTest, PrettyPrintPixelShuffle) {
5113 ASSERT_EQ(
5114 c10::str(PixelShuffle(PixelShuffleOptions(5))),
5115 "torch::nn::PixelShuffle(upscale_factor=5)");
5116}
5117
5118TEST_F(ModulesTest, PrettyPrintPixelUnshuffle) {
5119 ASSERT_EQ(
5120 c10::str(PixelUnshuffle(PixelUnshuffleOptions(5))),
5121 "torch::nn::PixelUnshuffle(downscale_factor=5)");
5122}
5123
5124TEST_F(ModulesTest, PrettyPrintSoftplus) {
5125 ASSERT_EQ(c10::str(Softplus()), "torch::nn::Softplus(beta=1, threshold=20)");
5126 ASSERT_EQ(
5127 c10::str(Softplus(SoftplusOptions().beta(0.24).threshold(42.42))),
5128 "torch::nn::Softplus(beta=0.24, threshold=42.42)");
5129}
5130
5131TEST_F(ModulesTest, PrettyPrintSoftshrink) {
5132 ASSERT_EQ(c10::str(Softshrink()), "torch::nn::Softshrink(0.5)");
5133 ASSERT_EQ(
5134 c10::str(Softshrink(SoftshrinkOptions(42.42))),
5135 "torch::nn::Softshrink(42.42)");
5136}
5137
5138TEST_F(ModulesTest, PrettyPrintSoftsign) {
5139 ASSERT_EQ(c10::str(Softsign()), "torch::nn::Softsign()");
5140}
5141
5142TEST_F(ModulesTest, PrettyPrintTanh) {
5143 ASSERT_EQ(c10::str(Tanh()), "torch::nn::Tanh()");
5144}
5145
5146TEST_F(ModulesTest, PrettyPrintTanhshrink) {
5147 ASSERT_EQ(c10::str(Tanhshrink()), "torch::nn::Tanhshrink()");
5148}
5149
5150TEST_F(ModulesTest, PrettyPrintThreshold) {
5151 ASSERT_EQ(
5152 c10::str(Threshold(24.24, 42.42)),
5153 "torch::nn::Threshold(threshold=24.24, value=42.42)");
5154 ASSERT_EQ(
5155 c10::str(Threshold(ThresholdOptions(42.42, 24.24).inplace(true))),
5156 "torch::nn::Threshold(threshold=42.42, value=24.24, inplace=true)");
5157}
5158
5159TEST_F(ModulesTest, PrettyPrintCTCLoss) {
5160 ASSERT_EQ(c10::str(CTCLoss()), "torch::nn::CTCLoss()");
5161 ASSERT_EQ(
5162 c10::str(
5163 CTCLoss(CTCLossOptions().blank(42).zero_infinity(false).reduction(
5164 torch::kSum))),
5165 "torch::nn::CTCLoss()");
5166}
5167
5168TEST_F(ModulesTest, PrettyPrintPoissonNLLLoss) {
5169 ASSERT_EQ(c10::str(PoissonNLLLoss()), "torch::nn::PoissonNLLLoss()");
5170 ASSERT_EQ(
5171 c10::str(PoissonNLLLoss(PoissonNLLLossOptions()
5172 .log_input(false)
5173 .full(true)
5174 .eps(0.42)
5175 .reduction(torch::kSum))),
5176 "torch::nn::PoissonNLLLoss()");
5177}
5178
5179TEST_F(ModulesTest, PrettyPrintMarginRankingLoss) {
5180 ASSERT_EQ(c10::str(MarginRankingLoss()), "torch::nn::MarginRankingLoss()");
5181 ASSERT_EQ(
5182 c10::str(MarginRankingLoss(
5183 MarginRankingLossOptions().margin(0.5).reduction(torch::kSum))),
5184 "torch::nn::MarginRankingLoss()");
5185}
5186
5187TEST_F(ModulesTest, PrettyPrintCrossMapLRN2d) {
5188 ASSERT_EQ(
5189 c10::str(CrossMapLRN2d(4)),
5190 "torch::nn::CrossMapLRN2d(4, alpha=0.0001, beta=0.75, k=1)");
5191 ASSERT_EQ(
5192 c10::str(
5193 CrossMapLRN2d(CrossMapLRN2dOptions(3).alpha(1e-5).beta(0.1).k(10))),
5194 "torch::nn::CrossMapLRN2d(3, alpha=1e-05, beta=0.1, k=10)");
5195}
5196
5197TEST_F(ModulesTest, PrettyPrintAlphaDropout) {
5198 ASSERT_EQ(
5199 c10::str(AlphaDropout()),
5200 "torch::nn::AlphaDropout(p=0.5, inplace=false)");
5201 ASSERT_EQ(
5202 c10::str(AlphaDropout(AlphaDropoutOptions(0.2))),
5203 "torch::nn::AlphaDropout(p=0.2, inplace=false)");
5204 ASSERT_EQ(
5205 c10::str(AlphaDropout(AlphaDropoutOptions(0.2).inplace(true))),
5206 "torch::nn::AlphaDropout(p=0.2, inplace=true)");
5207}
5208
5209TEST_F(ModulesTest, PrettyPrintFeatureAlphaDropout) {
5210 ASSERT_EQ(
5211 c10::str(FeatureAlphaDropout()),
5212 "torch::nn::FeatureAlphaDropout(p=0.5, inplace=false)");
5213 ASSERT_EQ(
5214 c10::str(FeatureAlphaDropout(FeatureAlphaDropoutOptions(0.2))),
5215 "torch::nn::FeatureAlphaDropout(p=0.2, inplace=false)");
5216 ASSERT_EQ(
5217 c10::str(
5218 FeatureAlphaDropout(FeatureAlphaDropoutOptions(0.2).inplace(true))),
5219 "torch::nn::FeatureAlphaDropout(p=0.2, inplace=true)");
5220}
5221
5222TEST_F(ModulesTest, PrettyPrintBCEWithLogitsLoss) {
5223 ASSERT_EQ(c10::str(BCEWithLogitsLoss()), "torch::nn::BCEWithLogitsLoss()");
5224 ASSERT_EQ(
5225 c10::str(BCEWithLogitsLoss(BCEWithLogitsLossOptions()
5226 .weight(torch::ones({3, 3}))
5227 .pos_weight(torch::ones({3, 3}))
5228 .reduction(torch::kSum))),
5229 "torch::nn::BCEWithLogitsLoss()");
5230}
5231
5232TEST_F(ModulesTest, PrettyPrintMultiheadAttention) {
5233 ASSERT_EQ(
5234 c10::str(MultiheadAttention(20, 10)),
5235 "torch::nn::MultiheadAttention(\n (out_proj): torch::nn::Linear(in_features=20, out_features=20, bias=true)\n)");
5236 ASSERT_EQ(
5237 c10::str(
5238 MultiheadAttention(MultiheadAttentionOptions(20, 10).bias(false))),
5239 "torch::nn::MultiheadAttention(\n (out_proj): torch::nn::Linear(in_features=20, out_features=20, bias=false)\n)");
5240}
5241
5242TEST_F(ModulesTest, PrettyPrintRNNCell) {
5243 ASSERT_EQ(c10::str(RNNCell(20, 10)), "torch::nn::RNNCell(20, 10)");
5244 ASSERT_EQ(
5245 c10::str(RNNCell(
5246 RNNCellOptions(20, 10).bias(false).nonlinearity(torch::kTanh))),
5247 "torch::nn::RNNCell(20, 10, bias=false)");
5248 ASSERT_EQ(
5249 c10::str(RNNCell(
5250 RNNCellOptions(20, 10).bias(false).nonlinearity(torch::kReLU))),
5251 "torch::nn::RNNCell(20, 10, bias=false, nonlinearity=kReLU)");
5252}
5253
5254TEST_F(ModulesTest, PrettyPrintLSTMCell) {
5255 ASSERT_EQ(c10::str(LSTMCell(20, 10)), "torch::nn::LSTMCell(20, 10)");
5256 ASSERT_EQ(
5257 c10::str(LSTMCell(LSTMCellOptions(20, 10).bias(false))),
5258 "torch::nn::LSTMCell(20, 10, bias=false)");
5259}
5260
5261TEST_F(ModulesTest, PrettyPrintGRUCell) {
5262 ASSERT_EQ(c10::str(GRUCell(20, 10)), "torch::nn::GRUCell(20, 10)");
5263 ASSERT_EQ(
5264 c10::str(GRUCell(GRUCellOptions(20, 10).bias(false))),
5265 "torch::nn::GRUCell(20, 10, bias=false)");
5266}
5267
5268TEST_F(ModulesTest, PrettyPrintAdaptiveLogSoftmaxWithLoss) {
5269 {
5270 AdaptiveLogSoftmaxWithLoss asfm(
5271 AdaptiveLogSoftmaxWithLossOptions(8, 4, {2}).div_value(2.));
5272 ASSERT_EQ(
5273 c10::str(asfm),
5274 "torch::nn::AdaptiveLogSoftmaxWithLoss(\n"
5275 " (head): torch::nn::Linear(in_features=8, out_features=3, bias=false)\n"
5276 " (tail): torch::nn::ModuleList(\n"
5277 " (0): torch::nn::Sequential(\n"
5278 " (0): torch::nn::Linear(in_features=8, out_features=4, bias=false)\n"
5279 " (1): torch::nn::Linear(in_features=4, out_features=2, bias=false)\n"
5280 " )\n"
5281 " )\n"
5282 ")");
5283 }
5284 {
5285 AdaptiveLogSoftmaxWithLoss asfm(
5286 AdaptiveLogSoftmaxWithLossOptions(8, 10, {4, 8})
5287 .div_value(2.)
5288 .head_bias(true));
5289 ASSERT_EQ(
5290 c10::str(asfm),
5291 "torch::nn::AdaptiveLogSoftmaxWithLoss(\n"
5292 " (head): torch::nn::Linear(in_features=8, out_features=6, bias=true)\n"
5293 " (tail): torch::nn::ModuleList(\n"
5294 " (0): torch::nn::Sequential(\n"
5295 " (0): torch::nn::Linear(in_features=8, out_features=4, bias=false)\n"
5296 " (1): torch::nn::Linear(in_features=4, out_features=4, bias=false)\n"
5297 " )\n"
5298 " (1): torch::nn::Sequential(\n"
5299 " (0): torch::nn::Linear(in_features=8, out_features=2, bias=false)\n"
5300 " (1): torch::nn::Linear(in_features=2, out_features=2, bias=false)\n"
5301 " )\n"
5302 " )\n"
5303 ")");
5304 }
5305}
5306