1#include <gtest/gtest.h>
2
3#include <c10/util/irange.h>
4#include <torch/torch.h>
5
6#include <test/cpp/api/support.h>
7
8namespace F = torch::nn::functional;
9
10using namespace torch::nn;
11
12struct FunctionalTest : torch::test::SeedingFixture {};
13
14TEST_F(FunctionalTest, Conv1d) {
15 auto x = torch::arange(30, torch::dtype(torch::kFloat).requires_grad(true))
16 .reshape({2, 3, 5});
17 auto weight =
18 torch::arange(18, torch::dtype(torch::kFloat).requires_grad(true))
19 .reshape({2, 3, 3});
20 auto y = F::conv1d(x, weight, F::Conv1dFuncOptions().stride(1));
21 auto expected = torch::tensor(
22 {{{312., 348., 384.}, {798., 915., 1032.}},
23
24 {{852., 888., 924.}, {2553., 2670., 2787.}}},
25 torch::kFloat);
26 ASSERT_TRUE(torch::allclose(y, expected));
27
28 auto y_no_options = F::conv1d(x, weight);
29 ASSERT_TRUE(torch::allclose(y_no_options, expected));
30}
31
32TEST_F(FunctionalTest, Conv2dEven) {
33 auto x = torch::arange(75, torch::dtype(torch::kFloat).requires_grad(true))
34 .reshape({1, 3, 5, 5});
35 auto weight =
36 torch::arange(54, torch::dtype(torch::kFloat).requires_grad(true))
37 .reshape({2, 3, 3, 3});
38 auto y = F::conv2d(x, weight, F::Conv2dFuncOptions().stride(1));
39 auto expected = torch::tensor(
40 {{{{15219., 15570., 15921.},
41 {16974., 17325., 17676.},
42 {18729., 19080., 19431.}},
43
44 {{37818., 38898., 39978.},
45 {43218., 44298., 45378.},
46 {48618., 49698., 50778.}}}},
47 torch::kFloat);
48 ASSERT_TRUE(torch::allclose(y, expected));
49
50 auto y_no_options = F::conv2d(x, weight);
51 ASSERT_TRUE(torch::allclose(y_no_options, expected));
52}
53
54TEST_F(FunctionalTest, Conv2dUneven) {
55 auto x = torch::arange(60, torch::dtype(torch::kFloat).requires_grad(true))
56 .reshape({1, 3, 5, 4});
57 auto weight =
58 torch::arange(36, torch::dtype(torch::kFloat).requires_grad(true))
59 .reshape({2, 3, 3, 2});
60 auto y = F::conv2d(x, weight, F::Conv2dFuncOptions().stride(1));
61 auto expected = torch::tensor(
62 {{{{5289., 5442., 5595.}, {5901., 6054., 6207.}, {6513., 6666., 6819.}},
63
64 {{13227., 13704., 14181.},
65 {15135., 15612., 16089.},
66 {17043., 17520., 17997.}}}},
67 torch::kFloat);
68 ASSERT_TRUE(torch::allclose(y, expected));
69
70 auto y_no_options = F::conv2d(x, weight);
71 ASSERT_TRUE(torch::allclose(y_no_options, expected));
72}
73
74TEST_F(FunctionalTest, Conv3d) {
75 auto x = torch::arange(375, torch::dtype(torch::kFloat).requires_grad(true))
76 .reshape({1, 3, 5, 5, 5});
77 auto weight =
78 torch::arange(162, torch::dtype(torch::kFloat).requires_grad(true))
79 .reshape({2, 3, 3, 3, 3});
80 auto y = F::conv3d(x, weight, F::Conv3dFuncOptions().stride(1));
81 auto expected = torch::tensor(
82 {{{{{700704., 703944., 707184.},
83 {716904., 720144., 723384.},
84 {733104., 736344., 739584.}},
85
86 {{781704., 784944., 788184.},
87 {797904., 801144., 804384.},
88 {814104., 817344., 820584.}},
89
90 {{862704., 865944., 869184.},
91 {878904., 882144., 885384.},
92 {895104., 898344., 901584.}}},
93
94 {{{1724220., 1734021., 1743822.},
95 {1773225., 1783026., 1792827.},
96 {1822230., 1832031., 1841832.}},
97
98 {{1969245., 1979046., 1988847.},
99 {2018250., 2028051., 2037852.},
100 {2067255., 2077056., 2086857.}},
101
102 {{2214270., 2224071., 2233872.},
103 {2263275., 2273076., 2282877.},
104 {2312280., 2322081., 2331882.}}}}},
105 torch::kFloat);
106 ASSERT_TRUE(torch::allclose(y, expected));
107
108 auto y_no_options = F::conv3d(x, weight);
109 ASSERT_TRUE(torch::allclose(y_no_options, expected));
110}
111
112TEST_F(FunctionalTest, MaxPool1d) {
113 auto x = torch::ones({1, 1, 5});
114 auto y = F::max_pool1d(x, F::MaxPool1dFuncOptions(3).stride(2));
115
116 ASSERT_EQ(y.ndimension(), 3);
117 ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2})));
118 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2}));
119}
120
121TEST_F(FunctionalTest, MaxPool2d) {
122 auto x = torch::ones({2, 5, 5});
123 auto y = F::max_pool2d(x, F::MaxPool2dFuncOptions(3).stride(2));
124
125 ASSERT_EQ(y.ndimension(), 3);
126 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
127 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
128}
129
130TEST_F(FunctionalTest, MaxPool2dBackward) {
131 auto input = torch::rand(
132 {1, 2, 4, 4}, torch::dtype(torch::kFloat).requires_grad(true));
133 auto output = F::max_pool2d(input, F::MaxPool2dFuncOptions(2));
134 auto s = output.sum();
135 s.backward();
136 ASSERT_TRUE(input.sizes() == input.grad().sizes());
137}
138
139TEST_F(FunctionalTest, MaxPool3d) {
140 auto x = torch::ones({2, 5, 5, 5});
141 auto y = F::max_pool3d(x, F::MaxPool3dFuncOptions(3).stride(2));
142
143 ASSERT_EQ(y.ndimension(), 4);
144 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
145 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
146}
147
148TEST_F(FunctionalTest, AvgPool1d) {
149 auto x = torch::ones({1, 1, 5});
150 auto y = F::avg_pool1d(x, F::AvgPool1dFuncOptions(3).stride(2));
151
152 ASSERT_EQ(y.ndimension(), 3);
153 ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2})));
154 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2}));
155}
156
157TEST_F(FunctionalTest, AvgPool2d) {
158 auto x = torch::ones({2, 5, 5});
159 auto y = F::avg_pool2d(x, F::AvgPool2dFuncOptions(3).stride(2));
160
161 ASSERT_EQ(y.ndimension(), 3);
162 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
163 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
164}
165
166TEST_F(FunctionalTest, AvgPool3d) {
167 auto x = torch::ones({2, 5, 5, 5});
168 auto y = F::avg_pool3d(x, F::AvgPool3dFuncOptions(3).stride(2));
169
170 ASSERT_EQ(y.ndimension(), 4);
171 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
172 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
173}
174
175TEST_F(FunctionalTest, FractionalMaxPool2d) {
176 auto x = torch::ones({2, 5, 5});
177 auto y = F::fractional_max_pool2d(
178 x, F::FractionalMaxPool2dFuncOptions(3).output_size(2));
179
180 ASSERT_EQ(y.ndimension(), 3);
181 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
182 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2}));
183
184 auto y_with_indices = F::fractional_max_pool2d_with_indices(
185 x, F::FractionalMaxPool2dFuncOptions(3).output_size(2));
186 ASSERT_TRUE(torch::equal(y, std::get<0>(y_with_indices)));
187 ASSERT_TRUE(torch::allclose(
188 std::get<1>(y_with_indices),
189 torch::tensor({{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}})));
190 ASSERT_EQ(
191 std::get<1>(y_with_indices).sizes(), std::vector<int64_t>({2, 2, 2}));
192
193 auto x1 = torch::ones({2, 2, 5, 5});
194 auto y1 = F::fractional_max_pool2d(
195 x1, F::FractionalMaxPool2dFuncOptions(3).output_size(2));
196
197 ASSERT_EQ(y1.ndimension(), 4);
198 ASSERT_TRUE(torch::allclose(y1, torch::ones({2, 2, 2, 2})));
199 ASSERT_EQ(y1.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
200
201 auto y1_with_indices = F::fractional_max_pool2d_with_indices(
202 x1, F::FractionalMaxPool2dFuncOptions(3).output_size(2));
203 ASSERT_TRUE(torch::equal(y1, std::get<0>(y1_with_indices)));
204 ASSERT_TRUE(torch::allclose(
205 std::get<1>(y1_with_indices),
206 torch::tensor(
207 {{{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}},
208 {{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}}})));
209 ASSERT_EQ(
210 std::get<1>(y1_with_indices).sizes(), std::vector<int64_t>({2, 2, 2, 2}));
211}
212
213TEST_F(FunctionalTest, FractionalMaxPool3d) {
214 auto x = torch::ones({2, 5, 5, 5});
215 auto y = F::fractional_max_pool3d(
216 x, F::FractionalMaxPool3dFuncOptions(3).output_size(2));
217
218 ASSERT_EQ(y.ndimension(), 4);
219 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
220 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2}));
221
222 auto y_with_indices = F::fractional_max_pool3d_with_indices(
223 x, F::FractionalMaxPool3dFuncOptions(3).output_size(2));
224 ASSERT_TRUE(torch::equal(y, std::get<0>(y_with_indices)));
225 ASSERT_TRUE(torch::allclose(
226 std::get<1>(y_with_indices),
227 torch::tensor(
228 {{{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}},
229 {{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}}})));
230 ASSERT_EQ(
231 std::get<1>(y_with_indices).sizes(), std::vector<int64_t>({2, 2, 2, 2}));
232}
233
234TEST_F(FunctionalTest, LPPool1d) {
235 int norm_type = 2;
236 int stride = 2;
237 int kernel_size = 3;
238
239 auto x = torch::ones({1, 1, 5});
240 auto y = F::lp_pool1d(
241 x, F::LPPool1dFuncOptions(norm_type, kernel_size).stride(stride));
242 auto expected =
243 (torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) *
244 kernel_size)
245 .pow(1. / norm_type);
246
247 ASSERT_EQ(y.ndimension(), 3);
248 ASSERT_TRUE(torch::allclose(y, expected));
249 ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2}));
250}
251
252TEST_F(FunctionalTest, LPPool2d) {
253 int norm_type = 2;
254 int stride = 2;
255 std::vector<int64_t> kernel_size({2, 3});
256
257 auto x = torch::ones({1, 2, 5});
258 auto y = F::lp_pool2d(
259 x, F::LPPool2dFuncOptions(norm_type, kernel_size).stride(stride));
260 auto expected =
261 (torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) *
262 (kernel_size[0] * kernel_size[1]))
263 .pow(1. / norm_type);
264
265 ASSERT_EQ(y.ndimension(), 3);
266 ASSERT_TRUE(torch::allclose(y, expected));
267 ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2}));
268}
269
270TEST_F(FunctionalTest, CosineSimilarity) {
271 auto input1 = torch::tensor({{1, 2, 3}, {4, 5, 6}}, torch::kFloat);
272 auto input2 = torch::tensor({{1, 8, 3}, {2, 1, 6}}, torch::kFloat);
273 auto output = F::cosine_similarity(
274 input1, input2, F::CosineSimilarityFuncOptions().dim(1));
275 auto expected = torch::tensor({0.8078, 0.8721}, torch::kFloat);
276 ASSERT_TRUE(output.allclose(expected, 1e-04));
277}
278
279TEST_F(FunctionalTest, SmoothL1LossDefaultOptions) {
280 auto input = torch::tensor(
281 {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
282 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
283 auto output = F::smooth_l1_loss(input, target);
284 auto expected = torch::tensor(0.0233335, torch::kFloat);
285 auto s = output.sum();
286 s.backward();
287 ASSERT_TRUE(output.allclose(expected));
288 ASSERT_TRUE(input.sizes() == input.grad().sizes());
289}
290
291TEST_F(FunctionalTest, SmoothL1LossBeta) {
292 auto input = torch::tensor(
293 {0.1, 1.5, 10.0}, torch::dtype(torch::kFloat).requires_grad(true));
294 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
295 auto output =
296 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,bugprone-argument-comment)
297 F::smooth_l1_loss(
298 input, target, /*reduction=*/torch::kMean, /*beta=*/0.5);
299 auto expected = torch::tensor(1.67, torch::kFloat);
300 auto s = output.sum();
301 s.backward();
302 ASSERT_TRUE(output.allclose(expected));
303 ASSERT_TRUE(input.sizes() == input.grad().sizes());
304}
305
306TEST_F(FunctionalTest, SmoothL1LossNoReduction) {
307 auto input = torch::tensor(
308 {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
309 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
310 auto output =
311 // NOLINTNEXTLINE(bugprone-argument-comment)
312 F::smooth_l1_loss(input, target, /*reduction=*/torch::kNone);
313 auto expected = torch::tensor({0.005, 0.02, 0.045}, torch::kFloat);
314 auto s = output.sum();
315 s.backward();
316 ASSERT_TRUE(output.allclose(expected));
317 ASSERT_TRUE(input.sizes() == input.grad().sizes());
318}
319
320TEST_F(FunctionalTest, HuberLossDefaultOptions) {
321 auto input = torch::tensor(
322 {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
323 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
324 auto output = F::huber_loss(input, target);
325 auto expected = torch::tensor(0.0233335, torch::kFloat);
326 auto s = output.sum();
327 s.backward();
328 ASSERT_TRUE(output.allclose(expected));
329 ASSERT_TRUE(input.sizes() == input.grad().sizes());
330}
331
332TEST_F(FunctionalTest, HuberLossDelta) {
333 auto input = torch::tensor(
334 {0.1, 1.5, 10.0}, torch::dtype(torch::kFloat).requires_grad(true));
335 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
336 auto options = F::HuberLossFuncOptions().reduction(torch::kMean).delta(0.5);
337 auto output = F::huber_loss(input, target, options);
338 auto expected = torch::tensor(1.67 * 0.5, torch::kFloat);
339 auto s = output.sum();
340 s.backward();
341 ASSERT_TRUE(output.allclose(expected));
342 ASSERT_TRUE(input.sizes() == input.grad().sizes());
343}
344
345TEST_F(FunctionalTest, HuberLossNoReduction) {
346 auto input = torch::tensor(
347 {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
348 auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
349 auto options = F::HuberLossFuncOptions().reduction(torch::kNone);
350 auto output = F::huber_loss(input, target, options);
351 auto expected = torch::tensor({0.005, 0.02, 0.045}, torch::kFloat);
352 auto s = output.sum();
353 s.backward();
354 ASSERT_TRUE(output.allclose(expected));
355 ASSERT_TRUE(input.sizes() == input.grad().sizes());
356}
357
358TEST_F(FunctionalTest, SoftMarginLossDefaultOptions) {
359 auto input = torch::tensor(
360 {2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true));
361 auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat);
362 auto output = F::soft_margin_loss(input, target);
363 auto expected = torch::tensor({1.3767317}, torch::kFloat);
364 auto s = output.sum();
365 s.backward();
366
367 ASSERT_TRUE(output.allclose(expected));
368 ASSERT_EQ(input.sizes(), input.grad().sizes());
369}
370
371TEST_F(FunctionalTest, MultiLabelSoftMarginLossDefaultOptions) {
372 auto input = torch::tensor(
373 {{0., 2., 2., 0.}, {2., 1., 0., 1.}},
374 torch::dtype(torch::kFloat).requires_grad(true));
375 auto target =
376 torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat);
377 auto output = F::multilabel_soft_margin_loss(input, target);
378 auto expected = torch::tensor({0.7608436}, torch::kFloat);
379 auto s = output.sum();
380 s.backward();
381
382 ASSERT_TRUE(output.allclose(expected));
383 ASSERT_EQ(input.sizes(), input.grad().sizes());
384}
385
386TEST_F(FunctionalTest, SoftMarginLossNoReduction) {
387 auto input = torch::tensor(
388 {2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true));
389 auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat);
390 auto output = F::soft_margin_loss(input, target, torch::kNone);
391 auto expected = torch::tensor(
392 {2.1269281, 0.01814993, 0.3132617, 3.0485873}, torch::kFloat);
393 auto s = output.sum();
394 s.backward();
395
396 ASSERT_TRUE(output.allclose(expected));
397 ASSERT_EQ(input.sizes(), input.grad().sizes());
398}
399
400TEST_F(FunctionalTest, MultiLabelSoftMarginLossWeightedNoReduction) {
401 auto input = torch::tensor(
402 {{0., 2., 2., 0.}, {2., 1., 0., 1.}},
403 torch::dtype(torch::kFloat).requires_grad(true));
404 auto target =
405 torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat);
406 auto weight = torch::tensor({0.1, 0.6, 0.4, 0.8}, torch::kFloat);
407 auto options = F::MultilabelSoftMarginLossFuncOptions()
408 .reduction(torch::kNone)
409 .weight(weight);
410 auto output = F::multilabel_soft_margin_loss(input, target, options);
411 auto expected = torch::tensor({0.4876902, 0.3321295}, torch::kFloat);
412 auto s = output.sum();
413 s.backward();
414
415 ASSERT_TRUE(output.allclose(expected));
416 ASSERT_EQ(input.sizes(), input.grad().sizes());
417}
418
419TEST_F(FunctionalTest, PairwiseDistance) {
420 auto input1 = torch::tensor({{1, 2, 3}, {4, 5, 6}}, torch::kFloat);
421 auto input2 = torch::tensor({{1, 8, 3}, {2, 1, 6}}, torch::kFloat);
422 auto output = F::pairwise_distance(
423 input1, input2, F::PairwiseDistanceFuncOptions().p(1));
424 auto expected = torch::tensor({6, 6}, torch::kFloat);
425 ASSERT_TRUE(output.allclose(expected));
426}
427
428TEST_F(FunctionalTest, PDist) {
429 {
430 auto input = torch::tensor({{-1.0, -5.0, -1.0}, {2.0, 4.0, 6.0}});
431 auto output = F::pdist(input);
432 auto expected = torch::tensor({11.7898});
433 ASSERT_TRUE(output.allclose(expected));
434 }
435 {
436 auto input = torch::tensor({{1.0, -1.0}, {1.0, 3.0}, {3.0, 3.0}});
437 auto output = F::pdist(input, 1.5);
438 auto expected = torch::tensor({4.0, 4.8945, 2.0});
439 ASSERT_TRUE(output.allclose(expected));
440 }
441}
442
443TEST_F(FunctionalTest, AdaptiveMaxPool1d) {
444 auto x = torch::ones({1, 1, 5});
445 auto y = F::adaptive_max_pool1d(x, F::AdaptiveMaxPool1dFuncOptions(3));
446
447 ASSERT_EQ(y.ndimension(), 3);
448 ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 3})));
449 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3}));
450}
451
452TEST_F(FunctionalTest, AdaptiveMaxPool2d) {
453 auto x = torch::ones({2, 5, 5});
454 auto y = F::adaptive_max_pool2d(x, F::AdaptiveMaxPool2dFuncOptions(3));
455
456 ASSERT_EQ(y.ndimension(), 3);
457 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3})));
458 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3}));
459}
460
461TEST_F(FunctionalTest, AdaptiveMaxPool3d) {
462 auto x = torch::ones({2, 5, 5, 5});
463 auto y = F::adaptive_max_pool3d(x, F::AdaptiveMaxPool3dFuncOptions(3));
464
465 ASSERT_EQ(y.ndimension(), 4);
466 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3, 3})));
467 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3, 3}));
468}
469
470TEST_F(FunctionalTest, AdaptiveAvgPool1d) {
471 auto x = torch::ones({1, 1, 5});
472 auto y = F::adaptive_avg_pool1d(x, F::AdaptiveAvgPool1dFuncOptions(3));
473
474 ASSERT_EQ(y.ndimension(), 3);
475 ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 3})));
476 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3}));
477}
478
479TEST_F(FunctionalTest, AdaptiveAvgPool2d) {
480 auto x = torch::ones({2, 5, 5});
481 auto y = F::adaptive_avg_pool2d(x, F::AdaptiveAvgPool2dFuncOptions(3));
482
483 ASSERT_EQ(y.ndimension(), 3);
484 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3})));
485 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3}));
486}
487
488TEST_F(FunctionalTest, AdaptiveAvgPool3d) {
489 auto x = torch::ones({2, 5, 5, 5});
490 auto y = F::adaptive_avg_pool3d(x, F::AdaptiveAvgPool3dFuncOptions(3));
491
492 ASSERT_EQ(y.ndimension(), 4);
493 ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3, 3})));
494 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3, 3}));
495}
496
497TEST_F(FunctionalTest, L1Loss) {
498 auto input = torch::randn({5, 6}, torch::requires_grad());
499 auto target = torch::empty({5, 6}).random_(2);
500 auto output = F::l1_loss(torch::sigmoid(input), target);
501 auto s = output.sum();
502 s.backward();
503
504 ASSERT_EQ(output.sizes(), torch::IntArrayRef());
505 ASSERT_EQ(input.sizes(), input.grad().sizes());
506}
507
508TEST_F(FunctionalTest, MSELoss) {
509 auto input = torch::randn({5, 6}, torch::requires_grad());
510 auto target = torch::empty({5, 6}).random_(2);
511 auto output = F::mse_loss(torch::sigmoid(input), target);
512 auto s = output.sum();
513 s.backward();
514
515 ASSERT_EQ(output.sizes(), torch::IntArrayRef());
516 ASSERT_EQ(input.sizes(), input.grad().sizes());
517}
518
519TEST_F(FunctionalTest, BCELoss) {
520 auto input = torch::randn({5, 6}, torch::requires_grad());
521 auto target = torch::empty({5, 6}).random_(2);
522 auto output = F::binary_cross_entropy(torch::sigmoid(input), target);
523 auto s = output.sum();
524 s.backward();
525
526 ASSERT_EQ(output.sizes(), torch::IntArrayRef());
527 ASSERT_EQ(input.sizes(), input.grad().sizes());
528}
529
530TEST_F(FunctionalTest, KLDivLoss) {
531 KLDivLoss loss;
532 auto input = torch::randn({5, 6}, torch::requires_grad());
533 auto target = torch::empty({5, 6}).random_(2);
534 auto output = F::kl_div(torch::sigmoid(input), target);
535 auto s = output.sum();
536 s.backward();
537
538 ASSERT_EQ(output.sizes(), torch::IntArrayRef());
539 ASSERT_EQ(input.sizes(), input.grad().sizes());
540}
541
542TEST_F(FunctionalTest, HingeEmbeddingLoss) {
543 auto input = torch::tensor({{2, 22, 4}, {20, 10, 0}}, torch::kFloat);
544 auto target = torch::tensor({{2, 6, 4}, {1, 10, 0}}, torch::kFloat);
545 auto output = F::hinge_embedding_loss(
546 input, target, F::HingeEmbeddingLossFuncOptions().margin(2));
547 auto expected = torch::tensor({10}, torch::kFloat);
548
549 ASSERT_TRUE(output.allclose(expected));
550}
551
552TEST_F(FunctionalTest, GridSample) {
553 auto input =
554 torch::arange(9, torch::kFloat).view(std::vector<int64_t>({1, 1, 3, 3}));
555 auto grid = torch::tensor(
556 {{{{-2., -1.}, {-1., -1.}, {0., -1.}},
557 {{-1., 0.}, {0., 0.}, {1., 0.}},
558 {{0., 1.}, {1., 1.}, {2., 1.}}}},
559 torch::kFloat);
560
561 // bilinear, zeros, true
562 auto options = F::GridSampleFuncOptions()
563 .mode(torch::kBilinear)
564 .padding_mode(torch::kZeros)
565 .align_corners(true);
566 auto output = F::grid_sample(input, grid, options);
567 auto expected = torch::tensor(
568 {{{{0., 0., 1.}, {3., 4., 5.}, {7., 8., 0.}}}}, torch::kFloat);
569
570 ASSERT_TRUE(output.allclose(expected));
571
572 // bilinear, zeros, false
573 options = F::GridSampleFuncOptions()
574 .mode(torch::kBilinear)
575 .padding_mode(torch::kZeros)
576 .align_corners(false);
577 output = F::grid_sample(input, grid, options);
578 expected = torch::tensor(
579 {{{{0., 0., 0.5}, {1.5, 4., 2.5}, {3.5, 2., 0.}}}}, torch::kFloat);
580
581 ASSERT_TRUE(output.allclose(expected));
582
583 // default options (bilinear, zeros, false) same result as above
584 output = F::grid_sample(input, grid);
585
586 ASSERT_TRUE(output.allclose(expected));
587
588 // nearest, zeros, true
589 options = F::GridSampleFuncOptions()
590 .mode(torch::kNearest)
591 .padding_mode(torch::kZeros)
592 .align_corners(true);
593 output = F::grid_sample(input, grid, options);
594 expected = torch::tensor(
595 {{{{0., 0., 1.}, {3., 4., 5.}, {7., 8., 0.}}}}, torch::kFloat);
596
597 ASSERT_TRUE(output.allclose(expected));
598
599 // bilinear, border, true
600 options = F::GridSampleFuncOptions()
601 .mode(torch::kBilinear)
602 .padding_mode(torch::kBorder)
603 .align_corners(true);
604 output = F::grid_sample(input, grid, options);
605 expected = torch::tensor(
606 {{{{0., 0., 1.}, {3., 4., 5.}, {7., 8., 8.}}}}, torch::kFloat);
607
608 ASSERT_TRUE(output.allclose(expected));
609
610 // bilinear, reflection, true
611 options = F::GridSampleFuncOptions()
612 .mode(torch::kBilinear)
613 .padding_mode(torch::kReflection)
614 .align_corners(true);
615 output = F::grid_sample(input, grid, options);
616 expected = torch::tensor(
617 {{{{1., 0., 1.}, {3., 4., 5.}, {7., 8., 7.}}}}, torch::kFloat);
618
619 ASSERT_TRUE(output.allclose(expected));
620}
621
622TEST_F(FunctionalTest, AffineGrid) {
623 {
624 // 2D affine.
625 auto theta = torch::arange(1., 13).view(std::vector<int64_t>({2, 2, 3}));
626 auto size = std::vector<int64_t>({2, 3, 2, 2});
627 auto align_corners = true;
628 auto output = F::affine_grid(theta, size, !align_corners);
629 auto expected = torch::tensor(
630 {{{{1.50, 1.50}, {2.50, 5.50}}, {{3.50, 6.50}, {4.50, 10.50}}},
631 {{{1.50, 1.50}, {8.50, 11.50}}, {{9.50, 12.50}, {16.50, 22.50}}}});
632 auto output_aligned = F::affine_grid(theta, size, align_corners);
633 auto expected_aligned = torch::tensor(
634 {{{{0.0, -3.0}, {2.0, 5.0}}, {{4.0, 7.0}, {6.0, 15.0}}},
635 {{{-6.0, -9.0}, {8.0, 11.0}}, {{10.0, 13.0}, {24.0, 33.0}}}});
636
637 ASSERT_TRUE(output.allclose(expected));
638 ASSERT_TRUE(output_aligned.allclose(expected_aligned));
639 }
640 {
641 // 3D affine.
642 auto theta = torch::arange(1., 13).view(std::vector<int64_t>({1, 3, 4}));
643 auto size = std::vector<int64_t>({1, 1, 3, 2, 2});
644 auto align_corners = true;
645 auto output = F::affine_grid(theta, size, !align_corners);
646 auto expected = torch::tensor(
647 {{{{{0.5000, -2.1667, -4.8333}, {1.5000, 2.8333, 4.1667}},
648 {{2.5000, 3.8333, 5.1667}, {3.5000, 8.8333, 14.1667}}},
649 {{{2.5000, 2.5000, 2.5000}, {3.5000, 7.5000, 11.5000}},
650 {{4.5000, 8.5000, 12.5000}, {5.5000, 13.5000, 21.5000}}},
651 {{{4.5000, 7.1667, 9.8333}, {5.5000, 12.1667, 18.8333}},
652 {{6.5000, 13.1667, 19.8333}, {7.5000, 18.1667, 28.8333}}}}});
653 auto output_aligned = F::affine_grid(theta, size, align_corners);
654 auto expected_aligned = torch::tensor(
655 {{{{{-2.0, -10.0, -18.0}, {0.0, 0.0, 0.0}},
656 {{2.0, 2.0, 2.0}, {4.0, 12.0, 20.0}}},
657 {{{1.0, -3.0, -7.0}, {3.0, 7.0, 11.0}},
658 {{5.0, 9.0, 13.0}, {7.0, 19.0, 31.0}}},
659 {{{4.0, 4.0, 4.0}, {6.0, 14.0, 22.0}},
660 {{8.0, 16.0, 24.0}, {10.0, 26.0, 42.0}}}}});
661
662 ASSERT_TRUE(output.allclose(expected, 1e-2));
663 ASSERT_TRUE(output_aligned.allclose(expected_aligned));
664 }
665 {
666 auto theta = torch::empty({1, 2, 3}, torch::kDouble);
667 auto size = std::vector<int64_t>({1, 1, 2, 2});
668 ASSERT_THROWS_WITH(
669 F::affine_grid(torch::empty({2, 2, 3}), {-1, 1, 2, 2}),
670 "Expected non-zero, positive output size. Got [-1, 1, 2, 2]");
671 ASSERT_THROWS_WITH(
672 F::affine_grid(torch::empty({2, 2, 3}, torch::kInt), size),
673 "Expected theta to have floating point type, but got int");
674 ASSERT_THROWS_WITH(
675 F::affine_grid(theta[0], size),
676 "Expected a batch of 2D affine matrices of shape Nx2x3 for size "
677 "[1, 1, 2, 2]. Got [2, 3].");
678 ASSERT_THROWS_WITH(
679 F::affine_grid(theta.unsqueeze(0), size),
680 "Expected a batch of 2D affine matrices of shape Nx2x3 for size "
681 "[1, 1, 2, 2]. Got [1, 1, 2, 3].");
682 ASSERT_THROWS_WITH(
683 F::affine_grid(theta.repeat({1, 2, 1}), size),
684 "Expected a batch of 2D affine matrices of shape Nx2x3 for size "
685 "[1, 1, 2, 2]. Got [1, 4, 3].");
686 ASSERT_THROWS_WITH(
687 F::affine_grid(theta.repeat({1, 1, 2}), size),
688 "Expected a batch of 2D affine matrices of shape Nx2x3 for size "
689 "[1, 1, 2, 2]. Got [1, 2, 6].");
690 }
691 {
692 auto theta = torch::empty({1, 3, 4}, torch::kDouble);
693 auto size = std::vector<int64_t>({1, 1, 2, 2, 3});
694 ASSERT_THROWS_WITH(
695 F::affine_grid(theta[0], size),
696 "Expected a batch of 3D affine matrices of shape Nx3x4 for size "
697 "[1, 1, 2, 2, 3]. Got [3, 4].");
698 ASSERT_THROWS_WITH(
699 F::affine_grid(theta.unsqueeze(0), size),
700 "Expected a batch of 3D affine matrices of shape Nx3x4 for size "
701 "[1, 1, 2, 2, 3]. Got [1, 1, 3, 4].");
702 ASSERT_THROWS_WITH(
703 F::affine_grid(theta.repeat({1, 2, 1}), size),
704 "Expected a batch of 3D affine matrices of shape Nx3x4 for size "
705 "[1, 1, 2, 2, 3]. Got [1, 6, 4].");
706 ASSERT_THROWS_WITH(
707 F::affine_grid(theta.repeat({1, 1, 2}), size),
708 "Expected a batch of 3D affine matrices of shape Nx3x4 for size "
709 "[1, 1, 2, 2, 3]. Got [1, 3, 8].");
710 ASSERT_THROWS_WITH(
711 F::affine_grid(theta, {1, 1, 1, 2, 2, 3}),
712 "affine_grid only supports 4D and 5D sizes, for 2D and 3D affine "
713 "transforms, respectively. Got size [1, 1, 1, 2, 2, 3]");
714 ASSERT_THROWS_WITH(
715 F::affine_grid(theta, {1, 1}),
716 "affine_grid only supports 4D and 5D sizes, for 2D and 3D affine "
717 "transforms, respectively. Got size [1, 1]");
718 }
719}
720
721TEST_F(FunctionalTest, MultiMarginLoss) {
722 auto weight = torch::tensor({0.3, 0.3, 0.4}, torch::kFloat);
723 auto input = torch::tensor(
724 {{0.2, 0.2, 0.6}, {0.1, 0.8, 0.1}, {0.9, 0.09, 0.01}},
725 torch::dtype(torch::kFloat).requires_grad(true));
726 auto target = torch::tensor({2, 1, 0}, torch::kLong);
727 auto output = F::multi_margin_loss(
728 input, target, F::MultiMarginLossFuncOptions().margin(2).weight(weight));
729 auto expected = torch::tensor({0.305556}, torch::kFloat);
730
731 ASSERT_TRUE(output.allclose(expected, 1e-04));
732}
733
734TEST_F(FunctionalTest, CosineEmbeddingLoss) {
735 auto input1 = torch::tensor({{2, 3, 4}, {6, 2, 4}});
736 auto input2 = torch::tensor({{2, 3, 5}, {9, 12, 0}});
737 auto target = torch::tensor({1, -1});
738 auto output = F::cosine_embedding_loss(
739 input1, input2, target, F::CosineEmbeddingLossFuncOptions().margin(0.5));
740 auto expected = torch::tensor({0.1004}, torch::kFloat);
741
742 ASSERT_TRUE(output.allclose(expected, 1e-4));
743}
744
745TEST_F(FunctionalTest, MultiLabelMarginLossDefaultOptions) {
746 auto input = torch::tensor(
747 {{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true));
748 auto target = torch::tensor({{3, 0, -1, 1}}, torch::kLong);
749 auto output = F::multilabel_margin_loss(input, target);
750 auto expected = torch::tensor({0.8500}, torch::kFloat);
751 auto s = output.sum();
752 s.backward();
753
754 ASSERT_TRUE(output.allclose(expected));
755 ASSERT_EQ(input.sizes(), input.grad().sizes());
756}
757
758TEST_F(FunctionalTest, MultiLabelMarginLossNoReduction) {
759 auto input = torch::tensor(
760 {{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true));
761 auto target = torch::tensor({{3, 0, -1, 1}}, torch::kLong);
762 auto output = F::multilabel_margin_loss(input, target, torch::kNone);
763 auto expected = torch::tensor({0.8500}, torch::kFloat);
764 auto s = output.sum();
765 s.backward();
766
767 ASSERT_TRUE(output.allclose(expected));
768 ASSERT_EQ(input.sizes(), input.grad().sizes());
769}
770
771TEST_F(FunctionalTest, TripletMarginLoss) {
772 auto anchor = torch::tensor({{3., 3.}}, torch::kFloat);
773 auto positive = torch::tensor({{2., 2.}}, torch::kFloat);
774 auto negative = torch::tensor({{0., 0.}}, torch::kFloat);
775 auto output = F::triplet_margin_loss(
776 anchor,
777 positive,
778 negative,
779 F::TripletMarginLossFuncOptions().margin(1.0));
780 auto expected = torch::tensor({0.}, torch::kFloat);
781
782 ASSERT_TRUE(output.allclose(expected, 1e-04));
783}
784
785TEST_F(FunctionalTest, TripletMarginWithDistanceLossDefaultParity) {
786 // Check that if we use torch::pairwise_distance with the default
787 // TripletMarginLoss options as our distance function, the outputs
788 // are equal (i.e., equal under defaults).
789
790 std::vector<TripletMarginWithDistanceLossOptions::reduction_t> reductions = {
791 torch::kSum, torch::kMean, torch::kNone};
792 std::vector<float> margins = {0.5, 1.0, 1.5};
793 std::vector<bool> swaps = {true, false};
794
795 for (auto& reduction : reductions) {
796 for (auto& margin : margins) {
797 for (const auto& swap : swaps) {
798 auto anchor = torch::randn(
799 {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
800 auto positive = torch::randn(
801 {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
802 auto negative = torch::randn(
803 {100, 128}, torch::dtype(torch::kFloat).requires_grad(true));
804
805 auto basicOptions = F::TripletMarginLossFuncOptions()
806 .reduction(reduction)
807 .margin(margin)
808 .swap(swap);
809 auto distanceOptions = F::TripletMarginWithDistanceLossFuncOptions()
810 .reduction(reduction)
811 .margin(margin)
812 .swap(swap);
813 TripletMarginLoss basicLoss(basicOptions);
814 TripletMarginWithDistanceLoss distanceLoss(distanceOptions);
815
816 auto basicOutput =
817 F::triplet_margin_loss(anchor, positive, negative, basicOptions);
818 auto distanceOutput = F::triplet_margin_with_distance_loss(
819 anchor, positive, negative, distanceOptions);
820
821 ASSERT_TRUE(distanceOutput.allclose(basicOutput, 1e-6, 1e-6));
822
823 // handle for torch::kNone reduction
824 auto sum = distanceOutput.sum();
825 sum.backward();
826 ASSERT_EQ(anchor.sizes(), anchor.grad().sizes());
827 ASSERT_EQ(positive.sizes(), positive.grad().sizes());
828 ASSERT_EQ(negative.sizes(), negative.grad().sizes());
829 }
830 }
831 }
832}
833
834TEST_F(FunctionalTest, NLLLoss) {
835 auto input = torch::tensor(
836 {{-0.1315, -3.1315, -2.5315},
837 {-3.7038, -0.1038, -2.6038},
838 {-2.3422, -1.3422, -0.4422}},
839 torch::kFloat);
840 auto target = torch::tensor({1, 0, 2}, torch::kLong);
841 auto output = F::nll_loss(
842 input,
843 target,
844 F::NLLLossFuncOptions().ignore_index(-100).reduction(torch::kMean));
845 auto expected = torch::tensor(2.4258, torch::kFloat);
846 ASSERT_TRUE(output.allclose(expected, 1e-04));
847 ASSERT_TRUE(F::nll_loss(input, target).allclose(expected, 1e-04));
848}
849
850TEST_F(FunctionalTest, CrossEntropy) {
851 auto input = torch::tensor({{3., 3.}, {2., 2.}}, torch::kFloat);
852 auto target = torch::tensor({0, 1}, torch::kLong);
853 auto output = F::cross_entropy(
854 input,
855 target,
856 F::CrossEntropyFuncOptions().ignore_index(-100).reduction(torch::kMean));
857 auto expected = torch::tensor(0.6931, torch::kFloat);
858
859 ASSERT_TRUE(output.allclose(expected, 1e-04));
860 ASSERT_TRUE(F::cross_entropy(input, target).allclose(expected, 1e-04));
861
862 // label smoothing with class indices
863 input = torch::tensor({{3., 1.}, {1., 2.}}, torch::kFloat);
864 output = F::cross_entropy(
865 input,
866 target,
867 F::CrossEntropyFuncOptions().label_smoothing(0.15).reduction(
868 torch::kMean));
869 expected = torch::tensor(0.3326, torch::kFloat);
870 ASSERT_TRUE(output.allclose(expected, 1e-04));
871
872 // label smoothing with target probabilities
873 target = torch::tensor({{0.8, 0.2}, {0.1, 0.9}}, torch::kFloat);
874 output = F::cross_entropy(
875 input,
876 target,
877 F::CrossEntropyFuncOptions().label_smoothing(0.2).reduction(
878 torch::kMean));
879 expected = torch::tensor(0.5701, torch::kFloat);
880 ASSERT_TRUE(output.allclose(expected, 1e-04));
881}
882
883TEST_F(FunctionalTest, MaxUnpool1d) {
884 auto x = torch::tensor(
885 {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
886 auto indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
887 auto y = F::max_unpool1d(x, indices, F::MaxUnpool1dFuncOptions(3));
888
889 ASSERT_EQ(y.ndimension(), 3);
890 ASSERT_TRUE(torch::allclose(
891 y, torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat)));
892 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 9}));
893
894 x = torch::tensor(
895 {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
896 indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
897 y = F::max_unpool1d(
898 x,
899 indices,
900 F::MaxUnpool1dFuncOptions(3).output_size(
901 std::vector<int64_t>({1, 1, 9})));
902
903 ASSERT_EQ(y.ndimension(), 3);
904 ASSERT_TRUE(torch::allclose(
905 y, torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat)));
906 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 9}));
907
908 x = torch::tensor(
909 {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true));
910 indices = torch::tensor({{{1, 3, 4}}}, torch::kLong);
911 y = F::max_unpool1d(
912 x, indices, F::MaxUnpool1dFuncOptions(3).stride(2).padding(1));
913
914 ASSERT_EQ(y.ndimension(), 3);
915 ASSERT_TRUE(
916 torch::allclose(y, torch::tensor({{{0, 2, 0, 4, 5}}}, torch::kFloat)));
917 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 5}));
918}
919
920TEST_F(FunctionalTest, MaxUnpool2d) {
921 auto indices = torch::tensor(
922 {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
923 {{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}},
924 torch::kLong);
925 auto x = torch::tensor(
926 {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}},
927 {{{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}}},
928 torch::dtype(torch::kFloat).requires_grad(true));
929 auto y = F::max_unpool2d(
930 x, indices, F::MaxUnpool2dFuncOptions(3).stride(2).padding(1));
931
932 ASSERT_EQ(y.dim(), 4);
933 ASSERT_TRUE(torch::allclose(
934 y,
935 torch::tensor(
936 {{{{0, 0, 0, 0, 0},
937 {0, 6, 0, 8, 9},
938 {0, 0, 0, 0, 0},
939 {0, 16, 0, 18, 19},
940 {0, 21, 0, 23, 24}}},
941 {{{0, 0, 0, 0, 0},
942 {0, 31, 0, 33, 34},
943 {0, 0, 0, 0, 0},
944 {0, 41, 0, 43, 44},
945 {0, 46, 0, 48, 49}}}},
946 torch::kFloat)));
947 ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 1, 5, 5}));
948}
949
950TEST_F(FunctionalTest, MaxUnpool3d) {
951 auto indices = torch::tensor({{{{{26}}}}}, torch::kLong);
952 auto x = torch::tensor(
953 {{{{{26}}}}}, torch::dtype(torch::kFloat).requires_grad(true));
954 auto y = F::max_unpool3d(x, indices, F::MaxUnpool3dFuncOptions(3));
955
956 ASSERT_EQ(y.dim(), 5);
957 ASSERT_TRUE(torch::allclose(
958 y,
959 torch::tensor(
960 {{{{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
961 {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}},
962 {{0, 0, 0}, {0, 0, 0}, {0, 0, 26}}}}},
963 torch::kFloat)));
964 ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3, 3, 3}));
965}
966
967TEST_F(FunctionalTest, ELU) {
968 const auto size = 3;
969 for (const auto inplace : {false, true}) {
970 for (const auto alpha : {0.0, 0.42, 1.0, 4.2, 42.42}) {
971 auto x = torch::linspace(-10.0, 10.0, size * size * size);
972 x.resize_({size, size, size});
973 auto x_bf16 =
974 torch::linspace(-10.0, 10.0, size * size * size).to(torch::kBFloat16);
975 x_bf16.resize_({size, size, size});
976
977 auto y_exp = torch::max(torch::zeros_like(x), x) +
978 torch::min(torch::zeros_like(x), alpha * (torch::exp(x) - 1.0));
979 auto y = F::elu(x, F::ELUFuncOptions().alpha(alpha).inplace(inplace));
980 auto y_bf16 =
981 F::elu(x_bf16, F::ELUFuncOptions().alpha(alpha).inplace(inplace));
982
983 ASSERT_EQ(y.ndimension(), 3);
984 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
985 ASSERT_TRUE(torch::allclose(y, y_exp));
986 ASSERT_TRUE(torch::allclose(y_bf16.to(torch::kFloat), y, 1e-2, 1e-2));
987 if (inplace) {
988 ASSERT_TRUE(torch::allclose(x, y_exp));
989 ASSERT_TRUE(torch::allclose(x_bf16.to(torch::kFloat), y, 1e-2, 1e-2));
990 }
991 }
992 }
993 ASSERT_TRUE(F::elu(torch::tensor(1.)).defined());
994}
995
996TEST_F(FunctionalTest, SELU) {
997 {
998 const double scale = 1.0507009873554804934193349852946;
999 const double alpha = 1.6732632423543772848170429916717;
1000 for (const auto inplace : {false, true}) {
1001 auto input = torch::randn({5, 5});
1002 auto input_bf16 = input.clone().to(torch::kBFloat16);
1003 auto expected = scale *
1004 (torch::max(torch::zeros_like(input), input) +
1005 torch::min(
1006 torch::zeros_like(input), alpha * (torch::exp(input) - 1)));
1007 auto output = F::selu(input, inplace);
1008 auto output_bf16 = F::selu(input_bf16, inplace);
1009
1010 ASSERT_TRUE(output.allclose(expected));
1011 ASSERT_TRUE(output_bf16.to(torch::kFloat).allclose(output, 1e-2, 1e-2));
1012 if (inplace) {
1013 ASSERT_TRUE(input.allclose(expected));
1014 ASSERT_TRUE(input_bf16.to(torch::kFloat).allclose(output, 1e-2, 1e-2));
1015 }
1016 }
1017 }
1018 {
1019 auto input = torch::arange(0, 9, torch::kDouble).view({3, 3});
1020 auto output = F::selu(input);
1021 auto expected = F::selu(input, false);
1022 ASSERT_TRUE(output.allclose(expected));
1023 }
1024 ASSERT_TRUE(F::selu(torch::tensor(1.)).defined());
1025}
1026
1027TEST_F(FunctionalTest, GLU) {
1028 int64_t dim = 1;
1029 auto input = torch::randn({4, 2}, torch::requires_grad());
1030 auto output = F::glu(input, dim);
1031 auto input_size = input.sizes()[dim] / 2;
1032 auto first_half = input.narrow(dim, 0, input_size);
1033 auto second_half = input.narrow(dim, input_size, input_size);
1034 auto expected = first_half * torch::sigmoid(second_half);
1035
1036 ASSERT_TRUE(output.allclose(expected));
1037 ASSERT_TRUE(F::glu(input).allclose(expected));
1038}
1039
1040TEST_F(FunctionalTest, GELU) {
1041 const auto x = torch::linspace(-3.0, 3.0, 100);
1042 const auto y_exp = x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0)));
1043 const auto y = F::gelu(x, F::GELUFuncOptions().approximate("none"));
1044 ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
1045}
1046
1047TEST_F(FunctionalTest, TanhGELU) {
1048 const auto x = torch::linspace(-3.0, 3.0, 100);
1049 const auto inner = std::sqrt(2 / M_PI) * (x + 0.044715 * x.pow(3.0));
1050 const auto y_exp = 0.5 * x * (1.0 + inner.tanh());
1051 const auto y = F::gelu(x, F::GELUFuncOptions().approximate("tanh"));
1052 ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
1053}
1054
1055TEST_F(FunctionalTest, Hardshrink) {
1056 const auto size = 3;
1057 for (const auto lambda : {-4.2, -1.0, -0.42, 0.0, 0.42, 1.0, 4.2, 42.42}) {
1058 auto x = torch::linspace(-10.0, 10.0, size * size * size);
1059 x.resize_({size, size, size}).set_requires_grad(true);
1060 auto y = F::hardshrink(x, F::HardshrinkFuncOptions().lambda(lambda));
1061 torch::Tensor s = y.sum();
1062
1063 s.backward();
1064 ASSERT_EQ(s.ndimension(), 0);
1065
1066 ASSERT_EQ(y.ndimension(), 3);
1067 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1068 auto y_exp = (x.abs() > lambda) * x;
1069 ASSERT_TRUE(torch::allclose(y, y_exp));
1070 }
1071}
1072
1073TEST_F(FunctionalTest, OneHot) {
1074 { // Test #1
1075 auto x = torch::arange(0, 5, torch::kLong);
1076 auto y = F::one_hot(x % 3);
1077 auto expected = torch::tensor(
1078 {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}, {1, 0, 0}, {0, 1, 0}}, torch::kLong);
1079
1080 ASSERT_EQ(y.ndimension(), 2);
1081 ASSERT_TRUE(torch::allclose(y, expected));
1082 ASSERT_EQ(y.sizes(), std::vector<int64_t>({5, 3}));
1083 }
1084
1085 { // Test #2
1086 auto x = torch::arange(0, 5, torch::kLong);
1087 auto y = F::one_hot(x % 3, 5);
1088 auto expected = torch::tensor(
1089 {{1, 0, 0, 0, 0},
1090 {0, 1, 0, 0, 0},
1091 {0, 0, 1, 0, 0},
1092 {1, 0, 0, 0, 0},
1093 {0, 1, 0, 0, 0}},
1094 torch::kLong);
1095
1096 ASSERT_EQ(y.ndimension(), 2);
1097 ASSERT_TRUE(torch::allclose(y, expected));
1098 ASSERT_EQ(y.sizes(), std::vector<int64_t>({5, 5}));
1099 }
1100
1101 { // Test #3
1102 auto x = torch::arange(0, 6, torch::kLong);
1103 auto y = F::one_hot(x.view(std::vector<int64_t>({3, 2})) % 3);
1104 auto expected = torch::tensor(
1105 {{{1, 0, 0}, {0, 1, 0}},
1106 {{0, 0, 1}, {1, 0, 0}},
1107 {{0, 1, 0}, {0, 0, 1}}},
1108 torch::kLong);
1109
1110 ASSERT_EQ(y.ndimension(), 3);
1111 ASSERT_TRUE(torch::allclose(y, expected));
1112 ASSERT_EQ(y.sizes(), std::vector<int64_t>({3, 2, 3}));
1113 }
1114}
1115
1116TEST_F(FunctionalTest, Hardtanh) {
1117 const auto size = 3;
1118 for (const auto min_val : {-4.2, -1.0, -0.42, 0.0}) {
1119 for (const auto max_val : {0.0, 0.42, 1.0, 4.2}) {
1120 for (const auto inplace : {false, true}) {
1121 auto x = torch::linspace(-10.0, 10.0, size * size * size);
1122 x.resize_({size, size, size});
1123 auto y_exp = (x < min_val) * min_val +
1124 ((x >= min_val) * (x <= max_val)) * x + (x > max_val) * max_val;
1125 auto y = F::hardtanh(
1126 x,
1127 F::HardtanhFuncOptions().min_val(min_val).max_val(max_val).inplace(
1128 inplace));
1129
1130 ASSERT_EQ(y.ndimension(), 3);
1131 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1132 ASSERT_TRUE(torch::allclose(y, y_exp));
1133 if (inplace) {
1134 ASSERT_TRUE(torch::allclose(x, y_exp));
1135 }
1136 }
1137 }
1138 }
1139 ASSERT_TRUE(F::hardtanh(torch::tensor(1.)).defined());
1140}
1141
1142TEST_F(FunctionalTest, LeakyReLU) {
1143 const auto size = 3;
1144 for (const auto negative_slope : {0.0, 0.42, 1.0}) {
1145 for (const auto inplace : {false, true}) {
1146 for (const auto type : {torch::kFloat, torch::kBFloat16}) {
1147 auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type);
1148 x.resize_({size, size, size});
1149 auto y_exp = (x < 0) * x * negative_slope + (x >= 0) * x;
1150 auto y = F::leaky_relu(
1151 x,
1152 F::LeakyReLUFuncOptions()
1153 .negative_slope(negative_slope)
1154 .inplace(inplace));
1155
1156 ASSERT_EQ(y.ndimension(), 3);
1157 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1158 ASSERT_TRUE(torch::allclose(y, y_exp));
1159 if (inplace) {
1160 ASSERT_TRUE(torch::allclose(x, y_exp));
1161 }
1162 }
1163 }
1164 }
1165 ASSERT_TRUE(F::leaky_relu(torch::tensor(1.)).defined());
1166}
1167
1168TEST_F(FunctionalTest, LogSigmoid) {
1169 const auto size = 3;
1170 LogSigmoid model;
1171 auto x = torch::linspace(-10.0, 10.0, size * size * size);
1172 x.resize_({size, size, size});
1173 auto y = F::logsigmoid(x);
1174
1175 ASSERT_EQ(y.ndimension(), 3);
1176 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1177 auto y_exp = torch::log(
1178 torch::ones_like(x) / (torch::ones_like(x) + torch::exp(torch::neg(x))));
1179 ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7));
1180}
1181
1182TEST_F(FunctionalTest, GumbelSoftmax) {
1183 // Test 1: No-options
1184 {
1185 auto logits = torch::randn({5});
1186 int expected_count = 1;
1187 auto y_draw = F::gumbel_softmax(logits);
1188
1189 // All values positive
1190 ASSERT_GE(y_draw.min().item<int>(), 0);
1191 // Shape unchanged
1192 ASSERT_TRUE(y_draw.sizes() == logits.sizes());
1193 // One choice per draw
1194 ASSERT_TRUE(torch::allclose(
1195 y_draw.sum(), torch::tensor(expected_count, torch::kFloat)));
1196 }
1197
1198 // Test 2: 1D shape, 0 and -1 dim
1199 for (const auto dim : {0, -1}) {
1200 auto logits = torch::randn({5});
1201 int expected_count = 1;
1202 auto y_draw = F::gumbel_softmax(
1203 logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(dim));
1204
1205 // All values positive
1206 ASSERT_GE(y_draw.min().item<int>(), 0);
1207 // Shape unchanged
1208 ASSERT_TRUE(y_draw.sizes() == logits.sizes());
1209 // One choice per draw
1210 ASSERT_TRUE(torch::allclose(
1211 y_draw.sum(), torch::tensor(expected_count, torch::kFloat)));
1212 }
1213
1214 { // Test 3: 2D shape, 1 dim
1215 auto logits = torch::randn({5, 4});
1216 int expected_count = 5;
1217 auto y_draw = F::gumbel_softmax(
1218 logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(1));
1219
1220 // All values positive
1221 ASSERT_GE(y_draw.min().item<int>(), 0);
1222 // Shape unchanged
1223 ASSERT_TRUE(y_draw.sizes() == logits.sizes());
1224 // One choice per draw
1225 ASSERT_TRUE(torch::allclose(
1226 y_draw.sum(), torch::tensor(expected_count, torch::kFloat)));
1227 }
1228
1229 // Test 4: 3D shape, 1 and -1 dim
1230 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1231 int dims[] = {1, -1};
1232 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers)
1233 int expected[] = {5 * 3, 5 * 4};
1234 for (const auto i : c10::irange(2)) {
1235 auto logits = torch::randn({5, 4, 3});
1236 int expected_count = expected[i];
1237 auto y_draw = F::gumbel_softmax(
1238 logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(dims[i]));
1239
1240 // All values positive
1241 ASSERT_GE(y_draw.min().item<int>(), 0);
1242 // Shape unchanged
1243 ASSERT_TRUE(y_draw.sizes() == logits.sizes());
1244 // One choice per draw
1245 ASSERT_TRUE(torch::allclose(
1246 y_draw.sum(), torch::tensor(expected_count, torch::kFloat)));
1247 }
1248
1249 { // Test 5: Straight through
1250 int num_draws = 100;
1251 auto logits = torch::tensor({{0.2, 0.8, 0.1}});
1252 logits = logits.reshape({1, 3});
1253 logits.requires_grad();
1254 auto probs = logits.softmax(-1);
1255
1256 auto counts = torch::zeros_like(logits);
1257 torch::Tensor y_draw;
1258 for (const auto i : c10::irange(num_draws)) {
1259 (void)i; // Suppress unused variable warning
1260 y_draw =
1261 F::gumbel_softmax(logits, F::GumbelSoftmaxFuncOptions().hard(true));
1262 counts += y_draw;
1263 }
1264
1265 // All values positive
1266 ASSERT_GE(y_draw.min().item<int>(), 0);
1267 // Each experiment should result in 1 draw
1268 ASSERT_EQ(counts.sum().item<int>(), num_draws);
1269
1270 // Check results are asymptotically as expected
1271 auto expected = probs * num_draws;
1272 // ~z is approximately N(0,1) for unbiased count
1273 auto z = (counts - expected) / (expected * (1 - probs)).sqrt();
1274 // A (lazy) approximate 99% two-sided test:
1275 // occurs with prob alpha~>=0.01 if unbiased
1276 ASSERT_LT(z.abs().max().item<float>(), 2.58);
1277 }
1278}
1279
1280TEST_F(FunctionalTest, Softmax) {
1281 auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
1282 // NOLINTNEXTLINE(bugprone-argument-comment)
1283 auto output = F::softmax(input, /*dim=*/1);
1284 auto sum = torch::sum(torch::exp(input), 1);
1285
1286 for (const auto i : c10::irange(2)) {
1287 auto expected = torch::exp(input[i]) / sum[i];
1288 ASSERT_TRUE(torch::allclose(output[i], expected));
1289 }
1290}
1291
1292TEST_F(FunctionalTest, Softmin) {
1293 auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
1294 // NOLINTNEXTLINE(bugprone-argument-comment)
1295 auto output = F::softmin(input, /*dim=*/1);
1296 auto sum = torch::sum(torch::exp(-input), 1);
1297
1298 for (const auto i : c10::irange(2)) {
1299 auto expected = torch::exp(-input[i]) / sum[i];
1300 ASSERT_TRUE(torch::allclose(output[i], expected));
1301 }
1302}
1303
1304TEST_F(FunctionalTest, LogSoftmax) {
1305 auto input = torch::arange(10, torch::kFloat).reshape({2, 5});
1306 // NOLINTNEXTLINE(bugprone-argument-comment)
1307 auto output = F::log_softmax(input, /*dim=*/1);
1308 auto sum = torch::sum(torch::exp(input), 1);
1309
1310 for (const auto i : c10::irange(2)) {
1311 auto expected = torch::log(torch::exp(input[i]) / sum[i]);
1312 ASSERT_TRUE(torch::allclose(output[i], expected));
1313 }
1314}
1315
1316TEST_F(FunctionalTest, PReLU) {
1317 const auto x = torch::rand({42, 24}) * 200 - 100;
1318 const auto w = torch::rand(24) * 200 - 100;
1319 const auto y = F::prelu(x, w);
1320 ASSERT_EQ(y.sizes(), std::vector<int64_t>({42, 24}));
1321 const auto y_exp = (x < 0) * w * x + (x >= 0) * x;
1322 ASSERT_TRUE(torch::allclose(y, y_exp));
1323}
1324
1325TEST_F(FunctionalTest, LayerNorm) {
1326 const auto input = torch::randn({2, 2});
1327 auto y = F::layer_norm(input, F::LayerNormFuncOptions({2, 2}).eps(2e-5));
1328 auto y_exp =
1329 torch::layer_norm(input, {2, 2}, torch::Tensor(), torch::Tensor(), 2e-5);
1330 ASSERT_TRUE(torch::allclose(y, y_exp));
1331}
1332
1333TEST_F(FunctionalTest, GroupNorm) {
1334 const auto input = torch::randn({2, 2});
1335 auto y = F::group_norm(input, F::GroupNormFuncOptions(2).eps(2e-5));
1336 auto y_exp =
1337 torch::group_norm(input, 2, torch::Tensor(), torch::Tensor(), 2e-5);
1338 ASSERT_TRUE(torch::allclose(y, y_exp));
1339}
1340
1341TEST_F(FunctionalTest, LocalResponseNorm) {
1342 const auto x = torch::arange(100, 118).resize_({3, 3, 2});
1343 const auto y = F::local_response_norm(x, F::LocalResponseNormFuncOptions(2));
1344 ASSERT_EQ(y.ndimension(), 3);
1345 ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 3, 2}));
1346 const auto y_exp = torch::tensor(
1347 {{{73.7788, 74.1462}, {60.1942, 60.3302}, {60.4609, 60.5865}},
1348 {{75.8729, 76.2011}, {60.9331, 61.0390}, {61.1403, 61.2370}},
1349 {{77.7387, 78.0303}, {61.5011, 61.5807}, {61.6563, 61.7279}}},
1350 torch::kFloat);
1351 ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7));
1352}
1353
1354TEST_F(FunctionalTest, Linear) {
1355 {
1356 const auto x = torch::arange(100., 118).resize_({3, 3, 2});
1357 const auto w = torch::arange(200., 206).resize_({3, 2});
1358 const auto b = torch::arange(300., 303);
1359 const auto y = F::linear(x, w, b);
1360 ASSERT_EQ(y.ndimension(), 3);
1361 ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 3, 3}));
1362 const auto y_exp = torch::tensor(
1363 {{{40601, 41004, 41407}, {41403, 41814, 42225}, {42205, 42624, 43043}},
1364 {{43007, 43434, 43861}, {43809, 44244, 44679}, {44611, 45054, 45497}},
1365 {{45413, 45864, 46315}, {46215, 46674, 47133}, {47017, 47484, 47951}}},
1366 torch::kFloat);
1367 ASSERT_TRUE(torch::allclose(y, y_exp));
1368 }
1369 {
1370 const auto x = torch::arange(100., 118).resize_({3, 3, 2});
1371 const auto w = torch::arange(200., 206).resize_({3, 2});
1372 const auto y = F::linear(x, w);
1373 ASSERT_EQ(y.ndimension(), 3);
1374 ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 3, 3}));
1375 const auto y_exp = torch::tensor(
1376 {{{40301, 40703, 41105}, {41103, 41513, 41923}, {41905, 42323, 42741}},
1377 {{42707, 43133, 43559}, {43509, 43943, 44377}, {44311, 44753, 45195}},
1378 {{45113, 45563, 46013}, {45915, 46373, 46831}, {46717, 47183, 47649}}},
1379 torch::kFloat);
1380 ASSERT_TRUE(torch::allclose(y, y_exp));
1381 }
1382}
1383
1384TEST_F(FunctionalTest, Embedding) {
1385 const auto input = torch::tensor({{1, 2, 4, 5}, {4, 3, 2, 9}}, torch::kLong);
1386 auto weight = torch::empty({10, 3});
1387 torch::nn::init::normal_(weight);
1388 auto y = F::embedding(input, weight);
1389 auto y_exp = torch::embedding(weight, input.contiguous(), -1, false, false);
1390 ASSERT_TRUE(torch::allclose(y, y_exp));
1391}
1392
1393TEST_F(FunctionalTest, EmbeddingBag) {
1394 const auto input = torch::tensor({1, 2, 4, 5, 4, 3, 2, 9}, torch::kLong);
1395 auto offsets = torch::tensor({0, 4}, torch::kLong);
1396 auto weight = torch::empty({10, 3});
1397 torch::nn::init::normal_(weight);
1398 auto y = F::embedding_bag(
1399 input,
1400 weight,
1401 F::EmbeddingBagFuncOptions()
1402 .mode(torch::kSum)
1403 .offsets(offsets)
1404 .padding_idx(4));
1405 auto y_exp = std::get<0>(torch::embedding_bag(
1406 weight, input, offsets, false, 0, false, torch::Tensor(), false, 4));
1407 ASSERT_TRUE(torch::allclose(y, y_exp));
1408
1409 // no options test
1410 const auto input_ = torch::tensor({{1, 2, 4, 5}, {4, 3, 2, 9}}, torch::kLong);
1411 auto offsets_ = torch::arange(
1412 0,
1413 input_.numel(),
1414 input_.size(1),
1415 torch::TensorOptions().dtype(torch::kLong).device(input.device()));
1416 y = F::embedding_bag(input_, weight);
1417 y_exp = std::get<0>(torch::embedding_bag(
1418 weight, input_.reshape(-1), offsets_, false, 1, false, torch::Tensor()));
1419 ASSERT_TRUE(torch::allclose(y, y_exp));
1420}
1421
1422TEST_F(FunctionalTest, Bilinear) {
1423 auto input1 = torch::tensor({{1, 2, 3}, {7, 6, 5}});
1424 auto input2 = torch::tensor({{7, 4}, {8, 9}});
1425 auto weight = torch::tensor({{{2, 3}, {9, 7}, {8, 6}}});
1426 auto bias = torch::tensor({1});
1427
1428 auto y_with_bias = F::bilinear(input1, input2, weight, bias);
1429 ASSERT_EQ(y_with_bias.ndimension(), 2);
1430 ASSERT_EQ(y_with_bias.sizes(), torch::IntArrayRef({2, 1}));
1431 auto y_with_bias_exp = torch::tensor({{449}, {1702}}).reshape({2, 1});
1432 ASSERT_TRUE(torch::allclose(y_with_bias, y_with_bias_exp, 1e-4, 1e-7));
1433
1434 auto y_no_bias = F::bilinear(input1, input2, weight);
1435 ASSERT_EQ(y_no_bias.ndimension(), 2);
1436 ASSERT_EQ(y_no_bias.sizes(), torch::IntArrayRef({2, 1}));
1437 auto y_no_bias_exp = torch::tensor({{448, 1701}}).reshape({2, 1});
1438 ASSERT_TRUE(torch::allclose(y_no_bias, y_no_bias_exp, 1e-4, 1e-7));
1439}
1440
1441TEST_F(FunctionalTest, Normalize) {
1442 const auto expected = torch::tensor(
1443 {{{0.00000000, 0.10000000, 0.2000, 0.30000000, 0.40000000},
1444 {0.14285715, 0.17142858, 0.2000, 0.22857143, 0.25714287}}},
1445 torch::requires_grad().dtype(torch::kFloat));
1446 { // Test #1
1447 auto input = torch::tensor(
1448 {{{0, 1, 2, 3, 4}, {5, 6, 7, 8, 9}}},
1449 torch::dtype(torch::kFloat).requires_grad(true));
1450 auto norm = F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1));
1451
1452 // reduce to scalar to call .backward()
1453 torch::Tensor s = norm.sum();
1454 s.backward();
1455
1456 ASSERT_EQ(s.ndimension(), 0);
1457 ASSERT_EQ(input.grad().numel(), 10);
1458 ASSERT_TRUE(torch::allclose(norm, expected));
1459 }
1460
1461 { // Test #2 Check variations of optional arguments
1462 auto input = torch::tensor(
1463 {{{0, 1, 2, 3, 4}, {5, 6, 7, 8, 9}}}, torch::dtype(torch::kFloat));
1464 auto output = torch::randn({1, 2, 5}, torch::dtype(torch::kFloat));
1465 // non-null output argument
1466 F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1).out(output));
1467 // default options
1468 F::normalize(input);
1469
1470 ASSERT_TRUE(torch::allclose(output, expected));
1471 }
1472
1473 { // Test #3 Base case of scalar tensor
1474 auto input = torch::randn({}, torch::requires_grad());
1475 torch::Tensor norm =
1476 F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1));
1477 norm.backward();
1478
1479 ASSERT_EQ(input.grad().numel(), 1);
1480 }
1481}
1482
1483TEST_F(FunctionalTest, ReLU) {
1484 const auto size = 3;
1485 for (const auto inplace : {false, true}) {
1486 auto x = torch::linspace(-10.0, 10.0, size * size * size);
1487 x.resize_({size, size, size});
1488 auto y_exp = (x < 0) * 0 + (x >= 0) * x;
1489 auto y = F::relu(x, F::ReLUFuncOptions().inplace(inplace));
1490
1491 ASSERT_EQ(y.ndimension(), 3);
1492 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1493 ASSERT_TRUE(torch::allclose(y, y_exp));
1494 if (inplace) {
1495 ASSERT_TRUE(torch::allclose(x, y_exp));
1496 }
1497
1498 // NOLINTNEXTLINE(bugprone-argument-comment)
1499 y = F::relu(x, /*inplace=*/inplace);
1500
1501 ASSERT_EQ(y.ndimension(), 3);
1502 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1503 ASSERT_TRUE(torch::allclose(y, y_exp));
1504 if (inplace) {
1505 ASSERT_TRUE(torch::allclose(x, y_exp));
1506 }
1507 }
1508 ASSERT_TRUE(F::relu(torch::tensor(1.)).defined());
1509}
1510
1511TEST_F(FunctionalTest, ReLUDefaultOptions) {
1512 const auto size = 3;
1513 auto x = torch::linspace(-10.0, 10.0, size * size * size);
1514 x.resize_({size, size, size});
1515 auto y_exp = (x < 0) * 0 + (x >= 0) * x;
1516 auto y = F::relu(x);
1517
1518 ASSERT_EQ(y.ndimension(), 3);
1519 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1520 ASSERT_TRUE(torch::allclose(y, y_exp));
1521}
1522
1523TEST_F(FunctionalTest, ReLU6) {
1524 const auto size = 3;
1525 for (const auto inplace : {false, true}) {
1526 auto x = torch::linspace(-10.0, 10.0, size * size * size);
1527 x.resize_({size, size, size});
1528 auto y_exp = (x < 0) * 0 + ((x >= 0) * (x <= 6)) * x + (x > 6) * 6;
1529 auto y = F::relu6(x, F::ReLU6FuncOptions().inplace(inplace));
1530
1531 ASSERT_EQ(y.ndimension(), 3);
1532 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1533 ASSERT_TRUE(torch::allclose(y, y_exp));
1534 if (inplace) {
1535 ASSERT_TRUE(torch::allclose(x, y_exp));
1536 }
1537
1538 // NOLINTNEXTLINE(bugprone-argument-comment)
1539 y = F::relu6(x, /*inplace=*/inplace);
1540
1541 ASSERT_EQ(y.ndimension(), 3);
1542 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1543 ASSERT_TRUE(torch::allclose(y, y_exp));
1544 if (inplace) {
1545 ASSERT_TRUE(torch::allclose(x, y_exp));
1546 }
1547 }
1548 ASSERT_TRUE(F::relu6(torch::tensor(1.)).defined());
1549}
1550
1551TEST_F(FunctionalTest, ReLU6DefaultOptions) {
1552 const auto size = 3;
1553 auto x = torch::linspace(-10.0, 10.0, size * size * size);
1554 x.resize_({size, size, size});
1555 auto y_exp = (x < 0) * 0 + ((x >= 0) * (x <= 6)) * x + (x > 6) * 6;
1556 auto y = F::relu6(x);
1557
1558 ASSERT_EQ(y.ndimension(), 3);
1559 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1560 ASSERT_TRUE(torch::allclose(y, y_exp));
1561}
1562
1563TEST_F(FunctionalTest, RReLU) {
1564 const auto size = 3;
1565 for (const auto lower : {0.01, 0.1, 0.2}) {
1566 for (const auto upper : {0.3, 0.4, 0.5}) {
1567 for (const auto inplace : {false, true}) {
1568 for (const auto type : {torch::kFloat, torch::kBFloat16}) {
1569 auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type);
1570 x.resize_({size, size, size});
1571 auto x_copy = x.clone();
1572 auto y = F::rrelu(
1573 x,
1574 F::RReLUFuncOptions().lower(lower).upper(upper).inplace(inplace));
1575 auto z =
1576 ((x_copy >= 0) * (x_copy == y) +
1577 (x_copy < 0) * (y >= x_copy * upper) * (y <= lower * x_copy)) *
1578 1.0;
1579
1580 ASSERT_EQ(y.ndimension(), 3);
1581 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1582 ASSERT_TRUE(torch::allclose(z, torch::ones_like(z)));
1583 if (inplace) {
1584 ASSERT_TRUE(torch::allclose(x, y));
1585 }
1586 }
1587 }
1588 }
1589 }
1590 ASSERT_TRUE(F::rrelu(torch::tensor(1.)).defined());
1591}
1592
1593TEST_F(FunctionalTest, RReLUDefaultOptions) {
1594 const auto size = 3;
1595 const auto lower = 1.0 / 8.0;
1596 const auto upper = 1.0 / 3.0;
1597 for (const auto type : {torch::kFloat, torch::kBFloat16}) {
1598 auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type);
1599 x.resize_({size, size, size});
1600 auto x_copy = x.clone();
1601 auto y = F::rrelu(x);
1602 auto z = ((x_copy >= 0) * (x_copy == y) +
1603 (x_copy < 0) * (y >= x_copy * upper) * (y <= lower * x_copy)) *
1604 1.0;
1605
1606 ASSERT_EQ(y.ndimension(), 3);
1607 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1608 ASSERT_TRUE(torch::allclose(z, torch::ones_like(z)));
1609 }
1610}
1611
1612TEST_F(FunctionalTest, CELU) {
1613 const auto size = 3;
1614 for (const auto inplace : {false, true}) {
1615 for (const auto alpha : {0.42, 1.0, 4.2, 42.42}) {
1616 auto x = torch::linspace(-10.0, 10.0, size * size * size);
1617 x.resize_({size, size, size});
1618 auto x_bf16 = x.clone().to(torch::kBFloat16);
1619 auto y_exp = torch::max(torch::zeros_like(x), x) +
1620 torch::min(torch::zeros_like(x),
1621 alpha * (torch::exp(x / alpha) - 1.0));
1622 auto y = F::celu(x, F::CELUFuncOptions().alpha(alpha).inplace(inplace));
1623 auto y_bf16 =
1624 F::celu(x_bf16, F::CELUFuncOptions().alpha(alpha).inplace(inplace));
1625
1626 ASSERT_EQ(y.ndimension(), 3);
1627 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1628 ASSERT_TRUE(torch::allclose(y, y_exp));
1629 ASSERT_TRUE(torch::allclose(y_bf16.to(torch::kFloat), y, 1e-2, 1e-2));
1630 if (inplace) {
1631 ASSERT_TRUE(torch::allclose(x, y_exp));
1632 ASSERT_TRUE(torch::allclose(x_bf16.to(torch::kFloat), y, 1e-2, 1e-2));
1633 }
1634 }
1635 }
1636 ASSERT_TRUE(F::celu(torch::tensor(1.)).defined());
1637}
1638
1639TEST_F(FunctionalTest, CELUDefaultOptions) {
1640 const auto size = 3;
1641 const auto alpha = 1.0;
1642 auto x = torch::linspace(-10.0, 10.0, size * size * size);
1643 x.resize_({size, size, size});
1644 auto x_bf16 = x.clone().to(torch::kBFloat16);
1645 auto y_exp = torch::max(torch::zeros_like(x), x) +
1646 torch::min(torch::zeros_like(x), alpha * (torch::exp(x / alpha) - 1.0));
1647 auto y = F::celu(x);
1648 auto y_bf16 = F::celu(x_bf16);
1649
1650 ASSERT_EQ(y.ndimension(), 3);
1651 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1652 ASSERT_TRUE(torch::allclose(y, y_exp));
1653 ASSERT_TRUE(torch::allclose(y_bf16.to(torch::kFloat), y, 1e-2, 1e-2));
1654}
1655
1656TEST_F(FunctionalTest, PixelShuffle) {
1657 auto x = torch::tensor(
1658 {{{{-17, 19}, {-1, 2}},
1659 {{7, 14}, {-3, 1}},
1660 {{0, -2}, {-12, 14}},
1661 {{-15, 0}, {-3, 9}}}},
1662 torch::kFloat);
1663 auto y_exp = torch::tensor(
1664 {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}},
1665 torch::kFloat);
1666 auto y = F::pixel_shuffle(x, 2);
1667
1668 ASSERT_EQ(y.ndimension(), 4);
1669 ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 4, 4}));
1670 ASSERT_TRUE(y.allclose(y_exp));
1671}
1672
1673TEST_F(FunctionalTest, PixelUnshuffle) {
1674 auto x = torch::tensor(
1675 {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}},
1676 torch::kFloat);
1677 auto y_exp = torch::tensor(
1678 {{{{-17, 19}, {-1, 2}},
1679 {{7, 14}, {-3, 1}},
1680 {{0, -2}, {-12, 14}},
1681 {{-15, 0}, {-3, 9}}}},
1682 torch::kFloat);
1683 auto y = F::pixel_unshuffle(x, 2);
1684
1685 ASSERT_EQ(y.ndimension(), 4);
1686 ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 4, 2, 2}));
1687 ASSERT_TRUE(y.allclose(y_exp));
1688}
1689
1690TEST_F(FunctionalTest, Softplus) {
1691 const auto size = 3;
1692 for (const auto beta : {0.5, 1.0, 2.0}) {
1693 for (const auto threshold : {1.0, 3.0, 5.0}) {
1694 auto x = torch::linspace(-3.0, 3.0, 61);
1695 x.resize_({size, size, size});
1696 auto y_exp =
1697 (x <= threshold) * torch::log(1 + torch::exp(x * beta)) / beta +
1698 (x > threshold) * x;
1699 auto y = F::softplus(
1700 x, F::SoftplusFuncOptions().beta(beta).threshold(threshold));
1701
1702 ASSERT_EQ(y.ndimension(), 3);
1703 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1704 ASSERT_TRUE(torch::allclose(y, y_exp));
1705 }
1706 }
1707}
1708
1709TEST_F(FunctionalTest, SoftplusDefaultOptions) {
1710 const auto size = 3;
1711 const auto beta = 1.0;
1712 const auto threshold = 20.0;
1713 auto x = torch::linspace(-3.0, 3.0, 61);
1714 x.resize_({size, size, size});
1715 auto y_exp = (x <= threshold) * torch::log(1 + torch::exp(x * beta)) / beta +
1716 (x > threshold) * x;
1717 auto y = F::softplus(x);
1718
1719 ASSERT_EQ(y.ndimension(), 3);
1720 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1721 ASSERT_TRUE(torch::allclose(y, y_exp));
1722}
1723
1724TEST_F(FunctionalTest, Fold) {
1725 auto input = torch::ones({1, 3 * 2 * 2, 2}, torch::kDouble);
1726 auto output = F::fold(input, F::FoldFuncOptions({3, 2}, {2, 2}));
1727 auto expected = torch::tensor(
1728 {{{{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}},
1729 {{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}},
1730 {{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}}}},
1731 torch::kDouble);
1732
1733 ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 3, 3, 2}));
1734 ASSERT_TRUE(output.allclose(expected));
1735}
1736
1737TEST_F(FunctionalTest, Unfold) {
1738 auto input = torch::arange(0, 12, torch::kDouble).view({1, 2, 2, 3});
1739 auto output =
1740 F::unfold(input, F::UnfoldFuncOptions({2, 2}).padding(1).stride(2));
1741 auto expected = torch::tensor(
1742 {{{0.0, 0.0, 0.0, 4.0},
1743 {0.0, 0.0, 3.0, 5.0},
1744 {0.0, 1.0, 0.0, 0.0},
1745 {0.0, 2.0, 0.0, 0.0},
1746 {0.0, 0.0, 0.0, 10.0},
1747 {0.0, 0.0, 9.0, 11.0},
1748 {0.0, 7.0, 0.0, 0.0},
1749 {6.0, 8.0, 0.0, 0.0}}},
1750 torch::kDouble);
1751
1752 ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 8, 4}));
1753 ASSERT_TRUE(output.allclose(expected));
1754}
1755
1756TEST_F(FunctionalTest, Softshrink) {
1757 const auto size = 3;
1758 for (const auto lambda : {0.0, 0.42, 1.0, 4.2, 42.42}) {
1759 auto x = torch::linspace(-10.0, 10.0, size * size * size);
1760 x.resize_({size, size, size}).set_requires_grad(true);
1761 // NOLINTNEXTLINE(bugprone-argument-comment)
1762 auto y = F::softshrink(x, /*lambda=*/lambda);
1763 torch::Tensor s = y.sum();
1764
1765 s.backward();
1766 ASSERT_EQ(s.ndimension(), 0);
1767
1768 ASSERT_EQ(y.ndimension(), 3);
1769 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1770 auto y_exp = (x < -lambda) * (x + lambda) + (x > lambda) * (x - lambda);
1771 ASSERT_TRUE(torch::allclose(y, y_exp));
1772 }
1773}
1774
1775TEST_F(FunctionalTest, SoftshrinkDefaultOptions) {
1776 const auto size = 3;
1777 const auto lambda = 0.5;
1778 auto x = torch::linspace(-10.0, 10.0, size * size * size);
1779 x.resize_({size, size, size}).set_requires_grad(true);
1780 auto y = F::softshrink(x);
1781 torch::Tensor s = y.sum();
1782
1783 s.backward();
1784 ASSERT_EQ(s.ndimension(), 0);
1785
1786 ASSERT_EQ(y.ndimension(), 3);
1787 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1788 auto y_exp = (x < -lambda) * (x + lambda) + (x > lambda) * (x - lambda);
1789}
1790
1791TEST_F(FunctionalTest, Softsign) {
1792 auto x = torch::randn(100) * 10;
1793 auto y_exp = x / (1 + x.abs());
1794 auto y = F::softsign(x);
1795
1796 ASSERT_TRUE(torch::allclose(y, y_exp));
1797}
1798
1799TEST_F(FunctionalTest, Mish) {
1800 auto x = torch::randn(100) * 10;
1801 auto y_exp = x * x.exp().log1p().tanh();
1802 auto y = F::mish(x);
1803
1804 ASSERT_TRUE(torch::allclose(y, y_exp));
1805}
1806
1807TEST_F(FunctionalTest, Tanhshrink) {
1808 auto x = torch::randn(100) * 10;
1809 auto y_exp = x - x.tanh();
1810 auto y = F::tanhshrink(x);
1811
1812 ASSERT_TRUE(torch::allclose(y, y_exp));
1813}
1814
1815TEST_F(FunctionalTest, Threshold) {
1816 const auto size = 3;
1817 for (const auto threshold : {0.5, 1.0, 2.0}) {
1818 for (const auto value : {0.5, 1.0, 2.0}) {
1819 for (const auto inplace : {false, true}) {
1820 auto x = torch::linspace(-3.0, 3.0, 61);
1821 x.resize_({size, size, size});
1822 auto y_exp = (x <= threshold) * value + (x > threshold) * x;
1823 auto y = F::threshold(
1824 x, F::ThresholdFuncOptions(threshold, value).inplace(inplace));
1825
1826 ASSERT_EQ(y.ndimension(), 3);
1827 ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
1828 ASSERT_TRUE(torch::allclose(y, y_exp));
1829 if (inplace) {
1830 ASSERT_TRUE(torch::allclose(x, y_exp));
1831 }
1832 }
1833 }
1834 }
1835 ASSERT_TRUE(F::threshold(torch::tensor(1.), F::ThresholdFuncOptions(0.5, 0.5))
1836 .defined());
1837}
1838
1839TEST_F(FunctionalTest, BatchNorm1d) {
1840 int num_features = 5;
1841 double eps = 1e-05;
1842 double momentum = 0.1;
1843
1844 auto input = torch::randn({2, 5});
1845 auto mean = torch::randn(5);
1846 auto variance = torch::rand(5);
1847 auto weight = torch::ones({num_features});
1848 auto bias = torch::zeros({num_features});
1849 auto output = F::batch_norm(
1850 input,
1851 mean,
1852 variance,
1853 F::BatchNormFuncOptions()
1854 .weight(weight)
1855 .bias(bias)
1856 .momentum(momentum)
1857 .eps(eps)
1858 .training(false));
1859 auto expected = (input - mean) / torch::sqrt(variance + eps);
1860 ASSERT_TRUE(output.allclose(expected));
1861}
1862
1863TEST_F(FunctionalTest, BatchNorm1dDefaultOptions) {
1864 auto input = torch::randn({2, 5});
1865 auto mean = torch::randn(5);
1866 auto variance = torch::rand(5);
1867 auto output = F::batch_norm(input, mean, variance);
1868 auto expected = (input - mean) / torch::sqrt(variance + 1e-5);
1869 ASSERT_TRUE(output.allclose(expected));
1870}
1871
1872TEST_F(FunctionalTest, BatchNorm2d) {
1873 int num_features = 5;
1874 double eps = 1e-05;
1875 double momentum = 0.1;
1876
1877 auto input = torch::randn({2, num_features, 4, 4});
1878 auto mean = torch::randn(num_features);
1879 auto variance = torch::rand(num_features);
1880 auto weight = torch::ones({num_features});
1881 auto bias = torch::zeros({num_features});
1882 auto output = F::batch_norm(
1883 input,
1884 mean,
1885 variance,
1886 F::BatchNormFuncOptions()
1887 .weight(weight)
1888 .bias(bias)
1889 .momentum(momentum)
1890 .eps(eps)
1891 .training(false));
1892 auto expected = torch::transpose(
1893 (torch::transpose(input, 1, 3) - mean) / torch::sqrt(variance + eps),
1894 1,
1895 3);
1896 ASSERT_TRUE(output.allclose(expected));
1897}
1898
1899TEST_F(FunctionalTest, BatchNorm2dDefaultOptions) {
1900 int num_features = 5;
1901 double eps = 1e-05;
1902
1903 auto input = torch::randn({2, num_features, 4, 4});
1904 auto mean = torch::randn(num_features);
1905 auto variance = torch::rand(num_features);
1906 auto output = F::batch_norm(input, mean, variance);
1907 auto expected = torch::transpose(
1908 (torch::transpose(input, 1, 3) - mean) / torch::sqrt(variance + eps),
1909 1,
1910 3);
1911 ASSERT_TRUE(output.allclose(expected));
1912}
1913
1914TEST_F(FunctionalTest, BatchNorm3d) {
1915 int num_features = 5;
1916 double eps = 1e-05;
1917 double momentum = 0.1;
1918
1919 auto input = torch::randn({2, num_features, 2, 2, 2});
1920 auto mean = torch::randn(num_features);
1921 auto variance = torch::rand(num_features);
1922 auto weight = torch::ones({num_features});
1923 auto bias = torch::zeros({num_features});
1924 auto output = F::batch_norm(
1925 input,
1926 mean,
1927 variance,
1928 F::BatchNormFuncOptions()
1929 .weight(weight)
1930 .bias(bias)
1931 .momentum(momentum)
1932 .eps(eps)
1933 .training(false));
1934 auto expected = torch::transpose(
1935 (torch::transpose(input, 1, 4) - mean) / torch::sqrt(variance + eps),
1936 1,
1937 4);
1938 ASSERT_TRUE(output.allclose(expected));
1939}
1940
1941TEST_F(FunctionalTest, BatchNorm3dDefaultOptions) {
1942 int num_features = 5;
1943 double eps = 1e-05;
1944
1945 auto input = torch::randn({2, num_features, 2, 2, 2});
1946 auto mean = torch::randn(num_features);
1947 auto variance = torch::rand(num_features);
1948 auto output = F::batch_norm(input, mean, variance);
1949 auto expected = torch::transpose(
1950 (torch::transpose(input, 1, 4) - mean) / torch::sqrt(variance + eps),
1951 1,
1952 4);
1953 ASSERT_TRUE(output.allclose(expected));
1954}
1955
1956TEST_F(FunctionalTest, InstanceNorm1d) {
1957 int num_features = 5;
1958 double eps = 1e-05;
1959 double momentum = 0.1;
1960
1961 auto input = torch::arange(40.).view({2, 5, 4});
1962 auto mean = torch::arange(5.);
1963 auto variance = torch::arange(5.);
1964 auto weight = torch::arange((double)num_features);
1965 auto bias = torch::arange((double)num_features);
1966 auto output = F::instance_norm(
1967 input,
1968 F::InstanceNormFuncOptions()
1969 .running_mean(mean)
1970 .running_var(variance)
1971 .weight(weight)
1972 .bias(bias)
1973 .momentum(momentum)
1974 .eps(eps));
1975 auto expected = torch::tensor(
1976 {{{0.0000, 0.0000, 0.0000, 0.0000},
1977 {-0.3416, 0.5528, 1.4472, 2.3416},
1978 {-0.6833, 1.1056, 2.8944, 4.6833},
1979 {-1.0249, 1.6584, 4.3416, 7.0249},
1980 {-1.3665, 2.2112, 5.7888, 9.3665}},
1981 {{0.0000, 0.0000, 0.0000, 0.0000},
1982 {-0.3416, 0.5528, 1.4472, 2.3416},
1983 {-0.6833, 1.1056, 2.8944, 4.6833},
1984 {-1.0249, 1.6584, 4.3416, 7.0249},
1985 {-1.3665, 2.2112, 5.7888, 9.3665}}});
1986 ASSERT_TRUE(output.allclose(expected, 2e-04));
1987}
1988
1989TEST_F(FunctionalTest, InstanceNorm1dDefaultOptions) {
1990 auto input = torch::arange(40.).view({2, 5, 4});
1991 auto output = F::instance_norm(input);
1992 auto expected = torch::tensor(
1993 {{{-1.3416, -0.4472, 0.4472, 1.3416},
1994 {-1.3416, -0.4472, 0.4472, 1.3416},
1995 {-1.3416, -0.4472, 0.4472, 1.3416},
1996 {-1.3416, -0.4472, 0.4472, 1.3416},
1997 {-1.3416, -0.4472, 0.4472, 1.3416}},
1998 {{-1.3416, -0.4472, 0.4472, 1.3416},
1999 {-1.3416, -0.4472, 0.4472, 1.3416},
2000 {-1.3416, -0.4472, 0.4472, 1.3416},
2001 {-1.3416, -0.4472, 0.4472, 1.3416},
2002 {-1.3416, -0.4472, 0.4472, 1.3416}}});
2003 ASSERT_TRUE(output.allclose(expected, 2e-04));
2004}
2005
2006TEST_F(FunctionalTest, InstanceNorm2d) {
2007 int num_features = 5;
2008 double eps = 1e-05;
2009 double momentum = 0.1;
2010
2011 auto input =
2012 torch::arange(2. * num_features * 2 * 2).view({2, num_features, 2, 2});
2013 auto mean = torch::arange((double)num_features);
2014 auto variance = torch::arange((double)num_features);
2015 auto weight = torch::arange((double)num_features);
2016 auto bias = torch::arange((double)num_features);
2017 auto output = F::instance_norm(
2018 input,
2019 F::InstanceNormFuncOptions()
2020 .running_mean(mean)
2021 .running_var(variance)
2022 .weight(weight)
2023 .bias(bias)
2024 .momentum(momentum)
2025 .eps(eps));
2026 auto expected = torch::tensor(
2027 {{{{0.0000, 0.0000}, {0.0000, 0.0000}},
2028 {{-0.3416, 0.5528}, {1.4472, 2.3416}},
2029 {{-0.6833, 1.1056}, {2.8944, 4.6833}},
2030 {{-1.0249, 1.6584}, {4.3416, 7.0249}},
2031 {{-1.3665, 2.2112}, {5.7888, 9.3665}}},
2032 {{{0.0000, 0.0000}, {0.0000, 0.0000}},
2033 {{-0.3416, 0.5528}, {1.4472, 2.3416}},
2034 {{-0.6833, 1.1056}, {2.8944, 4.6833}},
2035 {{-1.0249, 1.6584}, {4.3416, 7.0249}},
2036 {{-1.3665, 2.2112}, {5.7888, 9.3665}}}});
2037 ASSERT_TRUE(output.allclose(expected, 2e-04));
2038}
2039
2040TEST_F(FunctionalTest, InstanceNorm2dDefaultOptions) {
2041 int num_features = 5;
2042
2043 auto input =
2044 torch::arange(2. * num_features * 2 * 2).view({2, num_features, 2, 2});
2045 auto output = F::instance_norm(input);
2046 auto expected = torch::tensor(
2047 {{{{-1.3416, -0.4472}, {0.4472, 1.3416}},
2048 {{-1.3416, -0.4472}, {0.4472, 1.3416}},
2049 {{-1.3416, -0.4472}, {0.4472, 1.3416}},
2050 {{-1.3416, -0.4472}, {0.4472, 1.3416}},
2051 {{-1.3416, -0.4472}, {0.4472, 1.3416}}},
2052 {{{-1.3416, -0.4472}, {0.4472, 1.3416}},
2053 {{-1.3416, -0.4472}, {0.4472, 1.3416}},
2054 {{-1.3416, -0.4472}, {0.4472, 1.3416}},
2055 {{-1.3416, -0.4472}, {0.4472, 1.3416}},
2056 {{-1.3416, -0.4472}, {0.4472, 1.3416}}}});
2057 ASSERT_TRUE(output.allclose(expected, 2e-04));
2058}
2059
2060TEST_F(FunctionalTest, InstanceNorm3d) {
2061 int num_features = 5;
2062 double eps = 1e-05;
2063 double momentum = 0.1;
2064
2065 auto input = torch::arange(2. * num_features * 2 * 2 * 2)
2066 .view({2, num_features, 2, 2, 2});
2067 auto mean = torch::arange((double)num_features);
2068 auto variance = torch::arange((double)num_features);
2069 auto weight = torch::arange((double)num_features);
2070 auto bias = torch::arange((double)num_features);
2071 auto output = F::instance_norm(
2072 input,
2073 F::InstanceNormFuncOptions()
2074 .running_mean(mean)
2075 .running_var(variance)
2076 .weight(weight)
2077 .bias(bias)
2078 .momentum(momentum)
2079 .eps(eps));
2080 auto expected = torch::tensor(
2081 {{{{{0.0000, 0.0000}, {0.0000, 0.0000}},
2082 {{0.0000, 0.0000}, {0.0000, 0.0000}}},
2083 {{{-0.5275, -0.0911}, {0.3453, 0.7818}},
2084 {{1.2182, 1.6547}, {2.0911, 2.5275}}},
2085 {{{-1.0550, -0.1822}, {0.6907, 1.5636}},
2086 {{2.4364, 3.3093}, {4.1822, 5.0550}}},
2087 {{{-1.5826, -0.2733}, {1.0360, 2.3453}},
2088 {{3.6547, 4.9640}, {6.2733, 7.5826}}},
2089 {{{-2.1101, -0.3644}, {1.3814, 3.1271}},
2090 {{4.8729, 6.6186}, {8.3644, 10.1101}}}},
2091 {{{{0.0000, 0.0000}, {0.0000, 0.0000}},
2092 {{0.0000, 0.0000}, {0.0000, 0.0000}}},
2093 {{{-0.5275, -0.0911}, {0.3453, 0.7818}},
2094 {{1.2182, 1.6547}, {2.0911, 2.5275}}},
2095 {{{-1.0550, -0.1822}, {0.6907, 1.5636}},
2096 {{2.4364, 3.3093}, {4.1822, 5.0550}}},
2097 {{{-1.5826, -0.2733}, {1.0360, 2.3453}},
2098 {{3.6547, 4.9640}, {6.2733, 7.5826}}},
2099 {{{-2.1101, -0.3644}, {1.3814, 3.1271}},
2100 {{4.8729, 6.6186}, {8.3644, 10.1101}}}}});
2101 ASSERT_TRUE(output.allclose(expected, 2e-04));
2102}
2103
2104TEST_F(FunctionalTest, InstanceNorm3dDefaultOptions) {
2105 int num_features = 5;
2106
2107 auto input = torch::arange(2. * num_features * 2 * 2 * 2)
2108 .view({2, num_features, 2, 2, 2});
2109 auto output = F::instance_norm(input);
2110 auto expected = torch::tensor(
2111 {{{{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2112 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2113 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2114 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2115 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2116 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2117 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2118 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2119 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2120 {{0.2182, 0.6547}, {1.0911, 1.5275}}}},
2121 {{{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2122 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2123 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2124 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2125 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2126 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2127 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2128 {{0.2182, 0.6547}, {1.0911, 1.5275}}},
2129 {{{-1.5275, -1.0911}, {-0.6547, -0.2182}},
2130 {{0.2182, 0.6547}, {1.0911, 1.5275}}}}});
2131 ASSERT_TRUE(output.allclose(expected, 2e-04));
2132}
2133
2134TEST_F(FunctionalTest, Interpolate) {
2135 {
2136 // 1D interpolation
2137 auto input = torch::ones({1, 1, 2});
2138 auto options = F::InterpolateFuncOptions()
2139 .size(std::vector<int64_t>({4}))
2140 .mode(torch::kNearest);
2141 auto output = F::interpolate(input, options);
2142 auto expected = torch::ones({1, 1, 4});
2143
2144 ASSERT_TRUE(output.allclose(expected));
2145 }
2146 {
2147 // 2D interpolation
2148 for (const auto align_corners : {true, false}) {
2149 // test float scale factor up & down sampling
2150 for (const auto scale_factor : {0.5, 1.5, 2.0}) {
2151 auto input = torch::ones({1, 1, 2, 2});
2152 auto options =
2153 F::InterpolateFuncOptions()
2154 .scale_factor(std::vector<double>({scale_factor, scale_factor}))
2155 .mode(torch::kBilinear)
2156 .align_corners(align_corners);
2157 auto output = F::interpolate(input, options);
2158 auto expected_size =
2159 static_cast<int64_t>(std::floor(input.size(-1) * scale_factor));
2160 auto expected = torch::ones({1, 1, expected_size, expected_size});
2161
2162 ASSERT_TRUE(output.allclose(expected));
2163 }
2164 }
2165 }
2166 {
2167 // 3D interpolation
2168 for (const auto align_corners : {true, false}) {
2169 for (const auto scale_factor : {0.5, 1.5, 2.0}) {
2170 auto input = torch::ones({1, 1, 2, 2, 2});
2171 auto options = F::InterpolateFuncOptions()
2172 .scale_factor(std::vector<double>(
2173 {scale_factor, scale_factor, scale_factor}))
2174 .mode(torch::kTrilinear)
2175 .align_corners(align_corners);
2176 auto output = F::interpolate(input, options);
2177 auto expected_size =
2178 static_cast<int64_t>(std::floor(input.size(-1) * scale_factor));
2179 auto expected =
2180 torch::ones({1, 1, expected_size, expected_size, expected_size});
2181
2182 ASSERT_TRUE(output.allclose(expected));
2183 }
2184 }
2185 }
2186 {
2187 auto input = torch::randn({3, 2, 2});
2188 ASSERT_THROWS_WITH(
2189 F::interpolate(
2190 input[0],
2191 F::InterpolateFuncOptions().size(std::vector<int64_t>({4, 4}))),
2192 "Input Error: Only 3D, 4D and 5D input Tensors supported (got 2D) "
2193 "for the modes: nearest | linear | bilinear | bicubic | trilinear (got kNearest)");
2194 ASSERT_THROWS_WITH(
2195 F::interpolate(
2196 torch::reshape(input, {1, 1, 1, 3, 2, 2}),
2197 F::InterpolateFuncOptions().size(
2198 std::vector<int64_t>({1, 1, 1, 3, 4, 4}))),
2199 "Input Error: Only 3D, 4D and 5D input Tensors supported (got 6D) "
2200 "for the modes: nearest | linear | bilinear | bicubic | trilinear (got kNearest)");
2201 ASSERT_THROWS_WITH(
2202 F::interpolate(input, F::InterpolateFuncOptions()),
2203 "either size or scale_factor should be defined");
2204 ASSERT_THROWS_WITH(
2205 F::interpolate(
2206 input,
2207 F::InterpolateFuncOptions()
2208 .size(std::vector<int64_t>({3, 4, 4}))
2209 .scale_factor(std::vector<double>({0.5}))),
2210 "only one of size or scale_factor should be defined");
2211 ASSERT_THROWS_WITH(
2212 F::interpolate(
2213 input,
2214 F::InterpolateFuncOptions().scale_factor(
2215 std::vector<double>({3, 2}))),
2216 "scale_factor shape must match input shape. "
2217 "Input is 1D, scale_factor size is [3, 2]");
2218 ASSERT_THROWS_WITH(
2219 F::interpolate(
2220 input,
2221 F::InterpolateFuncOptions()
2222 .mode(torch::kNearest)
2223 .align_corners(true)),
2224 "align_corners option can only be set with the "
2225 "interpolating modes: linear | bilinear | bicubic | trilinear");
2226 }
2227 {
2228 auto tensor = torch::rand({2, 3, 32, 32});
2229 std::vector<int64_t> osize = {8, 10};
2230 auto expected =
2231 at::native::_upsample_nearest_exact2d(tensor, osize, torch::nullopt);
2232
2233 auto options = F::InterpolateFuncOptions()
2234 .size(osize)
2235 .mode(torch::kNearestExact)
2236 .align_corners(false);
2237 auto output = F::interpolate(tensor, options);
2238
2239 ASSERT_TRUE(output.allclose(expected));
2240 }
2241 {
2242 auto tensor = torch::rand({2, 3, 32, 32});
2243 std::vector<int64_t> osize = {8, 10};
2244 auto expected = at::native::_upsample_bilinear2d_aa(
2245 tensor, osize, false, torch::nullopt);
2246
2247 auto options = F::InterpolateFuncOptions()
2248 .size(osize)
2249 .mode(torch::kBilinear)
2250 .align_corners(false)
2251 .antialias(true);
2252 auto output = F::interpolate(tensor, options);
2253 ASSERT_TRUE(output.allclose(expected));
2254 }
2255 {
2256 auto tensor = torch::rand({2, 3, 32, 32});
2257 std::vector<int64_t> osize = {8, 10};
2258 auto expected = at::native::_upsample_bicubic2d_aa(
2259 tensor, osize, false, torch::nullopt);
2260
2261 auto options = F::InterpolateFuncOptions()
2262 .size(osize)
2263 .mode(torch::kBicubic)
2264 .align_corners(false)
2265 .antialias(true);
2266 auto output = F::interpolate(tensor, options);
2267 ASSERT_TRUE(output.allclose(expected));
2268 }
2269}
2270
2271TEST_F(FunctionalTest, Pad1) {
2272 {
2273 auto input = torch::arange(6, torch::kDouble).reshape({1, 2, 3});
2274 auto output =
2275 F::pad(input, F::PadFuncOptions({1, 2}).mode(torch::kCircular));
2276 auto expected = torch::tensor(
2277 {{{2., 0., 1., 2., 0., 1.}, {5., 3., 4., 5., 3., 4.}}}, torch::kDouble);
2278 ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 2, 6}));
2279 ASSERT_TRUE(output.allclose(expected, 1e-04));
2280 }
2281}
2282TEST_F(FunctionalTest, Pad2) {
2283 {
2284 auto input = torch::arange(9, torch::kDouble).reshape({1, 1, 3, 3});
2285 auto output =
2286 F::pad(input, F::PadFuncOptions({3, 3, 3, 1}).mode(torch::kCircular));
2287 auto expected = torch::tensor(
2288 {{{{0., 1., 2., 0., 1., 2., 0., 1., 2.},
2289 {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2290 {6., 7., 8., 6., 7., 8., 6., 7., 8.},
2291 {0., 1., 2., 0., 1., 2., 0., 1., 2.},
2292 {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2293 {6., 7., 8., 6., 7., 8., 6., 7., 8.},
2294 {0., 1., 2., 0., 1., 2., 0., 1., 2.}}}},
2295 torch::kDouble);
2296 ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 7, 9}));
2297 ASSERT_TRUE(output.allclose(expected, 1e-04));
2298 }
2299}
2300TEST_F(FunctionalTest, Pad3) {
2301 {
2302 auto input = torch::arange(12, torch::kDouble).reshape({1, 1, 2, 2, 3});
2303 auto output = F::pad(
2304 input, F::PadFuncOptions({3, 3, 2, 1, 2, 2}).mode(torch::kCircular));
2305 auto expected = torch::tensor(
2306 {{{{{0., 1., 2., 0., 1., 2., 0., 1., 2.},
2307 {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2308 {0., 1., 2., 0., 1., 2., 0., 1., 2.},
2309 {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2310 {0., 1., 2., 0., 1., 2., 0., 1., 2.}},
2311
2312 {{6., 7., 8., 6., 7., 8., 6., 7., 8.},
2313 {9., 10., 11., 9., 10., 11., 9., 10., 11.},
2314 {6., 7., 8., 6., 7., 8., 6., 7., 8.},
2315 {9., 10., 11., 9., 10., 11., 9., 10., 11.},
2316 {6., 7., 8., 6., 7., 8., 6., 7., 8.}},
2317
2318 {{0., 1., 2., 0., 1., 2., 0., 1., 2.},
2319 {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2320 {0., 1., 2., 0., 1., 2., 0., 1., 2.},
2321 {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2322 {0., 1., 2., 0., 1., 2., 0., 1., 2.}},
2323
2324 {{6., 7., 8., 6., 7., 8., 6., 7., 8.},
2325 {9., 10., 11., 9., 10., 11., 9., 10., 11.},
2326 {6., 7., 8., 6., 7., 8., 6., 7., 8.},
2327 {9., 10., 11., 9., 10., 11., 9., 10., 11.},
2328 {6., 7., 8., 6., 7., 8., 6., 7., 8.}},
2329
2330 {{0., 1., 2., 0., 1., 2., 0., 1., 2.},
2331 {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2332 {0., 1., 2., 0., 1., 2., 0., 1., 2.},
2333 {3., 4., 5., 3., 4., 5., 3., 4., 5.},
2334 {0., 1., 2., 0., 1., 2., 0., 1., 2.}},
2335
2336 {{6., 7., 8., 6., 7., 8., 6., 7., 8.},
2337 {9., 10., 11., 9., 10., 11., 9., 10., 11.},
2338 {6., 7., 8., 6., 7., 8., 6., 7., 8.},
2339 {9., 10., 11., 9., 10., 11., 9., 10., 11.},
2340 {6., 7., 8., 6., 7., 8., 6., 7., 8.}}}}},
2341 torch::kDouble);
2342 ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 6, 5, 9}));
2343 ASSERT_TRUE(output.allclose(expected, 1e-04));
2344 }
2345}
2346TEST_F(FunctionalTest, Pad4) {
2347 {
2348 auto input = torch::arange(16, torch::kDouble).reshape({2, 2, 2, 2});
2349 auto output =
2350 F::pad(input, F::PadFuncOptions({1, 1, 1, 1}).mode(torch::kReflect));
2351 auto expected = torch::tensor(
2352 {{{{3., 2., 3., 2.},
2353 {1., 0., 1., 0.},
2354 {3., 2., 3., 2.},
2355 {1., 0., 1., 0.}},
2356
2357 {{7., 6., 7., 6.},
2358 {5., 4., 5., 4.},
2359 {7., 6., 7., 6.},
2360 {5., 4., 5., 4.}}},
2361
2362 {{{11., 10., 11., 10.},
2363 {9., 8., 9., 8.},
2364 {11., 10., 11., 10.},
2365 {9., 8., 9., 8.}},
2366
2367 {{15., 14., 15., 14.},
2368 {13., 12., 13., 12.},
2369 {15., 14., 15., 14.},
2370 {13., 12., 13., 12.}}}},
2371 torch::kDouble);
2372 ASSERT_EQ(output.sizes(), std::vector<int64_t>({2, 2, 4, 4}));
2373 ASSERT_TRUE(output.allclose(expected, 1e-04));
2374 }
2375}
2376TEST_F(FunctionalTest, Pad5) {
2377 {
2378 auto input = torch::arange(12, torch::kDouble).reshape({1, 1, 2, 2, 3});
2379 auto output = F::pad(
2380 input, F::PadFuncOptions({1, 2, 2, 1, 1, 2}).mode(torch::kReplicate));
2381 auto expected = torch::tensor(
2382 {{{{{0., 0., 1., 2., 2., 2.},
2383 {0., 0., 1., 2., 2., 2.},
2384 {0., 0., 1., 2., 2., 2.},
2385 {3., 3., 4., 5., 5., 5.},
2386 {3., 3., 4., 5., 5., 5.}},
2387
2388 {{0., 0., 1., 2., 2., 2.},
2389 {0., 0., 1., 2., 2., 2.},
2390 {0., 0., 1., 2., 2., 2.},
2391 {3., 3., 4., 5., 5., 5.},
2392 {3., 3., 4., 5., 5., 5.}},
2393
2394 {{6., 6., 7., 8., 8., 8.},
2395 {6., 6., 7., 8., 8., 8.},
2396 {6., 6., 7., 8., 8., 8.},
2397 {9., 9., 10., 11., 11., 11.},
2398 {9., 9., 10., 11., 11., 11.}},
2399
2400 {{6., 6., 7., 8., 8., 8.},
2401 {6., 6., 7., 8., 8., 8.},
2402 {6., 6., 7., 8., 8., 8.},
2403 {9., 9., 10., 11., 11., 11.},
2404 {9., 9., 10., 11., 11., 11.}},
2405
2406 {{6., 6., 7., 8., 8., 8.},
2407 {6., 6., 7., 8., 8., 8.},
2408 {6., 6., 7., 8., 8., 8.},
2409 {9., 9., 10., 11., 11., 11.},
2410 {9., 9., 10., 11., 11., 11.}}}}},
2411 torch::kDouble);
2412 ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 5, 5, 6}));
2413 ASSERT_TRUE(output.allclose(expected, 1e-04));
2414 }
2415}
2416TEST_F(FunctionalTest, Pad6) {
2417 {
2418 auto input = torch::arange(18, torch::kDouble).reshape({1, 1, 3, 2, 3});
2419 auto output = F::pad(
2420 input, F::PadFuncOptions({0, 2, 1, 0, 1, 2}).mode(torch::kReflect));
2421 auto expected = torch::tensor(
2422 {{{{{9., 10., 11., 10., 9.},
2423 {6., 7., 8., 7., 6.},
2424 {9., 10., 11., 10., 9.}},
2425
2426 {{3., 4., 5., 4., 3.}, {0., 1., 2., 1., 0.}, {3., 4., 5., 4., 3.}},
2427
2428 {{9., 10., 11., 10., 9.},
2429 {6., 7., 8., 7., 6.},
2430 {9., 10., 11., 10., 9.}},
2431
2432 {{15., 16., 17., 16., 15.},
2433 {12., 13., 14., 13., 12.},
2434 {15., 16., 17., 16., 15.}},
2435
2436 {{9., 10., 11., 10., 9.},
2437 {6., 7., 8., 7., 6.},
2438 {9., 10., 11., 10., 9.}},
2439
2440 {{3., 4., 5., 4., 3.},
2441 {0., 1., 2., 1., 0.},
2442 {3., 4., 5., 4., 3.}}}}},
2443 torch::kDouble);
2444 ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 6, 3, 5}));
2445 ASSERT_TRUE(output.allclose(expected, 1e-04));
2446 }
2447}
2448TEST_F(FunctionalTest, Pad7) {
2449 {
2450 auto input = torch::ones({1, 1, 1, 1}, torch::kDouble);
2451 auto output = F::pad(
2452 input, F::PadFuncOptions({1, 1}).mode(torch::kConstant).value(0));
2453 ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 1, 3}));
2454 auto expected = torch::tensor({{{{0., 1., 0.}}}}, torch::kDouble);
2455 }
2456}
2457TEST_F(FunctionalTest, Pad8) {
2458 {
2459 auto input = torch::ones({1, 1, 1, 1}, torch::kDouble);
2460 auto output = F::pad(input, F::PadFuncOptions({1, 1}));
2461 ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 1, 3}));
2462 auto expected = torch::tensor({{{{0., 1., 0.}}}}, torch::kDouble);
2463 }
2464}
2465
2466TEST_F(FunctionalTest, CTCLoss) {
2467 { // test CTCLoss typechecks
2468 const auto target_lengths = torch::tensor({30, 25, 20});
2469 const auto input_lengths = torch::tensor({50, 50, 50});
2470 const auto targets =
2471 torch::randint(1, 15, {target_lengths.sum().item<int>()}, torch::kInt);
2472 const auto log_probs =
2473 torch::randn({50, 3, 15}, torch::kFloat).log_softmax(2);
2474
2475 const auto _input_lengths = input_lengths.to(torch::kFloat);
2476 ASSERT_THROWS_WITH(
2477 F::ctc_loss(log_probs, targets, _input_lengths, target_lengths),
2478 "input_lengths must be integral");
2479
2480 const auto target_lengths_ = target_lengths.to(torch::kFloat);
2481 ASSERT_THROWS_WITH(
2482 F::ctc_loss(log_probs, targets, input_lengths, target_lengths_),
2483 "target_lengths must be integral");
2484 }
2485 { // test CTCLoss length checks
2486 const auto target_lengths = torch::tensor({30, 25, 20});
2487 const auto input_lengths = torch::tensor({50, 50, 50});
2488 const auto targets = torch::randint(1, 15, {3, 29}, torch::kInt);
2489 const auto log_probs =
2490 torch::randn({50, 3, 15}, torch::kFloat).log_softmax(2);
2491 ASSERT_THROWS_WITH(
2492 F::ctc_loss(log_probs, targets, input_lengths, target_lengths),
2493 "Expected tensor to have size at least 30 at dimension 1");
2494 }
2495 { // test CTCLoss empty target
2496 {
2497 const auto target_lengths = torch::tensor({0, 0, 0});
2498 const auto input_lengths = torch::tensor({50, 50, 50});
2499 const auto targets =
2500 torch::randint(1, 15, at::IntArrayRef({0}), torch::kLong);
2501 const auto log_probs =
2502 torch::randn({50, 3, 15}, torch::kDouble).log_softmax(2);
2503 const auto loss = F::ctc_loss(
2504 log_probs,
2505 targets,
2506 input_lengths,
2507 target_lengths,
2508 F::CTCLossFuncOptions().reduction(torch::kNone));
2509 ASSERT_TRUE(loss.ge(0).all().item<bool>());
2510 ASSERT_TRUE(torch::allclose(
2511 -log_probs.sum(0).slice(1, 0, 1).view_as(loss), loss));
2512 }
2513 {
2514 const auto target_lengths = torch::tensor({0, 9, 0});
2515 const auto input_lengths = torch::tensor({50, 50, 50});
2516 const auto targets = torch::randint(1, 15, {9}, torch::kLong);
2517 const auto log_probs =
2518 torch::randn({50, 3, 15}, torch::kDouble).log_softmax(2);
2519 const auto loss = F::ctc_loss(
2520 log_probs,
2521 targets,
2522 input_lengths,
2523 target_lengths,
2524 F::CTCLossFuncOptions().reduction(torch::kNone));
2525 ASSERT_TRUE(loss.ge(0).all().item<bool>());
2526 ASSERT_TRUE(torch::allclose(
2527 -log_probs.sum(0)
2528 .index_select(0, torch::tensor({0, 2}, torch::kLong))
2529 .slice(1, 0, 1)
2530 .view({2}),
2531 loss.index_select(0, torch::tensor({0, 2}, torch::kLong))));
2532 }
2533 }
2534}
2535
2536TEST_F(FunctionalTest, PoissonNLLLoss) {
2537 const auto input = torch::tensor({0.5, 1.5, 2.5});
2538 const auto target = torch::tensor({1., 2., 3.});
2539 const auto component_wise_loss = torch::exp(input) - target * input;
2540 ASSERT_TRUE(torch::allclose(
2541 torch::mean(component_wise_loss), F::poisson_nll_loss(input, target)));
2542 ASSERT_TRUE(torch::allclose(
2543 component_wise_loss,
2544 F::poisson_nll_loss(
2545 input,
2546 target,
2547 F::PoissonNLLLossFuncOptions().reduction(torch::kNone))));
2548 ASSERT_TRUE(torch::allclose(
2549 torch::sum(component_wise_loss),
2550 F::poisson_nll_loss(
2551 input,
2552 target,
2553 F::PoissonNLLLossFuncOptions().reduction(torch::kSum))));
2554 ASSERT_TRUE(torch::allclose(
2555 torch::mean(component_wise_loss),
2556 F::poisson_nll_loss(
2557 input,
2558 target,
2559 F::PoissonNLLLossFuncOptions().reduction(torch::kMean))));
2560}
2561
2562TEST_F(FunctionalTest, MarginRankingLoss) {
2563 {
2564 const auto input1 = torch::randn(15) * 10;
2565 const auto input2 = torch::randn(15) * 10;
2566 const auto target = torch::randn(15).sign();
2567 ASSERT_TRUE(torch::allclose(
2568 F::margin_ranking_loss(input1, input2, target),
2569 (-target * (input1 - input2)).clamp(0).mean()));
2570 }
2571 {
2572 const auto input1 = torch::randn(15) * 10;
2573 const auto input2 = torch::randn(15) * 10;
2574 const auto target = torch::randn(15).sign();
2575 const auto margin = 0.5;
2576 ASSERT_TRUE(torch::allclose(
2577 F::margin_ranking_loss(
2578 input1,
2579 input2,
2580 target,
2581 F::MarginRankingLossFuncOptions().margin(0.5).reduction(
2582 torch::kSum)),
2583 (-target * (input1 - input2) + margin).clamp(0).sum()));
2584 }
2585 {
2586 const auto input1 = torch::randn(15) * 10;
2587 const auto input2 = torch::randn(15) * 10;
2588 const auto target = torch::randn(15).sign();
2589 const auto margin = 0.5;
2590 ASSERT_TRUE(torch::allclose(
2591 F::margin_ranking_loss(
2592 input1,
2593 input2,
2594 target,
2595 F::MarginRankingLossFuncOptions().margin(0.5).reduction(
2596 torch::kMean)),
2597 (-target * (input1 - input2) + margin).clamp(0).mean()));
2598 }
2599}
2600
2601TEST_F(FunctionalTest, ConvTranspose1d) {
2602 auto x = torch::arange(20.).view({2, 2, 5});
2603 auto weight = torch::arange(18.).view({2, 3, 3});
2604 auto y =
2605 F::conv_transpose1d(x, weight, F::ConvTranspose1dFuncOptions().stride(1));
2606 auto expected = torch::tensor(
2607 {{{45., 104., 179., 212., 245., 188., 107.},
2608 {60., 140., 242., 293., 344., 260., 146.},
2609 {75., 176., 305., 374., 443., 332., 185.}},
2610 {{135., 304., 509., 542., 575., 428., 237.},
2611 {210., 460., 752., 803., 854., 620., 336.},
2612 {285., 616., 995., 1064., 1133., 812., 435.}}});
2613 ASSERT_TRUE(torch::allclose(y, expected));
2614
2615 auto y_no_options = F::conv_transpose1d(x, weight);
2616 ASSERT_TRUE(torch::allclose(y_no_options, expected));
2617}
2618
2619TEST_F(FunctionalTest, ConvTranspose2dEven) {
2620 auto x = torch::arange(50.).view({1, 2, 5, 5});
2621 auto weight = torch::arange(54.).view({2, 3, 3, 3});
2622 auto y =
2623 F::conv_transpose2d(x, weight, F::ConvTranspose2dFuncOptions().stride(1));
2624 auto expected = torch::tensor(
2625 {{{{675., 1402., 2183., 2270., 2357., 1634., 849.},
2626 {1560., 3240., 5044., 5236., 5428., 3760., 1952.},
2627 {2685., 5574., 8673., 8988., 9303., 6438., 3339.},
2628 {3180., 6594., 10248., 10563., 10878., 7518., 3894.},
2629 {3675., 7614., 11823., 12138., 12453., 8598., 4449.},
2630 {2820., 5832., 9040., 9268., 9496., 6544., 3380.},
2631 {1605., 3314., 5129., 5252., 5375., 3698., 1907.}},
2632 {{900., 1870., 2912., 3053., 3194., 2210., 1146.},
2633 {2100., 4356., 6772., 7072., 7372., 5092., 2636.},
2634 {3630., 7518., 11670., 12147., 12624., 8706., 4500.},
2635 {4395., 9078., 14055., 14532., 15009., 10326., 5325.},
2636 {5160., 10638., 16440., 16917., 17394., 11946., 6150.},
2637 {3900., 8028., 12388., 12724., 13060., 8956., 4604.},
2638 {2190., 4502., 6938., 7115., 7292., 4994., 2564.}},
2639 {{1125., 2338., 3641., 3836., 4031., 2786., 1443.},
2640 {2640., 5472., 8500., 8908., 9316., 6424., 3320.},
2641 {4575., 9462., 14667., 15306., 15945., 10974., 5661.},
2642 {5610., 11562., 17862., 18501., 19140., 13134., 6756.},
2643 {6645., 13662., 21057., 21696., 22335., 15294., 7851.},
2644 {4980., 10224., 15736., 16180., 16624., 11368., 5828.},
2645 {2775., 5690., 8747., 8978., 9209., 6290., 3221.}}}});
2646 ASSERT_TRUE(torch::allclose(y, expected));
2647
2648 auto y_no_options = F::conv_transpose2d(x, weight);
2649 ASSERT_TRUE(torch::allclose(y_no_options, expected));
2650}
2651
2652TEST_F(FunctionalTest, ConvTranspose2dUneven) {
2653 auto x = torch::arange(40.).view({1, 2, 5, 4});
2654 auto weight = torch::arange(36.).view({2, 3, 3, 2});
2655 auto y =
2656 F::conv_transpose2d(x, weight, F::ConvTranspose2dFuncOptions().stride(1));
2657 auto expected = torch::tensor(
2658 {{{{360., 758., 796., 834., 440.},
2659 {832., 1752., 1836., 1920., 1012.},
2660 {1432., 3014., 3152., 3290., 1732.},
2661 {1696., 3566., 3704., 3842., 2020.},
2662 {1960., 4118., 4256., 4394., 2308.},
2663 {1504., 3152., 3252., 3352., 1756.},
2664 {856., 1790., 1844., 1898., 992.}},
2665 {{480., 1010., 1072., 1134., 596.},
2666 {1120., 2352., 2484., 2616., 1372.},
2667 {1936., 4058., 4268., 4478., 2344.},
2668 {2344., 4898., 5108., 5318., 2776.},
2669 {2752., 5738., 5948., 6158., 3208.},
2670 {2080., 4328., 4476., 4624., 2404.},
2671 {1168., 2426., 2504., 2582., 1340.}},
2672 {{600., 1262., 1348., 1434., 752.},
2673 {1408., 2952., 3132., 3312., 1732.},
2674 {2440., 5102., 5384., 5666., 2956.},
2675 {2992., 6230., 6512., 6794., 3532.},
2676 {3544., 7358., 7640., 7922., 4108.},
2677 {2656., 5504., 5700., 5896., 3052.},
2678 {1480., 3062., 3164., 3266., 1688.}}}});
2679 ASSERT_TRUE(torch::allclose(y, expected));
2680
2681 auto y_no_options = F::conv_transpose2d(x, weight);
2682 ASSERT_TRUE(torch::allclose(y_no_options, expected));
2683}
2684
2685TEST_F(FunctionalTest, ConvTranspose3d) {
2686 auto x = torch::arange(16.).view({1, 2, 2, 2, 2});
2687 auto weight = torch::arange(32.).view({2, 2, 2, 2, 2});
2688 auto y =
2689 F::conv_transpose3d(x, weight, F::ConvTranspose3dFuncOptions().stride(1));
2690 auto expected = torch::tensor(
2691 {{{{{128., 280., 154.}, {304., 664., 364.}, {184., 400., 218.}},
2692 {{352., 768., 420.}, {832., 1808., 984.}, {496., 1072., 580.}},
2693 {{256., 552., 298.}, {592., 1272., 684.}, {344., 736., 394.}}},
2694 {{{192., 424., 234.}, {464., 1016., 556.}, {280., 608., 330.}},
2695 {{544., 1184., 644.}, {1280., 2768., 1496.}, {752., 1616., 868.}},
2696 {{384., 824., 442.}, {880., 1880., 1004.}, {504., 1072., 570.}}}}});
2697 ASSERT_TRUE(torch::allclose(y, expected));
2698
2699 auto y_no_options = F::conv_transpose3d(x, weight);
2700 ASSERT_TRUE(torch::allclose(y_no_options, expected));
2701}
2702
2703TEST_F(FunctionalTest, AlphaDropout) {
2704 auto input = torch::randn(5000);
2705 auto input_mean = input.mean();
2706 auto input_std = input.std();
2707
2708 for (const auto rate : {0.2, 0.5, 0.8}) {
2709 for (const auto inplace : {false, true}) {
2710 auto input_ = input.clone();
2711 auto output = F::alpha_dropout(
2712 input_,
2713 F::AlphaDropoutFuncOptions().p(rate).training(false).inplace(
2714 inplace));
2715 ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.1));
2716 ASSERT_TRUE(torch::allclose(input_std, output.std(), 0.1));
2717 if (inplace) {
2718 ASSERT_TRUE(torch::allclose(input_, output));
2719 }
2720 }
2721 }
2722 auto output = F::detail::alpha_dropout(input, 0.5, false, false);
2723 ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.1));
2724 ASSERT_TRUE(torch::allclose(input_std, output.std(), 0.1));
2725}
2726
2727TEST_F(FunctionalTest, FeatureAlphaDropout) {
2728 auto input = torch::randn(5000);
2729 auto input_mean = input.mean();
2730 auto input_std = input.std();
2731
2732 for (const auto rate : {0.2, 0.5, 0.8}) {
2733 for (const auto inplace : {false, true}) {
2734 auto input_ = input.clone();
2735 auto output = F::feature_alpha_dropout(
2736 input_,
2737 F::FeatureAlphaDropoutFuncOptions().p(rate).training(false).inplace(
2738 inplace));
2739 ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.1));
2740 ASSERT_TRUE(torch::allclose(input_std, output.std(), 0.1));
2741 if (inplace) {
2742 ASSERT_TRUE(torch::allclose(input_, output));
2743 }
2744 }
2745 }
2746 auto output = F::feature_alpha_dropout(input);
2747 ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.1));
2748 ASSERT_TRUE(torch::allclose(input_std, output.std(), 0.1));
2749}
2750
2751TEST_F(FunctionalTest, Dropout) {
2752 auto input = torch::randn(5000);
2753 auto input_mean = input.mean();
2754 auto input_std = input.std();
2755
2756 for (const auto rate : {0.2, 0.5, 0.8}) {
2757 auto output = F::dropout(input, F::DropoutFuncOptions().p(rate));
2758 ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
2759 ASSERT_TRUE((input_std <= output.std()).all().item<bool>());
2760 }
2761 auto output = F::dropout(input);
2762 ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
2763 ASSERT_TRUE((input_std <= output.std()).all().item<bool>());
2764 ASSERT_TRUE(F::dropout(torch::tensor(1.)).defined());
2765}
2766
2767TEST_F(FunctionalTest, Dropout2d) {
2768 auto input = torch::randn({2, 2, 50, 100});
2769 auto input_mean = input.mean();
2770 auto input_std = input.std();
2771
2772 for (const auto rate : {0.2, 0.5, 0.8}) {
2773 auto output = F::dropout2d(input, F::Dropout2dFuncOptions().p(rate));
2774 ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
2775 }
2776 auto output = F::dropout2d(input);
2777 ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
2778 ASSERT_TRUE(F::dropout2d(torch::randn({2, 50, 100})).defined());
2779}
2780
2781TEST_F(FunctionalTest, Dropout3d) {
2782 auto input = torch::randn({2, 2, 50, 10, 10});
2783 auto input_mean = input.mean();
2784 auto input_std = input.std();
2785
2786 for (const auto rate : {0.2, 0.5, 0.8}) {
2787 auto output = F::dropout3d(input, F::Dropout3dFuncOptions().p(rate));
2788 ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
2789 }
2790 auto output = F::dropout3d(input);
2791 ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
2792 ASSERT_TRUE(F::dropout3d(torch::randn({2, 50, 10, 10})).defined());
2793}
2794
2795template <c10::ScalarType S, typename T>
2796void test_isfinite(const at::Device& device) {
2797 const std::vector<T> values = {
2798 std::numeric_limits<T>::lowest(),
2799 0,
2800 1,
2801 42,
2802 std::numeric_limits<T>::min(),
2803 std::numeric_limits<T>::max()};
2804 for (const auto value : values) {
2805 const auto x = torch::full(
2806 {3, 3}, value, torch::TensorOptions().dtype(S).device(device));
2807 ASSERT_TRUE(torch::isfinite(x).all().template item<bool>());
2808 }
2809 if (std::numeric_limits<T>::has_infinity) {
2810 const auto inf = std::numeric_limits<T>::infinity();
2811 const auto x = torch::tensor(
2812 {-inf,
2813 std::numeric_limits<T>::lowest(),
2814 static_cast<T>(0),
2815 static_cast<T>(1),
2816 static_cast<T>(42),
2817 std::numeric_limits<T>::min(),
2818 std::numeric_limits<T>::max(),
2819 inf},
2820 torch::TensorOptions().dtype(S).device(device));
2821 ASSERT_TRUE(torch::allclose(
2822 // torch::allclose does not support comparing torch::kBool
2823 torch::isfinite(x).toType(torch::kInt),
2824 torch::tensor(
2825 {false, true, true, true, true, true, true, false},
2826 torch::TensorOptions().device(device))
2827 .toType(torch::kInt)));
2828 }
2829 if (std::numeric_limits<T>::has_quiet_NaN) {
2830 const auto x = torch::tensor(
2831 {std::numeric_limits<T>::quiet_NaN()},
2832 torch::TensorOptions().dtype(S).device(device));
2833 ASSERT_FALSE(torch::isfinite(x).all().template item<bool>());
2834 }
2835 if (std::numeric_limits<T>::has_signaling_NaN) {
2836 const auto x = torch::tensor(
2837 {std::numeric_limits<T>::signaling_NaN()},
2838 torch::TensorOptions().dtype(S).device(device));
2839 ASSERT_FALSE(torch::isfinite(x).all().template item<bool>());
2840 }
2841}
2842
2843TEST_F(FunctionalTest, isfinite) {
2844 const at::Device device("cpu");
2845 test_isfinite<torch::kUInt8, uint8_t>(device);
2846 test_isfinite<torch::kInt8, int8_t>(device);
2847 test_isfinite<torch::kInt16, int16_t>(device);
2848 test_isfinite<torch::kInt32, int32_t>(device);
2849 test_isfinite<torch::kInt64, int64_t>(device);
2850 test_isfinite<torch::kFloat32, float>(device);
2851 test_isfinite<torch::kFloat64, double>(device);
2852}
2853
2854TEST_F(FunctionalTest, isfinite_CUDA) {
2855 const at::Device device("cuda");
2856 test_isfinite<torch::kUInt8, uint8_t>(device);
2857 test_isfinite<torch::kInt8, int8_t>(device);
2858 test_isfinite<torch::kInt16, int16_t>(device);
2859 test_isfinite<torch::kInt32, int32_t>(device);
2860 test_isfinite<torch::kInt64, int64_t>(device);
2861 test_isfinite<torch::kFloat32, float>(device);
2862 test_isfinite<torch::kFloat64, double>(device);
2863 test_isfinite<torch::kFloat16, c10::Half>(device);
2864}
2865
2866template <c10::ScalarType S, typename T>
2867void test_isinf(const at::Device& device) {
2868 const std::vector<T> values = {
2869 std::numeric_limits<T>::lowest(),
2870 0,
2871 1,
2872 42,
2873 std::numeric_limits<T>::min(),
2874 std::numeric_limits<T>::max()};
2875 for (const auto value : values) {
2876 const auto x = torch::full(
2877 {3, 3}, value, torch::TensorOptions().dtype(S).device(device));
2878 ASSERT_FALSE(torch::isinf(x).all().template item<bool>());
2879 }
2880 if (std::numeric_limits<T>::has_infinity) {
2881 const auto inf = std::numeric_limits<T>::infinity();
2882 const auto x = torch::tensor(
2883 {-inf,
2884 std::numeric_limits<T>::lowest(),
2885 static_cast<T>(0),
2886 static_cast<T>(1),
2887 static_cast<T>(42),
2888 std::numeric_limits<T>::min(),
2889 std::numeric_limits<T>::max(),
2890 inf},
2891 torch::TensorOptions().dtype(S).device(device));
2892 ASSERT_TRUE(torch::allclose(
2893 // torch::allclose does not support comparing torch::kBool
2894 torch::isinf(x).toType(torch::kInt),
2895 torch::tensor(
2896 {true, false, false, false, false, false, false, true},
2897 torch::TensorOptions().device(device))
2898 .toType(torch::kInt)));
2899 }
2900 if (std::numeric_limits<T>::has_quiet_NaN) {
2901 const auto x = torch::tensor(
2902 {std::numeric_limits<T>::quiet_NaN()},
2903 torch::TensorOptions().dtype(S).device(device));
2904 ASSERT_FALSE(torch::isinf(x).all().template item<bool>());
2905 }
2906 if (std::numeric_limits<T>::has_signaling_NaN) {
2907 const auto x = torch::tensor(
2908 {std::numeric_limits<T>::signaling_NaN()},
2909 torch::TensorOptions().dtype(S).device(device));
2910 ASSERT_FALSE(torch::isinf(x).all().template item<bool>());
2911 }
2912}
2913
2914TEST_F(FunctionalTest, isinf) {
2915 const at::Device device("cpu");
2916 test_isinf<torch::kUInt8, uint8_t>(device);
2917 test_isinf<torch::kInt8, int8_t>(device);
2918 test_isinf<torch::kInt16, int16_t>(device);
2919 test_isinf<torch::kInt32, int32_t>(device);
2920 test_isinf<torch::kInt64, int64_t>(device);
2921 test_isinf<torch::kFloat32, float>(device);
2922 test_isinf<torch::kFloat64, double>(device);
2923}
2924
2925TEST_F(FunctionalTest, isinf_CUDA) {
2926 const at::Device device("cuda");
2927 test_isinf<torch::kUInt8, uint8_t>(device);
2928 test_isinf<torch::kInt8, int8_t>(device);
2929 test_isinf<torch::kInt16, int16_t>(device);
2930 test_isinf<torch::kInt32, int32_t>(device);
2931 test_isinf<torch::kInt64, int64_t>(device);
2932 test_isinf<torch::kFloat32, float>(device);
2933 test_isinf<torch::kFloat64, double>(device);
2934 test_isinf<torch::kFloat16, c10::Half>(device);
2935}
2936
2937template <c10::ScalarType S, typename T>
2938void test_allclose(const at::Device& device) {
2939 const std::vector<T> values = {
2940 std::numeric_limits<T>::lowest(),
2941 0,
2942 1,
2943 42,
2944 std::numeric_limits<T>::min(),
2945 std::numeric_limits<T>::max()};
2946 for (const auto value : values) {
2947 const auto x =
2948 torch::full({1}, value, torch::TensorOptions().dtype(S).device(device));
2949 const auto y =
2950 torch::full({1}, value, torch::TensorOptions().dtype(S).device(device));
2951 ASSERT_TRUE(torch::allclose(x, x));
2952 ASSERT_TRUE(torch::allclose(x, y));
2953 ASSERT_TRUE(torch::allclose(y, x));
2954 ASSERT_FALSE(torch::allclose(1.1 * x + 0.1, 1.0 * x));
2955 ASSERT_TRUE(torch::allclose(0.99 * x + 0.1, 1.0 * x, 1.1, 0.1));
2956 }
2957 if (std::numeric_limits<T>::has_infinity) {
2958 const auto inf = std::numeric_limits<T>::infinity();
2959 const auto x = torch::tensor(
2960 {-inf, inf}, torch::TensorOptions().dtype(S).device(device));
2961 const auto y = torch::tensor(
2962 {-inf, inf}, torch::TensorOptions().dtype(S).device(device));
2963 ASSERT_TRUE(torch::allclose(x, x));
2964 ASSERT_TRUE(torch::allclose(x, y));
2965 ASSERT_TRUE(torch::allclose(y, x));
2966 }
2967 if (std::numeric_limits<T>::has_quiet_NaN) {
2968 const auto x = torch::tensor(
2969 {std::numeric_limits<T>::quiet_NaN()},
2970 torch::TensorOptions().dtype(S).device(device));
2971 const auto y = torch::tensor(
2972 {std::numeric_limits<T>::quiet_NaN()},
2973 torch::TensorOptions().dtype(S).device(device));
2974 ASSERT_TRUE(torch::allclose(x, x, 1.0, 0.0, /*equal_nan=*/true));
2975 ASSERT_TRUE(torch::allclose(x, y, 1.0, 0.0, /*equal_nan=*/true));
2976 ASSERT_TRUE(torch::allclose(y, x, 1.0, 0.0, /*equal_nan=*/true));
2977 }
2978 if (std::numeric_limits<T>::has_signaling_NaN) {
2979 const auto x = torch::tensor(
2980 {std::numeric_limits<T>::signaling_NaN()},
2981 torch::TensorOptions().dtype(S).device(device));
2982 const auto y = torch::tensor(
2983 {std::numeric_limits<T>::signaling_NaN()},
2984 torch::TensorOptions().dtype(S).device(device));
2985 ASSERT_TRUE(torch::allclose(x, x, 1.0, 0.0, /*equal_nan=*/true));
2986 ASSERT_TRUE(torch::allclose(x, y, 1.0, 0.0, /*equal_nan=*/true));
2987 ASSERT_TRUE(torch::allclose(y, x, 1.0, 0.0, /*equal_nan=*/true));
2988 }
2989}
2990
2991TEST_F(FunctionalTest, AllClose) {
2992 const at::Device device("cpu");
2993 test_allclose<torch::kUInt8, uint8_t>(device);
2994 test_allclose<torch::kInt8, int8_t>(device);
2995 test_allclose<torch::kInt16, int16_t>(device);
2996 test_allclose<torch::kInt32, int32_t>(device);
2997 test_allclose<torch::kInt64, int64_t>(device);
2998 test_allclose<torch::kFloat32, float>(device);
2999 test_allclose<torch::kFloat64, double>(device);
3000}
3001
3002TEST_F(FunctionalTest, AllClose_CUDA) {
3003 const at::Device device("cuda");
3004 test_allclose<torch::kUInt8, uint8_t>(device);
3005 test_allclose<torch::kInt8, int8_t>(device);
3006 test_allclose<torch::kInt16, int16_t>(device);
3007 test_allclose<torch::kInt32, int32_t>(device);
3008 test_allclose<torch::kInt64, int64_t>(device);
3009 test_allclose<torch::kFloat32, float>(device);
3010 test_allclose<torch::kFloat64, double>(device);
3011 test_allclose<torch::kFloat16, c10::Half>(device);
3012}
3013
3014TEST_F(FunctionalTest, BCEWithLogitsLoss) {
3015 {// test BCE with logits raises if target and input are different size
3016 {const auto target = torch::rand(5);
3017 const auto input = torch::rand({5, 1});
3018 ASSERT_THROWS_WITH(
3019 F::binary_cross_entropy_with_logits(input, target),
3020 "must be the same as input size");
3021}
3022{
3023 const auto target = torch::rand({5, 1});
3024 const auto input = torch::rand(5);
3025 ASSERT_THROWS_WITH(
3026 F::binary_cross_entropy_with_logits(input, target),
3027 "must be the same as input size");
3028}
3029}
3030{ // test BCE with logits gives same result as sigmoid and bce loss
3031 auto sigmoid = Sigmoid();
3032
3033 auto target = torch::rand({64, 4});
3034 auto output = torch::rand({64, 4}) - 0.5;
3035
3036 ASSERT_TRUE(torch::allclose(
3037 F::binary_cross_entropy_with_logits(output, target),
3038 F::binary_cross_entropy(sigmoid(output), target)));
3039
3040 auto weight = torch::rand(4);
3041 ASSERT_TRUE(torch::allclose(
3042 F::binary_cross_entropy_with_logits(
3043 output,
3044 target,
3045 F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight)),
3046 F::binary_cross_entropy(
3047 sigmoid(output),
3048 target,
3049 F::BinaryCrossEntropyFuncOptions().weight(weight))));
3050
3051 target = torch::zeros({4, 1}, torch::kFloat);
3052 output = torch::empty({4, 1}, torch::kFloat).fill_(-100);
3053
3054 ASSERT_TRUE(torch::allclose(
3055 F::binary_cross_entropy_with_logits(output, target),
3056 F::binary_cross_entropy(sigmoid(output), target)));
3057
3058 ASSERT_TRUE(torch::allclose(
3059 F::binary_cross_entropy_with_logits(
3060 output,
3061 target,
3062 F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone)),
3063 F::binary_cross_entropy(
3064 sigmoid(output),
3065 target,
3066 F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))));
3067
3068 weight = torch::rand({1}, torch::kFloat);
3069 ASSERT_TRUE(torch::allclose(
3070 F::binary_cross_entropy_with_logits(
3071 output,
3072 target,
3073 F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight)),
3074 F::binary_cross_entropy(
3075 sigmoid(output),
3076 target,
3077 F::BinaryCrossEntropyFuncOptions().weight(weight))));
3078}
3079{ // test BCE with logits has correct grad at zero
3080 const auto output = torch::zeros({3, 1}, torch::requires_grad());
3081 const auto target = torch::zeros({3, 1});
3082 F::binary_cross_entropy_with_logits(
3083 output,
3084 target,
3085 F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kSum))
3086 .backward();
3087 const auto expected_grad = torch::empty({3, 1}).fill_(0.5);
3088 ASSERT_TRUE(torch::allclose(output.grad(), expected_grad));
3089}
3090{ // test BCE with logits broadcasts weights
3091 const auto target = torch::rand({16, 4});
3092 const auto output = torch::rand({16, 4}) - 0.5;
3093
3094 auto weight = torch::rand(4);
3095 auto out1 = F::binary_cross_entropy_with_logits(
3096 output,
3097 target,
3098 F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight));
3099
3100 weight = weight.expand({16, 4}).contiguous();
3101 auto out2 = F::binary_cross_entropy_with_logits(
3102 output,
3103 target,
3104 F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight));
3105
3106 ASSERT_TRUE(torch::allclose(out1, out2));
3107
3108 weight = torch::rand({16, 1});
3109 out1 = F::binary_cross_entropy_with_logits(
3110 output,
3111 target,
3112 F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight));
3113
3114 weight = weight.expand({16, 4}).contiguous();
3115 out2 = F::binary_cross_entropy_with_logits(
3116 output,
3117 target,
3118 F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight));
3119
3120 ASSERT_TRUE(torch::allclose(out1, out2));
3121}
3122{ // test BCE with logits ones in pos weights are the same as none
3123 const auto target = torch::rand({64, 4});
3124 const auto output = torch::rand({64, 4}) - 0.5;
3125 const auto pos_weight = torch::ones({64, 4});
3126
3127 ASSERT_TRUE(torch::allclose(
3128 F::binary_cross_entropy_with_logits(output, target),
3129 F::binary_cross_entropy_with_logits(
3130 output,
3131 target,
3132 F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(
3133 pos_weight))));
3134}
3135{ // test BCE with logits broadcasts pos weights
3136 const auto target = torch::rand({64, 4});
3137 const auto output = torch::rand({64, 4}) - 0.5;
3138 const auto pos_weight = torch::rand(4);
3139 const auto out1 = F::binary_cross_entropy_with_logits(
3140 output,
3141 target,
3142 F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight));
3143
3144 const auto pos_weight1 = pos_weight.expand({1, 4});
3145 const auto out2 = F::binary_cross_entropy_with_logits(
3146 output,
3147 target,
3148 F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight));
3149
3150 const auto pos_weight2 = pos_weight.expand({64, 4});
3151 const auto out3 = F::binary_cross_entropy_with_logits(
3152 output,
3153 target,
3154 F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight));
3155
3156 ASSERT_TRUE(torch::allclose(out1, out2));
3157 ASSERT_TRUE(torch::allclose(out1, out3));
3158}
3159{ // test BCE with logits with pos weight has correct grad at zero
3160 const auto output = torch::zeros({3, 1}, torch::requires_grad());
3161 const auto target = torch::zeros({3, 1});
3162 const auto pos_weight = torch::ones({3, 1});
3163 F::binary_cross_entropy_with_logits(
3164 output,
3165 target,
3166 F::BinaryCrossEntropyWithLogitsFuncOptions()
3167 .pos_weight(pos_weight)
3168 .reduction(torch::kSum))
3169 .backward();
3170 const auto expected_grad = torch::empty({3, 1}).fill_(0.5);
3171 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
3172 const auto grad = output.grad();
3173 ASSERT_TRUE(torch::allclose(grad, expected_grad));
3174}
3175{ // test BCE with logits stability
3176 const auto output = torch::tensor({0., -120.});
3177 const auto target = torch::tensor({0., 1.});
3178 const auto pos_weight = torch::tensor({1., 1.});
3179
3180 const auto out1 = F::binary_cross_entropy_with_logits(output, target);
3181 ASSERT_TRUE(torch::isfinite(out1).all().item<bool>());
3182
3183 const auto out2 = F::binary_cross_entropy_with_logits(
3184 output,
3185 target,
3186 F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight));
3187 ASSERT_TRUE(torch::isfinite(out2).all().item<bool>());
3188}
3189}
3190