1 | #include <gtest/gtest.h> |
2 | |
3 | #include <c10/util/irange.h> |
4 | #include <torch/torch.h> |
5 | |
6 | #include <test/cpp/api/support.h> |
7 | |
8 | namespace F = torch::nn::functional; |
9 | |
10 | using namespace torch::nn; |
11 | |
12 | struct FunctionalTest : torch::test::SeedingFixture {}; |
13 | |
14 | TEST_F(FunctionalTest, Conv1d) { |
15 | auto x = torch::arange(30, torch::dtype(torch::kFloat).requires_grad(true)) |
16 | .reshape({2, 3, 5}); |
17 | auto weight = |
18 | torch::arange(18, torch::dtype(torch::kFloat).requires_grad(true)) |
19 | .reshape({2, 3, 3}); |
20 | auto y = F::conv1d(x, weight, F::Conv1dFuncOptions().stride(1)); |
21 | auto expected = torch::tensor( |
22 | {{{312., 348., 384.}, {798., 915., 1032.}}, |
23 | |
24 | {{852., 888., 924.}, {2553., 2670., 2787.}}}, |
25 | torch::kFloat); |
26 | ASSERT_TRUE(torch::allclose(y, expected)); |
27 | |
28 | auto y_no_options = F::conv1d(x, weight); |
29 | ASSERT_TRUE(torch::allclose(y_no_options, expected)); |
30 | } |
31 | |
32 | TEST_F(FunctionalTest, Conv2dEven) { |
33 | auto x = torch::arange(75, torch::dtype(torch::kFloat).requires_grad(true)) |
34 | .reshape({1, 3, 5, 5}); |
35 | auto weight = |
36 | torch::arange(54, torch::dtype(torch::kFloat).requires_grad(true)) |
37 | .reshape({2, 3, 3, 3}); |
38 | auto y = F::conv2d(x, weight, F::Conv2dFuncOptions().stride(1)); |
39 | auto expected = torch::tensor( |
40 | {{{{15219., 15570., 15921.}, |
41 | {16974., 17325., 17676.}, |
42 | {18729., 19080., 19431.}}, |
43 | |
44 | {{37818., 38898., 39978.}, |
45 | {43218., 44298., 45378.}, |
46 | {48618., 49698., 50778.}}}}, |
47 | torch::kFloat); |
48 | ASSERT_TRUE(torch::allclose(y, expected)); |
49 | |
50 | auto y_no_options = F::conv2d(x, weight); |
51 | ASSERT_TRUE(torch::allclose(y_no_options, expected)); |
52 | } |
53 | |
54 | TEST_F(FunctionalTest, Conv2dUneven) { |
55 | auto x = torch::arange(60, torch::dtype(torch::kFloat).requires_grad(true)) |
56 | .reshape({1, 3, 5, 4}); |
57 | auto weight = |
58 | torch::arange(36, torch::dtype(torch::kFloat).requires_grad(true)) |
59 | .reshape({2, 3, 3, 2}); |
60 | auto y = F::conv2d(x, weight, F::Conv2dFuncOptions().stride(1)); |
61 | auto expected = torch::tensor( |
62 | {{{{5289., 5442., 5595.}, {5901., 6054., 6207.}, {6513., 6666., 6819.}}, |
63 | |
64 | {{13227., 13704., 14181.}, |
65 | {15135., 15612., 16089.}, |
66 | {17043., 17520., 17997.}}}}, |
67 | torch::kFloat); |
68 | ASSERT_TRUE(torch::allclose(y, expected)); |
69 | |
70 | auto y_no_options = F::conv2d(x, weight); |
71 | ASSERT_TRUE(torch::allclose(y_no_options, expected)); |
72 | } |
73 | |
74 | TEST_F(FunctionalTest, Conv3d) { |
75 | auto x = torch::arange(375, torch::dtype(torch::kFloat).requires_grad(true)) |
76 | .reshape({1, 3, 5, 5, 5}); |
77 | auto weight = |
78 | torch::arange(162, torch::dtype(torch::kFloat).requires_grad(true)) |
79 | .reshape({2, 3, 3, 3, 3}); |
80 | auto y = F::conv3d(x, weight, F::Conv3dFuncOptions().stride(1)); |
81 | auto expected = torch::tensor( |
82 | {{{{{700704., 703944., 707184.}, |
83 | {716904., 720144., 723384.}, |
84 | {733104., 736344., 739584.}}, |
85 | |
86 | {{781704., 784944., 788184.}, |
87 | {797904., 801144., 804384.}, |
88 | {814104., 817344., 820584.}}, |
89 | |
90 | {{862704., 865944., 869184.}, |
91 | {878904., 882144., 885384.}, |
92 | {895104., 898344., 901584.}}}, |
93 | |
94 | {{{1724220., 1734021., 1743822.}, |
95 | {1773225., 1783026., 1792827.}, |
96 | {1822230., 1832031., 1841832.}}, |
97 | |
98 | {{1969245., 1979046., 1988847.}, |
99 | {2018250., 2028051., 2037852.}, |
100 | {2067255., 2077056., 2086857.}}, |
101 | |
102 | {{2214270., 2224071., 2233872.}, |
103 | {2263275., 2273076., 2282877.}, |
104 | {2312280., 2322081., 2331882.}}}}}, |
105 | torch::kFloat); |
106 | ASSERT_TRUE(torch::allclose(y, expected)); |
107 | |
108 | auto y_no_options = F::conv3d(x, weight); |
109 | ASSERT_TRUE(torch::allclose(y_no_options, expected)); |
110 | } |
111 | |
112 | TEST_F(FunctionalTest, MaxPool1d) { |
113 | auto x = torch::ones({1, 1, 5}); |
114 | auto y = F::max_pool1d(x, F::MaxPool1dFuncOptions(3).stride(2)); |
115 | |
116 | ASSERT_EQ(y.ndimension(), 3); |
117 | ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2}))); |
118 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2})); |
119 | } |
120 | |
121 | TEST_F(FunctionalTest, MaxPool2d) { |
122 | auto x = torch::ones({2, 5, 5}); |
123 | auto y = F::max_pool2d(x, F::MaxPool2dFuncOptions(3).stride(2)); |
124 | |
125 | ASSERT_EQ(y.ndimension(), 3); |
126 | ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2}))); |
127 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2})); |
128 | } |
129 | |
130 | TEST_F(FunctionalTest, MaxPool2dBackward) { |
131 | auto input = torch::rand( |
132 | {1, 2, 4, 4}, torch::dtype(torch::kFloat).requires_grad(true)); |
133 | auto output = F::max_pool2d(input, F::MaxPool2dFuncOptions(2)); |
134 | auto s = output.sum(); |
135 | s.backward(); |
136 | ASSERT_TRUE(input.sizes() == input.grad().sizes()); |
137 | } |
138 | |
139 | TEST_F(FunctionalTest, MaxPool3d) { |
140 | auto x = torch::ones({2, 5, 5, 5}); |
141 | auto y = F::max_pool3d(x, F::MaxPool3dFuncOptions(3).stride(2)); |
142 | |
143 | ASSERT_EQ(y.ndimension(), 4); |
144 | ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2}))); |
145 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2})); |
146 | } |
147 | |
148 | TEST_F(FunctionalTest, AvgPool1d) { |
149 | auto x = torch::ones({1, 1, 5}); |
150 | auto y = F::avg_pool1d(x, F::AvgPool1dFuncOptions(3).stride(2)); |
151 | |
152 | ASSERT_EQ(y.ndimension(), 3); |
153 | ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2}))); |
154 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 2})); |
155 | } |
156 | |
157 | TEST_F(FunctionalTest, AvgPool2d) { |
158 | auto x = torch::ones({2, 5, 5}); |
159 | auto y = F::avg_pool2d(x, F::AvgPool2dFuncOptions(3).stride(2)); |
160 | |
161 | ASSERT_EQ(y.ndimension(), 3); |
162 | ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2}))); |
163 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2})); |
164 | } |
165 | |
166 | TEST_F(FunctionalTest, AvgPool3d) { |
167 | auto x = torch::ones({2, 5, 5, 5}); |
168 | auto y = F::avg_pool3d(x, F::AvgPool3dFuncOptions(3).stride(2)); |
169 | |
170 | ASSERT_EQ(y.ndimension(), 4); |
171 | ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2}))); |
172 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2})); |
173 | } |
174 | |
175 | TEST_F(FunctionalTest, FractionalMaxPool2d) { |
176 | auto x = torch::ones({2, 5, 5}); |
177 | auto y = F::fractional_max_pool2d( |
178 | x, F::FractionalMaxPool2dFuncOptions(3).output_size(2)); |
179 | |
180 | ASSERT_EQ(y.ndimension(), 3); |
181 | ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2}))); |
182 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2})); |
183 | |
184 | auto y_with_indices = F::fractional_max_pool2d_with_indices( |
185 | x, F::FractionalMaxPool2dFuncOptions(3).output_size(2)); |
186 | ASSERT_TRUE(torch::equal(y, std::get<0>(y_with_indices))); |
187 | ASSERT_TRUE(torch::allclose( |
188 | std::get<1>(y_with_indices), |
189 | torch::tensor({{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}}))); |
190 | ASSERT_EQ( |
191 | std::get<1>(y_with_indices).sizes(), std::vector<int64_t>({2, 2, 2})); |
192 | |
193 | auto x1 = torch::ones({2, 2, 5, 5}); |
194 | auto y1 = F::fractional_max_pool2d( |
195 | x1, F::FractionalMaxPool2dFuncOptions(3).output_size(2)); |
196 | |
197 | ASSERT_EQ(y1.ndimension(), 4); |
198 | ASSERT_TRUE(torch::allclose(y1, torch::ones({2, 2, 2, 2}))); |
199 | ASSERT_EQ(y1.sizes(), std::vector<int64_t>({2, 2, 2, 2})); |
200 | |
201 | auto y1_with_indices = F::fractional_max_pool2d_with_indices( |
202 | x1, F::FractionalMaxPool2dFuncOptions(3).output_size(2)); |
203 | ASSERT_TRUE(torch::equal(y1, std::get<0>(y1_with_indices))); |
204 | ASSERT_TRUE(torch::allclose( |
205 | std::get<1>(y1_with_indices), |
206 | torch::tensor( |
207 | {{{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}}, |
208 | {{{0, 2}, {10, 12}}, {{0, 2}, {10, 12}}}}))); |
209 | ASSERT_EQ( |
210 | std::get<1>(y1_with_indices).sizes(), std::vector<int64_t>({2, 2, 2, 2})); |
211 | } |
212 | |
213 | TEST_F(FunctionalTest, FractionalMaxPool3d) { |
214 | auto x = torch::ones({2, 5, 5, 5}); |
215 | auto y = F::fractional_max_pool3d( |
216 | x, F::FractionalMaxPool3dFuncOptions(3).output_size(2)); |
217 | |
218 | ASSERT_EQ(y.ndimension(), 4); |
219 | ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2}))); |
220 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 2, 2, 2})); |
221 | |
222 | auto y_with_indices = F::fractional_max_pool3d_with_indices( |
223 | x, F::FractionalMaxPool3dFuncOptions(3).output_size(2)); |
224 | ASSERT_TRUE(torch::equal(y, std::get<0>(y_with_indices))); |
225 | ASSERT_TRUE(torch::allclose( |
226 | std::get<1>(y_with_indices), |
227 | torch::tensor( |
228 | {{{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}}, |
229 | {{{0, 2}, {10, 12}}, {{50, 52}, {60, 62}}}}))); |
230 | ASSERT_EQ( |
231 | std::get<1>(y_with_indices).sizes(), std::vector<int64_t>({2, 2, 2, 2})); |
232 | } |
233 | |
234 | TEST_F(FunctionalTest, LPPool1d) { |
235 | int norm_type = 2; |
236 | int stride = 2; |
237 | int kernel_size = 3; |
238 | |
239 | auto x = torch::ones({1, 1, 5}); |
240 | auto y = F::lp_pool1d( |
241 | x, F::LPPool1dFuncOptions(norm_type, kernel_size).stride(stride)); |
242 | auto expected = |
243 | (torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) * |
244 | kernel_size) |
245 | .pow(1. / norm_type); |
246 | |
247 | ASSERT_EQ(y.ndimension(), 3); |
248 | ASSERT_TRUE(torch::allclose(y, expected)); |
249 | ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2})); |
250 | } |
251 | |
252 | TEST_F(FunctionalTest, LPPool2d) { |
253 | int norm_type = 2; |
254 | int stride = 2; |
255 | std::vector<int64_t> kernel_size({2, 3}); |
256 | |
257 | auto x = torch::ones({1, 2, 5}); |
258 | auto y = F::lp_pool2d( |
259 | x, F::LPPool2dFuncOptions(norm_type, kernel_size).stride(stride)); |
260 | auto expected = |
261 | (torch::pow(torch::tensor({{{1, 1}}}, torch::kFloat), norm_type) * |
262 | (kernel_size[0] * kernel_size[1])) |
263 | .pow(1. / norm_type); |
264 | |
265 | ASSERT_EQ(y.ndimension(), 3); |
266 | ASSERT_TRUE(torch::allclose(y, expected)); |
267 | ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2})); |
268 | } |
269 | |
270 | TEST_F(FunctionalTest, CosineSimilarity) { |
271 | auto input1 = torch::tensor({{1, 2, 3}, {4, 5, 6}}, torch::kFloat); |
272 | auto input2 = torch::tensor({{1, 8, 3}, {2, 1, 6}}, torch::kFloat); |
273 | auto output = F::cosine_similarity( |
274 | input1, input2, F::CosineSimilarityFuncOptions().dim(1)); |
275 | auto expected = torch::tensor({0.8078, 0.8721}, torch::kFloat); |
276 | ASSERT_TRUE(output.allclose(expected, 1e-04)); |
277 | } |
278 | |
279 | TEST_F(FunctionalTest, SmoothL1LossDefaultOptions) { |
280 | auto input = torch::tensor( |
281 | {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); |
282 | auto target = torch::tensor({0., 1., 5.}, torch::kFloat); |
283 | auto output = F::smooth_l1_loss(input, target); |
284 | auto expected = torch::tensor(0.0233335, torch::kFloat); |
285 | auto s = output.sum(); |
286 | s.backward(); |
287 | ASSERT_TRUE(output.allclose(expected)); |
288 | ASSERT_TRUE(input.sizes() == input.grad().sizes()); |
289 | } |
290 | |
291 | TEST_F(FunctionalTest, SmoothL1LossBeta) { |
292 | auto input = torch::tensor( |
293 | {0.1, 1.5, 10.0}, torch::dtype(torch::kFloat).requires_grad(true)); |
294 | auto target = torch::tensor({0., 1., 5.}, torch::kFloat); |
295 | auto output = |
296 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,bugprone-argument-comment) |
297 | F::smooth_l1_loss( |
298 | input, target, /*reduction=*/torch::kMean, /*beta=*/0.5); |
299 | auto expected = torch::tensor(1.67, torch::kFloat); |
300 | auto s = output.sum(); |
301 | s.backward(); |
302 | ASSERT_TRUE(output.allclose(expected)); |
303 | ASSERT_TRUE(input.sizes() == input.grad().sizes()); |
304 | } |
305 | |
306 | TEST_F(FunctionalTest, SmoothL1LossNoReduction) { |
307 | auto input = torch::tensor( |
308 | {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); |
309 | auto target = torch::tensor({0., 1., 5.}, torch::kFloat); |
310 | auto output = |
311 | // NOLINTNEXTLINE(bugprone-argument-comment) |
312 | F::smooth_l1_loss(input, target, /*reduction=*/torch::kNone); |
313 | auto expected = torch::tensor({0.005, 0.02, 0.045}, torch::kFloat); |
314 | auto s = output.sum(); |
315 | s.backward(); |
316 | ASSERT_TRUE(output.allclose(expected)); |
317 | ASSERT_TRUE(input.sizes() == input.grad().sizes()); |
318 | } |
319 | |
320 | TEST_F(FunctionalTest, HuberLossDefaultOptions) { |
321 | auto input = torch::tensor( |
322 | {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); |
323 | auto target = torch::tensor({0., 1., 5.}, torch::kFloat); |
324 | auto output = F::huber_loss(input, target); |
325 | auto expected = torch::tensor(0.0233335, torch::kFloat); |
326 | auto s = output.sum(); |
327 | s.backward(); |
328 | ASSERT_TRUE(output.allclose(expected)); |
329 | ASSERT_TRUE(input.sizes() == input.grad().sizes()); |
330 | } |
331 | |
332 | TEST_F(FunctionalTest, HuberLossDelta) { |
333 | auto input = torch::tensor( |
334 | {0.1, 1.5, 10.0}, torch::dtype(torch::kFloat).requires_grad(true)); |
335 | auto target = torch::tensor({0., 1., 5.}, torch::kFloat); |
336 | auto options = F::HuberLossFuncOptions().reduction(torch::kMean).delta(0.5); |
337 | auto output = F::huber_loss(input, target, options); |
338 | auto expected = torch::tensor(1.67 * 0.5, torch::kFloat); |
339 | auto s = output.sum(); |
340 | s.backward(); |
341 | ASSERT_TRUE(output.allclose(expected)); |
342 | ASSERT_TRUE(input.sizes() == input.grad().sizes()); |
343 | } |
344 | |
345 | TEST_F(FunctionalTest, HuberLossNoReduction) { |
346 | auto input = torch::tensor( |
347 | {0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true)); |
348 | auto target = torch::tensor({0., 1., 5.}, torch::kFloat); |
349 | auto options = F::HuberLossFuncOptions().reduction(torch::kNone); |
350 | auto output = F::huber_loss(input, target, options); |
351 | auto expected = torch::tensor({0.005, 0.02, 0.045}, torch::kFloat); |
352 | auto s = output.sum(); |
353 | s.backward(); |
354 | ASSERT_TRUE(output.allclose(expected)); |
355 | ASSERT_TRUE(input.sizes() == input.grad().sizes()); |
356 | } |
357 | |
358 | TEST_F(FunctionalTest, SoftMarginLossDefaultOptions) { |
359 | auto input = torch::tensor( |
360 | {2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true)); |
361 | auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat); |
362 | auto output = F::soft_margin_loss(input, target); |
363 | auto expected = torch::tensor({1.3767317}, torch::kFloat); |
364 | auto s = output.sum(); |
365 | s.backward(); |
366 | |
367 | ASSERT_TRUE(output.allclose(expected)); |
368 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
369 | } |
370 | |
371 | TEST_F(FunctionalTest, MultiLabelSoftMarginLossDefaultOptions) { |
372 | auto input = torch::tensor( |
373 | {{0., 2., 2., 0.}, {2., 1., 0., 1.}}, |
374 | torch::dtype(torch::kFloat).requires_grad(true)); |
375 | auto target = |
376 | torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat); |
377 | auto output = F::multilabel_soft_margin_loss(input, target); |
378 | auto expected = torch::tensor({0.7608436}, torch::kFloat); |
379 | auto s = output.sum(); |
380 | s.backward(); |
381 | |
382 | ASSERT_TRUE(output.allclose(expected)); |
383 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
384 | } |
385 | |
386 | TEST_F(FunctionalTest, SoftMarginLossNoReduction) { |
387 | auto input = torch::tensor( |
388 | {2., 4., 1., 3.}, torch::dtype(torch::kFloat).requires_grad(true)); |
389 | auto target = torch::tensor({-1., 1., 1., -1.}, torch::kFloat); |
390 | auto output = F::soft_margin_loss(input, target, torch::kNone); |
391 | auto expected = torch::tensor( |
392 | {2.1269281, 0.01814993, 0.3132617, 3.0485873}, torch::kFloat); |
393 | auto s = output.sum(); |
394 | s.backward(); |
395 | |
396 | ASSERT_TRUE(output.allclose(expected)); |
397 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
398 | } |
399 | |
400 | TEST_F(FunctionalTest, MultiLabelSoftMarginLossWeightedNoReduction) { |
401 | auto input = torch::tensor( |
402 | {{0., 2., 2., 0.}, {2., 1., 0., 1.}}, |
403 | torch::dtype(torch::kFloat).requires_grad(true)); |
404 | auto target = |
405 | torch::tensor({{0., 0., 1., 0.}, {1., 0., 1., 1.}}, torch::kFloat); |
406 | auto weight = torch::tensor({0.1, 0.6, 0.4, 0.8}, torch::kFloat); |
407 | auto options = F::MultilabelSoftMarginLossFuncOptions() |
408 | .reduction(torch::kNone) |
409 | .weight(weight); |
410 | auto output = F::multilabel_soft_margin_loss(input, target, options); |
411 | auto expected = torch::tensor({0.4876902, 0.3321295}, torch::kFloat); |
412 | auto s = output.sum(); |
413 | s.backward(); |
414 | |
415 | ASSERT_TRUE(output.allclose(expected)); |
416 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
417 | } |
418 | |
419 | TEST_F(FunctionalTest, PairwiseDistance) { |
420 | auto input1 = torch::tensor({{1, 2, 3}, {4, 5, 6}}, torch::kFloat); |
421 | auto input2 = torch::tensor({{1, 8, 3}, {2, 1, 6}}, torch::kFloat); |
422 | auto output = F::pairwise_distance( |
423 | input1, input2, F::PairwiseDistanceFuncOptions().p(1)); |
424 | auto expected = torch::tensor({6, 6}, torch::kFloat); |
425 | ASSERT_TRUE(output.allclose(expected)); |
426 | } |
427 | |
428 | TEST_F(FunctionalTest, PDist) { |
429 | { |
430 | auto input = torch::tensor({{-1.0, -5.0, -1.0}, {2.0, 4.0, 6.0}}); |
431 | auto output = F::pdist(input); |
432 | auto expected = torch::tensor({11.7898}); |
433 | ASSERT_TRUE(output.allclose(expected)); |
434 | } |
435 | { |
436 | auto input = torch::tensor({{1.0, -1.0}, {1.0, 3.0}, {3.0, 3.0}}); |
437 | auto output = F::pdist(input, 1.5); |
438 | auto expected = torch::tensor({4.0, 4.8945, 2.0}); |
439 | ASSERT_TRUE(output.allclose(expected)); |
440 | } |
441 | } |
442 | |
443 | TEST_F(FunctionalTest, AdaptiveMaxPool1d) { |
444 | auto x = torch::ones({1, 1, 5}); |
445 | auto y = F::adaptive_max_pool1d(x, F::AdaptiveMaxPool1dFuncOptions(3)); |
446 | |
447 | ASSERT_EQ(y.ndimension(), 3); |
448 | ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 3}))); |
449 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3})); |
450 | } |
451 | |
452 | TEST_F(FunctionalTest, AdaptiveMaxPool2d) { |
453 | auto x = torch::ones({2, 5, 5}); |
454 | auto y = F::adaptive_max_pool2d(x, F::AdaptiveMaxPool2dFuncOptions(3)); |
455 | |
456 | ASSERT_EQ(y.ndimension(), 3); |
457 | ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3}))); |
458 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3})); |
459 | } |
460 | |
461 | TEST_F(FunctionalTest, AdaptiveMaxPool3d) { |
462 | auto x = torch::ones({2, 5, 5, 5}); |
463 | auto y = F::adaptive_max_pool3d(x, F::AdaptiveMaxPool3dFuncOptions(3)); |
464 | |
465 | ASSERT_EQ(y.ndimension(), 4); |
466 | ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3, 3}))); |
467 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3, 3})); |
468 | } |
469 | |
470 | TEST_F(FunctionalTest, AdaptiveAvgPool1d) { |
471 | auto x = torch::ones({1, 1, 5}); |
472 | auto y = F::adaptive_avg_pool1d(x, F::AdaptiveAvgPool1dFuncOptions(3)); |
473 | |
474 | ASSERT_EQ(y.ndimension(), 3); |
475 | ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 3}))); |
476 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3})); |
477 | } |
478 | |
479 | TEST_F(FunctionalTest, AdaptiveAvgPool2d) { |
480 | auto x = torch::ones({2, 5, 5}); |
481 | auto y = F::adaptive_avg_pool2d(x, F::AdaptiveAvgPool2dFuncOptions(3)); |
482 | |
483 | ASSERT_EQ(y.ndimension(), 3); |
484 | ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3}))); |
485 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3})); |
486 | } |
487 | |
488 | TEST_F(FunctionalTest, AdaptiveAvgPool3d) { |
489 | auto x = torch::ones({2, 5, 5, 5}); |
490 | auto y = F::adaptive_avg_pool3d(x, F::AdaptiveAvgPool3dFuncOptions(3)); |
491 | |
492 | ASSERT_EQ(y.ndimension(), 4); |
493 | ASSERT_TRUE(torch::allclose(y, torch::ones({2, 3, 3, 3}))); |
494 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 3, 3, 3})); |
495 | } |
496 | |
497 | TEST_F(FunctionalTest, L1Loss) { |
498 | auto input = torch::randn({5, 6}, torch::requires_grad()); |
499 | auto target = torch::empty({5, 6}).random_(2); |
500 | auto output = F::l1_loss(torch::sigmoid(input), target); |
501 | auto s = output.sum(); |
502 | s.backward(); |
503 | |
504 | ASSERT_EQ(output.sizes(), torch::IntArrayRef()); |
505 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
506 | } |
507 | |
508 | TEST_F(FunctionalTest, MSELoss) { |
509 | auto input = torch::randn({5, 6}, torch::requires_grad()); |
510 | auto target = torch::empty({5, 6}).random_(2); |
511 | auto output = F::mse_loss(torch::sigmoid(input), target); |
512 | auto s = output.sum(); |
513 | s.backward(); |
514 | |
515 | ASSERT_EQ(output.sizes(), torch::IntArrayRef()); |
516 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
517 | } |
518 | |
519 | TEST_F(FunctionalTest, BCELoss) { |
520 | auto input = torch::randn({5, 6}, torch::requires_grad()); |
521 | auto target = torch::empty({5, 6}).random_(2); |
522 | auto output = F::binary_cross_entropy(torch::sigmoid(input), target); |
523 | auto s = output.sum(); |
524 | s.backward(); |
525 | |
526 | ASSERT_EQ(output.sizes(), torch::IntArrayRef()); |
527 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
528 | } |
529 | |
530 | TEST_F(FunctionalTest, KLDivLoss) { |
531 | KLDivLoss loss; |
532 | auto input = torch::randn({5, 6}, torch::requires_grad()); |
533 | auto target = torch::empty({5, 6}).random_(2); |
534 | auto output = F::kl_div(torch::sigmoid(input), target); |
535 | auto s = output.sum(); |
536 | s.backward(); |
537 | |
538 | ASSERT_EQ(output.sizes(), torch::IntArrayRef()); |
539 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
540 | } |
541 | |
542 | TEST_F(FunctionalTest, HingeEmbeddingLoss) { |
543 | auto input = torch::tensor({{2, 22, 4}, {20, 10, 0}}, torch::kFloat); |
544 | auto target = torch::tensor({{2, 6, 4}, {1, 10, 0}}, torch::kFloat); |
545 | auto output = F::hinge_embedding_loss( |
546 | input, target, F::HingeEmbeddingLossFuncOptions().margin(2)); |
547 | auto expected = torch::tensor({10}, torch::kFloat); |
548 | |
549 | ASSERT_TRUE(output.allclose(expected)); |
550 | } |
551 | |
552 | TEST_F(FunctionalTest, GridSample) { |
553 | auto input = |
554 | torch::arange(9, torch::kFloat).view(std::vector<int64_t>({1, 1, 3, 3})); |
555 | auto grid = torch::tensor( |
556 | {{{{-2., -1.}, {-1., -1.}, {0., -1.}}, |
557 | {{-1., 0.}, {0., 0.}, {1., 0.}}, |
558 | {{0., 1.}, {1., 1.}, {2., 1.}}}}, |
559 | torch::kFloat); |
560 | |
561 | // bilinear, zeros, true |
562 | auto options = F::GridSampleFuncOptions() |
563 | .mode(torch::kBilinear) |
564 | .padding_mode(torch::kZeros) |
565 | .align_corners(true); |
566 | auto output = F::grid_sample(input, grid, options); |
567 | auto expected = torch::tensor( |
568 | {{{{0., 0., 1.}, {3., 4., 5.}, {7., 8., 0.}}}}, torch::kFloat); |
569 | |
570 | ASSERT_TRUE(output.allclose(expected)); |
571 | |
572 | // bilinear, zeros, false |
573 | options = F::GridSampleFuncOptions() |
574 | .mode(torch::kBilinear) |
575 | .padding_mode(torch::kZeros) |
576 | .align_corners(false); |
577 | output = F::grid_sample(input, grid, options); |
578 | expected = torch::tensor( |
579 | {{{{0., 0., 0.5}, {1.5, 4., 2.5}, {3.5, 2., 0.}}}}, torch::kFloat); |
580 | |
581 | ASSERT_TRUE(output.allclose(expected)); |
582 | |
583 | // default options (bilinear, zeros, false) same result as above |
584 | output = F::grid_sample(input, grid); |
585 | |
586 | ASSERT_TRUE(output.allclose(expected)); |
587 | |
588 | // nearest, zeros, true |
589 | options = F::GridSampleFuncOptions() |
590 | .mode(torch::kNearest) |
591 | .padding_mode(torch::kZeros) |
592 | .align_corners(true); |
593 | output = F::grid_sample(input, grid, options); |
594 | expected = torch::tensor( |
595 | {{{{0., 0., 1.}, {3., 4., 5.}, {7., 8., 0.}}}}, torch::kFloat); |
596 | |
597 | ASSERT_TRUE(output.allclose(expected)); |
598 | |
599 | // bilinear, border, true |
600 | options = F::GridSampleFuncOptions() |
601 | .mode(torch::kBilinear) |
602 | .padding_mode(torch::kBorder) |
603 | .align_corners(true); |
604 | output = F::grid_sample(input, grid, options); |
605 | expected = torch::tensor( |
606 | {{{{0., 0., 1.}, {3., 4., 5.}, {7., 8., 8.}}}}, torch::kFloat); |
607 | |
608 | ASSERT_TRUE(output.allclose(expected)); |
609 | |
610 | // bilinear, reflection, true |
611 | options = F::GridSampleFuncOptions() |
612 | .mode(torch::kBilinear) |
613 | .padding_mode(torch::kReflection) |
614 | .align_corners(true); |
615 | output = F::grid_sample(input, grid, options); |
616 | expected = torch::tensor( |
617 | {{{{1., 0., 1.}, {3., 4., 5.}, {7., 8., 7.}}}}, torch::kFloat); |
618 | |
619 | ASSERT_TRUE(output.allclose(expected)); |
620 | } |
621 | |
622 | TEST_F(FunctionalTest, AffineGrid) { |
623 | { |
624 | // 2D affine. |
625 | auto theta = torch::arange(1., 13).view(std::vector<int64_t>({2, 2, 3})); |
626 | auto size = std::vector<int64_t>({2, 3, 2, 2}); |
627 | auto align_corners = true; |
628 | auto output = F::affine_grid(theta, size, !align_corners); |
629 | auto expected = torch::tensor( |
630 | {{{{1.50, 1.50}, {2.50, 5.50}}, {{3.50, 6.50}, {4.50, 10.50}}}, |
631 | {{{1.50, 1.50}, {8.50, 11.50}}, {{9.50, 12.50}, {16.50, 22.50}}}}); |
632 | auto output_aligned = F::affine_grid(theta, size, align_corners); |
633 | auto expected_aligned = torch::tensor( |
634 | {{{{0.0, -3.0}, {2.0, 5.0}}, {{4.0, 7.0}, {6.0, 15.0}}}, |
635 | {{{-6.0, -9.0}, {8.0, 11.0}}, {{10.0, 13.0}, {24.0, 33.0}}}}); |
636 | |
637 | ASSERT_TRUE(output.allclose(expected)); |
638 | ASSERT_TRUE(output_aligned.allclose(expected_aligned)); |
639 | } |
640 | { |
641 | // 3D affine. |
642 | auto theta = torch::arange(1., 13).view(std::vector<int64_t>({1, 3, 4})); |
643 | auto size = std::vector<int64_t>({1, 1, 3, 2, 2}); |
644 | auto align_corners = true; |
645 | auto output = F::affine_grid(theta, size, !align_corners); |
646 | auto expected = torch::tensor( |
647 | {{{{{0.5000, -2.1667, -4.8333}, {1.5000, 2.8333, 4.1667}}, |
648 | {{2.5000, 3.8333, 5.1667}, {3.5000, 8.8333, 14.1667}}}, |
649 | {{{2.5000, 2.5000, 2.5000}, {3.5000, 7.5000, 11.5000}}, |
650 | {{4.5000, 8.5000, 12.5000}, {5.5000, 13.5000, 21.5000}}}, |
651 | {{{4.5000, 7.1667, 9.8333}, {5.5000, 12.1667, 18.8333}}, |
652 | {{6.5000, 13.1667, 19.8333}, {7.5000, 18.1667, 28.8333}}}}}); |
653 | auto output_aligned = F::affine_grid(theta, size, align_corners); |
654 | auto expected_aligned = torch::tensor( |
655 | {{{{{-2.0, -10.0, -18.0}, {0.0, 0.0, 0.0}}, |
656 | {{2.0, 2.0, 2.0}, {4.0, 12.0, 20.0}}}, |
657 | {{{1.0, -3.0, -7.0}, {3.0, 7.0, 11.0}}, |
658 | {{5.0, 9.0, 13.0}, {7.0, 19.0, 31.0}}}, |
659 | {{{4.0, 4.0, 4.0}, {6.0, 14.0, 22.0}}, |
660 | {{8.0, 16.0, 24.0}, {10.0, 26.0, 42.0}}}}}); |
661 | |
662 | ASSERT_TRUE(output.allclose(expected, 1e-2)); |
663 | ASSERT_TRUE(output_aligned.allclose(expected_aligned)); |
664 | } |
665 | { |
666 | auto theta = torch::empty({1, 2, 3}, torch::kDouble); |
667 | auto size = std::vector<int64_t>({1, 1, 2, 2}); |
668 | ASSERT_THROWS_WITH( |
669 | F::affine_grid(torch::empty({2, 2, 3}), {-1, 1, 2, 2}), |
670 | "Expected non-zero, positive output size. Got [-1, 1, 2, 2]" ); |
671 | ASSERT_THROWS_WITH( |
672 | F::affine_grid(torch::empty({2, 2, 3}, torch::kInt), size), |
673 | "Expected theta to have floating point type, but got int" ); |
674 | ASSERT_THROWS_WITH( |
675 | F::affine_grid(theta[0], size), |
676 | "Expected a batch of 2D affine matrices of shape Nx2x3 for size " |
677 | "[1, 1, 2, 2]. Got [2, 3]." ); |
678 | ASSERT_THROWS_WITH( |
679 | F::affine_grid(theta.unsqueeze(0), size), |
680 | "Expected a batch of 2D affine matrices of shape Nx2x3 for size " |
681 | "[1, 1, 2, 2]. Got [1, 1, 2, 3]." ); |
682 | ASSERT_THROWS_WITH( |
683 | F::affine_grid(theta.repeat({1, 2, 1}), size), |
684 | "Expected a batch of 2D affine matrices of shape Nx2x3 for size " |
685 | "[1, 1, 2, 2]. Got [1, 4, 3]." ); |
686 | ASSERT_THROWS_WITH( |
687 | F::affine_grid(theta.repeat({1, 1, 2}), size), |
688 | "Expected a batch of 2D affine matrices of shape Nx2x3 for size " |
689 | "[1, 1, 2, 2]. Got [1, 2, 6]." ); |
690 | } |
691 | { |
692 | auto theta = torch::empty({1, 3, 4}, torch::kDouble); |
693 | auto size = std::vector<int64_t>({1, 1, 2, 2, 3}); |
694 | ASSERT_THROWS_WITH( |
695 | F::affine_grid(theta[0], size), |
696 | "Expected a batch of 3D affine matrices of shape Nx3x4 for size " |
697 | "[1, 1, 2, 2, 3]. Got [3, 4]." ); |
698 | ASSERT_THROWS_WITH( |
699 | F::affine_grid(theta.unsqueeze(0), size), |
700 | "Expected a batch of 3D affine matrices of shape Nx3x4 for size " |
701 | "[1, 1, 2, 2, 3]. Got [1, 1, 3, 4]." ); |
702 | ASSERT_THROWS_WITH( |
703 | F::affine_grid(theta.repeat({1, 2, 1}), size), |
704 | "Expected a batch of 3D affine matrices of shape Nx3x4 for size " |
705 | "[1, 1, 2, 2, 3]. Got [1, 6, 4]." ); |
706 | ASSERT_THROWS_WITH( |
707 | F::affine_grid(theta.repeat({1, 1, 2}), size), |
708 | "Expected a batch of 3D affine matrices of shape Nx3x4 for size " |
709 | "[1, 1, 2, 2, 3]. Got [1, 3, 8]." ); |
710 | ASSERT_THROWS_WITH( |
711 | F::affine_grid(theta, {1, 1, 1, 2, 2, 3}), |
712 | "affine_grid only supports 4D and 5D sizes, for 2D and 3D affine " |
713 | "transforms, respectively. Got size [1, 1, 1, 2, 2, 3]" ); |
714 | ASSERT_THROWS_WITH( |
715 | F::affine_grid(theta, {1, 1}), |
716 | "affine_grid only supports 4D and 5D sizes, for 2D and 3D affine " |
717 | "transforms, respectively. Got size [1, 1]" ); |
718 | } |
719 | } |
720 | |
721 | TEST_F(FunctionalTest, MultiMarginLoss) { |
722 | auto weight = torch::tensor({0.3, 0.3, 0.4}, torch::kFloat); |
723 | auto input = torch::tensor( |
724 | {{0.2, 0.2, 0.6}, {0.1, 0.8, 0.1}, {0.9, 0.09, 0.01}}, |
725 | torch::dtype(torch::kFloat).requires_grad(true)); |
726 | auto target = torch::tensor({2, 1, 0}, torch::kLong); |
727 | auto output = F::multi_margin_loss( |
728 | input, target, F::MultiMarginLossFuncOptions().margin(2).weight(weight)); |
729 | auto expected = torch::tensor({0.305556}, torch::kFloat); |
730 | |
731 | ASSERT_TRUE(output.allclose(expected, 1e-04)); |
732 | } |
733 | |
734 | TEST_F(FunctionalTest, CosineEmbeddingLoss) { |
735 | auto input1 = torch::tensor({{2, 3, 4}, {6, 2, 4}}); |
736 | auto input2 = torch::tensor({{2, 3, 5}, {9, 12, 0}}); |
737 | auto target = torch::tensor({1, -1}); |
738 | auto output = F::cosine_embedding_loss( |
739 | input1, input2, target, F::CosineEmbeddingLossFuncOptions().margin(0.5)); |
740 | auto expected = torch::tensor({0.1004}, torch::kFloat); |
741 | |
742 | ASSERT_TRUE(output.allclose(expected, 1e-4)); |
743 | } |
744 | |
745 | TEST_F(FunctionalTest, MultiLabelMarginLossDefaultOptions) { |
746 | auto input = torch::tensor( |
747 | {{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true)); |
748 | auto target = torch::tensor({{3, 0, -1, 1}}, torch::kLong); |
749 | auto output = F::multilabel_margin_loss(input, target); |
750 | auto expected = torch::tensor({0.8500}, torch::kFloat); |
751 | auto s = output.sum(); |
752 | s.backward(); |
753 | |
754 | ASSERT_TRUE(output.allclose(expected)); |
755 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
756 | } |
757 | |
758 | TEST_F(FunctionalTest, MultiLabelMarginLossNoReduction) { |
759 | auto input = torch::tensor( |
760 | {{0.1, 0.2, 0.4, 0.8}}, torch::dtype(torch::kFloat).requires_grad(true)); |
761 | auto target = torch::tensor({{3, 0, -1, 1}}, torch::kLong); |
762 | auto output = F::multilabel_margin_loss(input, target, torch::kNone); |
763 | auto expected = torch::tensor({0.8500}, torch::kFloat); |
764 | auto s = output.sum(); |
765 | s.backward(); |
766 | |
767 | ASSERT_TRUE(output.allclose(expected)); |
768 | ASSERT_EQ(input.sizes(), input.grad().sizes()); |
769 | } |
770 | |
771 | TEST_F(FunctionalTest, TripletMarginLoss) { |
772 | auto anchor = torch::tensor({{3., 3.}}, torch::kFloat); |
773 | auto positive = torch::tensor({{2., 2.}}, torch::kFloat); |
774 | auto negative = torch::tensor({{0., 0.}}, torch::kFloat); |
775 | auto output = F::triplet_margin_loss( |
776 | anchor, |
777 | positive, |
778 | negative, |
779 | F::TripletMarginLossFuncOptions().margin(1.0)); |
780 | auto expected = torch::tensor({0.}, torch::kFloat); |
781 | |
782 | ASSERT_TRUE(output.allclose(expected, 1e-04)); |
783 | } |
784 | |
785 | TEST_F(FunctionalTest, TripletMarginWithDistanceLossDefaultParity) { |
786 | // Check that if we use torch::pairwise_distance with the default |
787 | // TripletMarginLoss options as our distance function, the outputs |
788 | // are equal (i.e., equal under defaults). |
789 | |
790 | std::vector<TripletMarginWithDistanceLossOptions::reduction_t> reductions = { |
791 | torch::kSum, torch::kMean, torch::kNone}; |
792 | std::vector<float> margins = {0.5, 1.0, 1.5}; |
793 | std::vector<bool> swaps = {true, false}; |
794 | |
795 | for (auto& reduction : reductions) { |
796 | for (auto& margin : margins) { |
797 | for (const auto& swap : swaps) { |
798 | auto anchor = torch::randn( |
799 | {100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); |
800 | auto positive = torch::randn( |
801 | {100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); |
802 | auto negative = torch::randn( |
803 | {100, 128}, torch::dtype(torch::kFloat).requires_grad(true)); |
804 | |
805 | auto basicOptions = F::TripletMarginLossFuncOptions() |
806 | .reduction(reduction) |
807 | .margin(margin) |
808 | .swap(swap); |
809 | auto distanceOptions = F::TripletMarginWithDistanceLossFuncOptions() |
810 | .reduction(reduction) |
811 | .margin(margin) |
812 | .swap(swap); |
813 | TripletMarginLoss basicLoss(basicOptions); |
814 | TripletMarginWithDistanceLoss distanceLoss(distanceOptions); |
815 | |
816 | auto basicOutput = |
817 | F::triplet_margin_loss(anchor, positive, negative, basicOptions); |
818 | auto distanceOutput = F::triplet_margin_with_distance_loss( |
819 | anchor, positive, negative, distanceOptions); |
820 | |
821 | ASSERT_TRUE(distanceOutput.allclose(basicOutput, 1e-6, 1e-6)); |
822 | |
823 | // handle for torch::kNone reduction |
824 | auto sum = distanceOutput.sum(); |
825 | sum.backward(); |
826 | ASSERT_EQ(anchor.sizes(), anchor.grad().sizes()); |
827 | ASSERT_EQ(positive.sizes(), positive.grad().sizes()); |
828 | ASSERT_EQ(negative.sizes(), negative.grad().sizes()); |
829 | } |
830 | } |
831 | } |
832 | } |
833 | |
834 | TEST_F(FunctionalTest, NLLLoss) { |
835 | auto input = torch::tensor( |
836 | {{-0.1315, -3.1315, -2.5315}, |
837 | {-3.7038, -0.1038, -2.6038}, |
838 | {-2.3422, -1.3422, -0.4422}}, |
839 | torch::kFloat); |
840 | auto target = torch::tensor({1, 0, 2}, torch::kLong); |
841 | auto output = F::nll_loss( |
842 | input, |
843 | target, |
844 | F::NLLLossFuncOptions().ignore_index(-100).reduction(torch::kMean)); |
845 | auto expected = torch::tensor(2.4258, torch::kFloat); |
846 | ASSERT_TRUE(output.allclose(expected, 1e-04)); |
847 | ASSERT_TRUE(F::nll_loss(input, target).allclose(expected, 1e-04)); |
848 | } |
849 | |
850 | TEST_F(FunctionalTest, CrossEntropy) { |
851 | auto input = torch::tensor({{3., 3.}, {2., 2.}}, torch::kFloat); |
852 | auto target = torch::tensor({0, 1}, torch::kLong); |
853 | auto output = F::cross_entropy( |
854 | input, |
855 | target, |
856 | F::CrossEntropyFuncOptions().ignore_index(-100).reduction(torch::kMean)); |
857 | auto expected = torch::tensor(0.6931, torch::kFloat); |
858 | |
859 | ASSERT_TRUE(output.allclose(expected, 1e-04)); |
860 | ASSERT_TRUE(F::cross_entropy(input, target).allclose(expected, 1e-04)); |
861 | |
862 | // label smoothing with class indices |
863 | input = torch::tensor({{3., 1.}, {1., 2.}}, torch::kFloat); |
864 | output = F::cross_entropy( |
865 | input, |
866 | target, |
867 | F::CrossEntropyFuncOptions().label_smoothing(0.15).reduction( |
868 | torch::kMean)); |
869 | expected = torch::tensor(0.3326, torch::kFloat); |
870 | ASSERT_TRUE(output.allclose(expected, 1e-04)); |
871 | |
872 | // label smoothing with target probabilities |
873 | target = torch::tensor({{0.8, 0.2}, {0.1, 0.9}}, torch::kFloat); |
874 | output = F::cross_entropy( |
875 | input, |
876 | target, |
877 | F::CrossEntropyFuncOptions().label_smoothing(0.2).reduction( |
878 | torch::kMean)); |
879 | expected = torch::tensor(0.5701, torch::kFloat); |
880 | ASSERT_TRUE(output.allclose(expected, 1e-04)); |
881 | } |
882 | |
883 | TEST_F(FunctionalTest, MaxUnpool1d) { |
884 | auto x = torch::tensor( |
885 | {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true)); |
886 | auto indices = torch::tensor({{{1, 3, 4}}}, torch::kLong); |
887 | auto y = F::max_unpool1d(x, indices, F::MaxUnpool1dFuncOptions(3)); |
888 | |
889 | ASSERT_EQ(y.ndimension(), 3); |
890 | ASSERT_TRUE(torch::allclose( |
891 | y, torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat))); |
892 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 9})); |
893 | |
894 | x = torch::tensor( |
895 | {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true)); |
896 | indices = torch::tensor({{{1, 3, 4}}}, torch::kLong); |
897 | y = F::max_unpool1d( |
898 | x, |
899 | indices, |
900 | F::MaxUnpool1dFuncOptions(3).output_size( |
901 | std::vector<int64_t>({1, 1, 9}))); |
902 | |
903 | ASSERT_EQ(y.ndimension(), 3); |
904 | ASSERT_TRUE(torch::allclose( |
905 | y, torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat))); |
906 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 9})); |
907 | |
908 | x = torch::tensor( |
909 | {{{2, 4, 5}}}, torch::dtype(torch::kFloat).requires_grad(true)); |
910 | indices = torch::tensor({{{1, 3, 4}}}, torch::kLong); |
911 | y = F::max_unpool1d( |
912 | x, indices, F::MaxUnpool1dFuncOptions(3).stride(2).padding(1)); |
913 | |
914 | ASSERT_EQ(y.ndimension(), 3); |
915 | ASSERT_TRUE( |
916 | torch::allclose(y, torch::tensor({{{0, 2, 0, 4, 5}}}, torch::kFloat))); |
917 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 5})); |
918 | } |
919 | |
920 | TEST_F(FunctionalTest, MaxUnpool2d) { |
921 | auto indices = torch::tensor( |
922 | {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}, |
923 | {{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}}, |
924 | torch::kLong); |
925 | auto x = torch::tensor( |
926 | {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}, |
927 | {{{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}}}, |
928 | torch::dtype(torch::kFloat).requires_grad(true)); |
929 | auto y = F::max_unpool2d( |
930 | x, indices, F::MaxUnpool2dFuncOptions(3).stride(2).padding(1)); |
931 | |
932 | ASSERT_EQ(y.dim(), 4); |
933 | ASSERT_TRUE(torch::allclose( |
934 | y, |
935 | torch::tensor( |
936 | {{{{0, 0, 0, 0, 0}, |
937 | {0, 6, 0, 8, 9}, |
938 | {0, 0, 0, 0, 0}, |
939 | {0, 16, 0, 18, 19}, |
940 | {0, 21, 0, 23, 24}}}, |
941 | {{{0, 0, 0, 0, 0}, |
942 | {0, 31, 0, 33, 34}, |
943 | {0, 0, 0, 0, 0}, |
944 | {0, 41, 0, 43, 44}, |
945 | {0, 46, 0, 48, 49}}}}, |
946 | torch::kFloat))); |
947 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({2, 1, 5, 5})); |
948 | } |
949 | |
950 | TEST_F(FunctionalTest, MaxUnpool3d) { |
951 | auto indices = torch::tensor({{{{{26}}}}}, torch::kLong); |
952 | auto x = torch::tensor( |
953 | {{{{{26}}}}}, torch::dtype(torch::kFloat).requires_grad(true)); |
954 | auto y = F::max_unpool3d(x, indices, F::MaxUnpool3dFuncOptions(3)); |
955 | |
956 | ASSERT_EQ(y.dim(), 5); |
957 | ASSERT_TRUE(torch::allclose( |
958 | y, |
959 | torch::tensor( |
960 | {{{{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}, |
961 | {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}, |
962 | {{0, 0, 0}, {0, 0, 0}, {0, 0, 26}}}}}, |
963 | torch::kFloat))); |
964 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({1, 1, 3, 3, 3})); |
965 | } |
966 | |
967 | TEST_F(FunctionalTest, ELU) { |
968 | const auto size = 3; |
969 | for (const auto inplace : {false, true}) { |
970 | for (const auto alpha : {0.0, 0.42, 1.0, 4.2, 42.42}) { |
971 | auto x = torch::linspace(-10.0, 10.0, size * size * size); |
972 | x.resize_({size, size, size}); |
973 | auto x_bf16 = |
974 | torch::linspace(-10.0, 10.0, size * size * size).to(torch::kBFloat16); |
975 | x_bf16.resize_({size, size, size}); |
976 | |
977 | auto y_exp = torch::max(torch::zeros_like(x), x) + |
978 | torch::min(torch::zeros_like(x), alpha * (torch::exp(x) - 1.0)); |
979 | auto y = F::elu(x, F::ELUFuncOptions().alpha(alpha).inplace(inplace)); |
980 | auto y_bf16 = |
981 | F::elu(x_bf16, F::ELUFuncOptions().alpha(alpha).inplace(inplace)); |
982 | |
983 | ASSERT_EQ(y.ndimension(), 3); |
984 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
985 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
986 | ASSERT_TRUE(torch::allclose(y_bf16.to(torch::kFloat), y, 1e-2, 1e-2)); |
987 | if (inplace) { |
988 | ASSERT_TRUE(torch::allclose(x, y_exp)); |
989 | ASSERT_TRUE(torch::allclose(x_bf16.to(torch::kFloat), y, 1e-2, 1e-2)); |
990 | } |
991 | } |
992 | } |
993 | ASSERT_TRUE(F::elu(torch::tensor(1.)).defined()); |
994 | } |
995 | |
996 | TEST_F(FunctionalTest, SELU) { |
997 | { |
998 | const double scale = 1.0507009873554804934193349852946; |
999 | const double alpha = 1.6732632423543772848170429916717; |
1000 | for (const auto inplace : {false, true}) { |
1001 | auto input = torch::randn({5, 5}); |
1002 | auto input_bf16 = input.clone().to(torch::kBFloat16); |
1003 | auto expected = scale * |
1004 | (torch::max(torch::zeros_like(input), input) + |
1005 | torch::min( |
1006 | torch::zeros_like(input), alpha * (torch::exp(input) - 1))); |
1007 | auto output = F::selu(input, inplace); |
1008 | auto output_bf16 = F::selu(input_bf16, inplace); |
1009 | |
1010 | ASSERT_TRUE(output.allclose(expected)); |
1011 | ASSERT_TRUE(output_bf16.to(torch::kFloat).allclose(output, 1e-2, 1e-2)); |
1012 | if (inplace) { |
1013 | ASSERT_TRUE(input.allclose(expected)); |
1014 | ASSERT_TRUE(input_bf16.to(torch::kFloat).allclose(output, 1e-2, 1e-2)); |
1015 | } |
1016 | } |
1017 | } |
1018 | { |
1019 | auto input = torch::arange(0, 9, torch::kDouble).view({3, 3}); |
1020 | auto output = F::selu(input); |
1021 | auto expected = F::selu(input, false); |
1022 | ASSERT_TRUE(output.allclose(expected)); |
1023 | } |
1024 | ASSERT_TRUE(F::selu(torch::tensor(1.)).defined()); |
1025 | } |
1026 | |
1027 | TEST_F(FunctionalTest, GLU) { |
1028 | int64_t dim = 1; |
1029 | auto input = torch::randn({4, 2}, torch::requires_grad()); |
1030 | auto output = F::glu(input, dim); |
1031 | auto input_size = input.sizes()[dim] / 2; |
1032 | auto first_half = input.narrow(dim, 0, input_size); |
1033 | auto second_half = input.narrow(dim, input_size, input_size); |
1034 | auto expected = first_half * torch::sigmoid(second_half); |
1035 | |
1036 | ASSERT_TRUE(output.allclose(expected)); |
1037 | ASSERT_TRUE(F::glu(input).allclose(expected)); |
1038 | } |
1039 | |
1040 | TEST_F(FunctionalTest, GELU) { |
1041 | const auto x = torch::linspace(-3.0, 3.0, 100); |
1042 | const auto y_exp = x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0))); |
1043 | const auto y = F::gelu(x, F::GELUFuncOptions().approximate("none" )); |
1044 | ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05)); |
1045 | } |
1046 | |
1047 | TEST_F(FunctionalTest, TanhGELU) { |
1048 | const auto x = torch::linspace(-3.0, 3.0, 100); |
1049 | const auto inner = std::sqrt(2 / M_PI) * (x + 0.044715 * x.pow(3.0)); |
1050 | const auto y_exp = 0.5 * x * (1.0 + inner.tanh()); |
1051 | const auto y = F::gelu(x, F::GELUFuncOptions().approximate("tanh" )); |
1052 | ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05)); |
1053 | } |
1054 | |
1055 | TEST_F(FunctionalTest, Hardshrink) { |
1056 | const auto size = 3; |
1057 | for (const auto lambda : {-4.2, -1.0, -0.42, 0.0, 0.42, 1.0, 4.2, 42.42}) { |
1058 | auto x = torch::linspace(-10.0, 10.0, size * size * size); |
1059 | x.resize_({size, size, size}).set_requires_grad(true); |
1060 | auto y = F::hardshrink(x, F::HardshrinkFuncOptions().lambda(lambda)); |
1061 | torch::Tensor s = y.sum(); |
1062 | |
1063 | s.backward(); |
1064 | ASSERT_EQ(s.ndimension(), 0); |
1065 | |
1066 | ASSERT_EQ(y.ndimension(), 3); |
1067 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
1068 | auto y_exp = (x.abs() > lambda) * x; |
1069 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1070 | } |
1071 | } |
1072 | |
1073 | TEST_F(FunctionalTest, OneHot) { |
1074 | { // Test #1 |
1075 | auto x = torch::arange(0, 5, torch::kLong); |
1076 | auto y = F::one_hot(x % 3); |
1077 | auto expected = torch::tensor( |
1078 | {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}, {1, 0, 0}, {0, 1, 0}}, torch::kLong); |
1079 | |
1080 | ASSERT_EQ(y.ndimension(), 2); |
1081 | ASSERT_TRUE(torch::allclose(y, expected)); |
1082 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({5, 3})); |
1083 | } |
1084 | |
1085 | { // Test #2 |
1086 | auto x = torch::arange(0, 5, torch::kLong); |
1087 | auto y = F::one_hot(x % 3, 5); |
1088 | auto expected = torch::tensor( |
1089 | {{1, 0, 0, 0, 0}, |
1090 | {0, 1, 0, 0, 0}, |
1091 | {0, 0, 1, 0, 0}, |
1092 | {1, 0, 0, 0, 0}, |
1093 | {0, 1, 0, 0, 0}}, |
1094 | torch::kLong); |
1095 | |
1096 | ASSERT_EQ(y.ndimension(), 2); |
1097 | ASSERT_TRUE(torch::allclose(y, expected)); |
1098 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({5, 5})); |
1099 | } |
1100 | |
1101 | { // Test #3 |
1102 | auto x = torch::arange(0, 6, torch::kLong); |
1103 | auto y = F::one_hot(x.view(std::vector<int64_t>({3, 2})) % 3); |
1104 | auto expected = torch::tensor( |
1105 | {{{1, 0, 0}, {0, 1, 0}}, |
1106 | {{0, 0, 1}, {1, 0, 0}}, |
1107 | {{0, 1, 0}, {0, 0, 1}}}, |
1108 | torch::kLong); |
1109 | |
1110 | ASSERT_EQ(y.ndimension(), 3); |
1111 | ASSERT_TRUE(torch::allclose(y, expected)); |
1112 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({3, 2, 3})); |
1113 | } |
1114 | } |
1115 | |
1116 | TEST_F(FunctionalTest, Hardtanh) { |
1117 | const auto size = 3; |
1118 | for (const auto min_val : {-4.2, -1.0, -0.42, 0.0}) { |
1119 | for (const auto max_val : {0.0, 0.42, 1.0, 4.2}) { |
1120 | for (const auto inplace : {false, true}) { |
1121 | auto x = torch::linspace(-10.0, 10.0, size * size * size); |
1122 | x.resize_({size, size, size}); |
1123 | auto y_exp = (x < min_val) * min_val + |
1124 | ((x >= min_val) * (x <= max_val)) * x + (x > max_val) * max_val; |
1125 | auto y = F::hardtanh( |
1126 | x, |
1127 | F::HardtanhFuncOptions().min_val(min_val).max_val(max_val).inplace( |
1128 | inplace)); |
1129 | |
1130 | ASSERT_EQ(y.ndimension(), 3); |
1131 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
1132 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1133 | if (inplace) { |
1134 | ASSERT_TRUE(torch::allclose(x, y_exp)); |
1135 | } |
1136 | } |
1137 | } |
1138 | } |
1139 | ASSERT_TRUE(F::hardtanh(torch::tensor(1.)).defined()); |
1140 | } |
1141 | |
1142 | TEST_F(FunctionalTest, LeakyReLU) { |
1143 | const auto size = 3; |
1144 | for (const auto negative_slope : {0.0, 0.42, 1.0}) { |
1145 | for (const auto inplace : {false, true}) { |
1146 | for (const auto type : {torch::kFloat, torch::kBFloat16}) { |
1147 | auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type); |
1148 | x.resize_({size, size, size}); |
1149 | auto y_exp = (x < 0) * x * negative_slope + (x >= 0) * x; |
1150 | auto y = F::leaky_relu( |
1151 | x, |
1152 | F::LeakyReLUFuncOptions() |
1153 | .negative_slope(negative_slope) |
1154 | .inplace(inplace)); |
1155 | |
1156 | ASSERT_EQ(y.ndimension(), 3); |
1157 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
1158 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1159 | if (inplace) { |
1160 | ASSERT_TRUE(torch::allclose(x, y_exp)); |
1161 | } |
1162 | } |
1163 | } |
1164 | } |
1165 | ASSERT_TRUE(F::leaky_relu(torch::tensor(1.)).defined()); |
1166 | } |
1167 | |
1168 | TEST_F(FunctionalTest, LogSigmoid) { |
1169 | const auto size = 3; |
1170 | LogSigmoid model; |
1171 | auto x = torch::linspace(-10.0, 10.0, size * size * size); |
1172 | x.resize_({size, size, size}); |
1173 | auto y = F::logsigmoid(x); |
1174 | |
1175 | ASSERT_EQ(y.ndimension(), 3); |
1176 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
1177 | auto y_exp = torch::log( |
1178 | torch::ones_like(x) / (torch::ones_like(x) + torch::exp(torch::neg(x)))); |
1179 | ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7)); |
1180 | } |
1181 | |
1182 | TEST_F(FunctionalTest, GumbelSoftmax) { |
1183 | // Test 1: No-options |
1184 | { |
1185 | auto logits = torch::randn({5}); |
1186 | int expected_count = 1; |
1187 | auto y_draw = F::gumbel_softmax(logits); |
1188 | |
1189 | // All values positive |
1190 | ASSERT_GE(y_draw.min().item<int>(), 0); |
1191 | // Shape unchanged |
1192 | ASSERT_TRUE(y_draw.sizes() == logits.sizes()); |
1193 | // One choice per draw |
1194 | ASSERT_TRUE(torch::allclose( |
1195 | y_draw.sum(), torch::tensor(expected_count, torch::kFloat))); |
1196 | } |
1197 | |
1198 | // Test 2: 1D shape, 0 and -1 dim |
1199 | for (const auto dim : {0, -1}) { |
1200 | auto logits = torch::randn({5}); |
1201 | int expected_count = 1; |
1202 | auto y_draw = F::gumbel_softmax( |
1203 | logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(dim)); |
1204 | |
1205 | // All values positive |
1206 | ASSERT_GE(y_draw.min().item<int>(), 0); |
1207 | // Shape unchanged |
1208 | ASSERT_TRUE(y_draw.sizes() == logits.sizes()); |
1209 | // One choice per draw |
1210 | ASSERT_TRUE(torch::allclose( |
1211 | y_draw.sum(), torch::tensor(expected_count, torch::kFloat))); |
1212 | } |
1213 | |
1214 | { // Test 3: 2D shape, 1 dim |
1215 | auto logits = torch::randn({5, 4}); |
1216 | int expected_count = 5; |
1217 | auto y_draw = F::gumbel_softmax( |
1218 | logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(1)); |
1219 | |
1220 | // All values positive |
1221 | ASSERT_GE(y_draw.min().item<int>(), 0); |
1222 | // Shape unchanged |
1223 | ASSERT_TRUE(y_draw.sizes() == logits.sizes()); |
1224 | // One choice per draw |
1225 | ASSERT_TRUE(torch::allclose( |
1226 | y_draw.sum(), torch::tensor(expected_count, torch::kFloat))); |
1227 | } |
1228 | |
1229 | // Test 4: 3D shape, 1 and -1 dim |
1230 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
1231 | int dims[] = {1, -1}; |
1232 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers) |
1233 | int expected[] = {5 * 3, 5 * 4}; |
1234 | for (const auto i : c10::irange(2)) { |
1235 | auto logits = torch::randn({5, 4, 3}); |
1236 | int expected_count = expected[i]; |
1237 | auto y_draw = F::gumbel_softmax( |
1238 | logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(dims[i])); |
1239 | |
1240 | // All values positive |
1241 | ASSERT_GE(y_draw.min().item<int>(), 0); |
1242 | // Shape unchanged |
1243 | ASSERT_TRUE(y_draw.sizes() == logits.sizes()); |
1244 | // One choice per draw |
1245 | ASSERT_TRUE(torch::allclose( |
1246 | y_draw.sum(), torch::tensor(expected_count, torch::kFloat))); |
1247 | } |
1248 | |
1249 | { // Test 5: Straight through |
1250 | int num_draws = 100; |
1251 | auto logits = torch::tensor({{0.2, 0.8, 0.1}}); |
1252 | logits = logits.reshape({1, 3}); |
1253 | logits.requires_grad(); |
1254 | auto probs = logits.softmax(-1); |
1255 | |
1256 | auto counts = torch::zeros_like(logits); |
1257 | torch::Tensor y_draw; |
1258 | for (const auto i : c10::irange(num_draws)) { |
1259 | (void)i; // Suppress unused variable warning |
1260 | y_draw = |
1261 | F::gumbel_softmax(logits, F::GumbelSoftmaxFuncOptions().hard(true)); |
1262 | counts += y_draw; |
1263 | } |
1264 | |
1265 | // All values positive |
1266 | ASSERT_GE(y_draw.min().item<int>(), 0); |
1267 | // Each experiment should result in 1 draw |
1268 | ASSERT_EQ(counts.sum().item<int>(), num_draws); |
1269 | |
1270 | // Check results are asymptotically as expected |
1271 | auto expected = probs * num_draws; |
1272 | // ~z is approximately N(0,1) for unbiased count |
1273 | auto z = (counts - expected) / (expected * (1 - probs)).sqrt(); |
1274 | // A (lazy) approximate 99% two-sided test: |
1275 | // occurs with prob alpha~>=0.01 if unbiased |
1276 | ASSERT_LT(z.abs().max().item<float>(), 2.58); |
1277 | } |
1278 | } |
1279 | |
1280 | TEST_F(FunctionalTest, Softmax) { |
1281 | auto input = torch::arange(10, torch::kFloat).reshape({2, 5}); |
1282 | // NOLINTNEXTLINE(bugprone-argument-comment) |
1283 | auto output = F::softmax(input, /*dim=*/1); |
1284 | auto sum = torch::sum(torch::exp(input), 1); |
1285 | |
1286 | for (const auto i : c10::irange(2)) { |
1287 | auto expected = torch::exp(input[i]) / sum[i]; |
1288 | ASSERT_TRUE(torch::allclose(output[i], expected)); |
1289 | } |
1290 | } |
1291 | |
1292 | TEST_F(FunctionalTest, Softmin) { |
1293 | auto input = torch::arange(10, torch::kFloat).reshape({2, 5}); |
1294 | // NOLINTNEXTLINE(bugprone-argument-comment) |
1295 | auto output = F::softmin(input, /*dim=*/1); |
1296 | auto sum = torch::sum(torch::exp(-input), 1); |
1297 | |
1298 | for (const auto i : c10::irange(2)) { |
1299 | auto expected = torch::exp(-input[i]) / sum[i]; |
1300 | ASSERT_TRUE(torch::allclose(output[i], expected)); |
1301 | } |
1302 | } |
1303 | |
1304 | TEST_F(FunctionalTest, LogSoftmax) { |
1305 | auto input = torch::arange(10, torch::kFloat).reshape({2, 5}); |
1306 | // NOLINTNEXTLINE(bugprone-argument-comment) |
1307 | auto output = F::log_softmax(input, /*dim=*/1); |
1308 | auto sum = torch::sum(torch::exp(input), 1); |
1309 | |
1310 | for (const auto i : c10::irange(2)) { |
1311 | auto expected = torch::log(torch::exp(input[i]) / sum[i]); |
1312 | ASSERT_TRUE(torch::allclose(output[i], expected)); |
1313 | } |
1314 | } |
1315 | |
1316 | TEST_F(FunctionalTest, PReLU) { |
1317 | const auto x = torch::rand({42, 24}) * 200 - 100; |
1318 | const auto w = torch::rand(24) * 200 - 100; |
1319 | const auto y = F::prelu(x, w); |
1320 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({42, 24})); |
1321 | const auto y_exp = (x < 0) * w * x + (x >= 0) * x; |
1322 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1323 | } |
1324 | |
1325 | TEST_F(FunctionalTest, LayerNorm) { |
1326 | const auto input = torch::randn({2, 2}); |
1327 | auto y = F::layer_norm(input, F::LayerNormFuncOptions({2, 2}).eps(2e-5)); |
1328 | auto y_exp = |
1329 | torch::layer_norm(input, {2, 2}, torch::Tensor(), torch::Tensor(), 2e-5); |
1330 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1331 | } |
1332 | |
1333 | TEST_F(FunctionalTest, GroupNorm) { |
1334 | const auto input = torch::randn({2, 2}); |
1335 | auto y = F::group_norm(input, F::GroupNormFuncOptions(2).eps(2e-5)); |
1336 | auto y_exp = |
1337 | torch::group_norm(input, 2, torch::Tensor(), torch::Tensor(), 2e-5); |
1338 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1339 | } |
1340 | |
1341 | TEST_F(FunctionalTest, LocalResponseNorm) { |
1342 | const auto x = torch::arange(100, 118).resize_({3, 3, 2}); |
1343 | const auto y = F::local_response_norm(x, F::LocalResponseNormFuncOptions(2)); |
1344 | ASSERT_EQ(y.ndimension(), 3); |
1345 | ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 3, 2})); |
1346 | const auto y_exp = torch::tensor( |
1347 | {{{73.7788, 74.1462}, {60.1942, 60.3302}, {60.4609, 60.5865}}, |
1348 | {{75.8729, 76.2011}, {60.9331, 61.0390}, {61.1403, 61.2370}}, |
1349 | {{77.7387, 78.0303}, {61.5011, 61.5807}, {61.6563, 61.7279}}}, |
1350 | torch::kFloat); |
1351 | ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7)); |
1352 | } |
1353 | |
1354 | TEST_F(FunctionalTest, Linear) { |
1355 | { |
1356 | const auto x = torch::arange(100., 118).resize_({3, 3, 2}); |
1357 | const auto w = torch::arange(200., 206).resize_({3, 2}); |
1358 | const auto b = torch::arange(300., 303); |
1359 | const auto y = F::linear(x, w, b); |
1360 | ASSERT_EQ(y.ndimension(), 3); |
1361 | ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 3, 3})); |
1362 | const auto y_exp = torch::tensor( |
1363 | {{{40601, 41004, 41407}, {41403, 41814, 42225}, {42205, 42624, 43043}}, |
1364 | {{43007, 43434, 43861}, {43809, 44244, 44679}, {44611, 45054, 45497}}, |
1365 | {{45413, 45864, 46315}, {46215, 46674, 47133}, {47017, 47484, 47951}}}, |
1366 | torch::kFloat); |
1367 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1368 | } |
1369 | { |
1370 | const auto x = torch::arange(100., 118).resize_({3, 3, 2}); |
1371 | const auto w = torch::arange(200., 206).resize_({3, 2}); |
1372 | const auto y = F::linear(x, w); |
1373 | ASSERT_EQ(y.ndimension(), 3); |
1374 | ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 3, 3})); |
1375 | const auto y_exp = torch::tensor( |
1376 | {{{40301, 40703, 41105}, {41103, 41513, 41923}, {41905, 42323, 42741}}, |
1377 | {{42707, 43133, 43559}, {43509, 43943, 44377}, {44311, 44753, 45195}}, |
1378 | {{45113, 45563, 46013}, {45915, 46373, 46831}, {46717, 47183, 47649}}}, |
1379 | torch::kFloat); |
1380 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1381 | } |
1382 | } |
1383 | |
1384 | TEST_F(FunctionalTest, Embedding) { |
1385 | const auto input = torch::tensor({{1, 2, 4, 5}, {4, 3, 2, 9}}, torch::kLong); |
1386 | auto weight = torch::empty({10, 3}); |
1387 | torch::nn::init::normal_(weight); |
1388 | auto y = F::embedding(input, weight); |
1389 | auto y_exp = torch::embedding(weight, input.contiguous(), -1, false, false); |
1390 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1391 | } |
1392 | |
1393 | TEST_F(FunctionalTest, EmbeddingBag) { |
1394 | const auto input = torch::tensor({1, 2, 4, 5, 4, 3, 2, 9}, torch::kLong); |
1395 | auto offsets = torch::tensor({0, 4}, torch::kLong); |
1396 | auto weight = torch::empty({10, 3}); |
1397 | torch::nn::init::normal_(weight); |
1398 | auto y = F::embedding_bag( |
1399 | input, |
1400 | weight, |
1401 | F::EmbeddingBagFuncOptions() |
1402 | .mode(torch::kSum) |
1403 | .offsets(offsets) |
1404 | .padding_idx(4)); |
1405 | auto y_exp = std::get<0>(torch::embedding_bag( |
1406 | weight, input, offsets, false, 0, false, torch::Tensor(), false, 4)); |
1407 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1408 | |
1409 | // no options test |
1410 | const auto input_ = torch::tensor({{1, 2, 4, 5}, {4, 3, 2, 9}}, torch::kLong); |
1411 | auto offsets_ = torch::arange( |
1412 | 0, |
1413 | input_.numel(), |
1414 | input_.size(1), |
1415 | torch::TensorOptions().dtype(torch::kLong).device(input.device())); |
1416 | y = F::embedding_bag(input_, weight); |
1417 | y_exp = std::get<0>(torch::embedding_bag( |
1418 | weight, input_.reshape(-1), offsets_, false, 1, false, torch::Tensor())); |
1419 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1420 | } |
1421 | |
1422 | TEST_F(FunctionalTest, Bilinear) { |
1423 | auto input1 = torch::tensor({{1, 2, 3}, {7, 6, 5}}); |
1424 | auto input2 = torch::tensor({{7, 4}, {8, 9}}); |
1425 | auto weight = torch::tensor({{{2, 3}, {9, 7}, {8, 6}}}); |
1426 | auto bias = torch::tensor({1}); |
1427 | |
1428 | auto y_with_bias = F::bilinear(input1, input2, weight, bias); |
1429 | ASSERT_EQ(y_with_bias.ndimension(), 2); |
1430 | ASSERT_EQ(y_with_bias.sizes(), torch::IntArrayRef({2, 1})); |
1431 | auto y_with_bias_exp = torch::tensor({{449}, {1702}}).reshape({2, 1}); |
1432 | ASSERT_TRUE(torch::allclose(y_with_bias, y_with_bias_exp, 1e-4, 1e-7)); |
1433 | |
1434 | auto y_no_bias = F::bilinear(input1, input2, weight); |
1435 | ASSERT_EQ(y_no_bias.ndimension(), 2); |
1436 | ASSERT_EQ(y_no_bias.sizes(), torch::IntArrayRef({2, 1})); |
1437 | auto y_no_bias_exp = torch::tensor({{448, 1701}}).reshape({2, 1}); |
1438 | ASSERT_TRUE(torch::allclose(y_no_bias, y_no_bias_exp, 1e-4, 1e-7)); |
1439 | } |
1440 | |
1441 | TEST_F(FunctionalTest, Normalize) { |
1442 | const auto expected = torch::tensor( |
1443 | {{{0.00000000, 0.10000000, 0.2000, 0.30000000, 0.40000000}, |
1444 | {0.14285715, 0.17142858, 0.2000, 0.22857143, 0.25714287}}}, |
1445 | torch::requires_grad().dtype(torch::kFloat)); |
1446 | { // Test #1 |
1447 | auto input = torch::tensor( |
1448 | {{{0, 1, 2, 3, 4}, {5, 6, 7, 8, 9}}}, |
1449 | torch::dtype(torch::kFloat).requires_grad(true)); |
1450 | auto norm = F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1)); |
1451 | |
1452 | // reduce to scalar to call .backward() |
1453 | torch::Tensor s = norm.sum(); |
1454 | s.backward(); |
1455 | |
1456 | ASSERT_EQ(s.ndimension(), 0); |
1457 | ASSERT_EQ(input.grad().numel(), 10); |
1458 | ASSERT_TRUE(torch::allclose(norm, expected)); |
1459 | } |
1460 | |
1461 | { // Test #2 Check variations of optional arguments |
1462 | auto input = torch::tensor( |
1463 | {{{0, 1, 2, 3, 4}, {5, 6, 7, 8, 9}}}, torch::dtype(torch::kFloat)); |
1464 | auto output = torch::randn({1, 2, 5}, torch::dtype(torch::kFloat)); |
1465 | // non-null output argument |
1466 | F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1).out(output)); |
1467 | // default options |
1468 | F::normalize(input); |
1469 | |
1470 | ASSERT_TRUE(torch::allclose(output, expected)); |
1471 | } |
1472 | |
1473 | { // Test #3 Base case of scalar tensor |
1474 | auto input = torch::randn({}, torch::requires_grad()); |
1475 | torch::Tensor norm = |
1476 | F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1)); |
1477 | norm.backward(); |
1478 | |
1479 | ASSERT_EQ(input.grad().numel(), 1); |
1480 | } |
1481 | } |
1482 | |
1483 | TEST_F(FunctionalTest, ReLU) { |
1484 | const auto size = 3; |
1485 | for (const auto inplace : {false, true}) { |
1486 | auto x = torch::linspace(-10.0, 10.0, size * size * size); |
1487 | x.resize_({size, size, size}); |
1488 | auto y_exp = (x < 0) * 0 + (x >= 0) * x; |
1489 | auto y = F::relu(x, F::ReLUFuncOptions().inplace(inplace)); |
1490 | |
1491 | ASSERT_EQ(y.ndimension(), 3); |
1492 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
1493 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1494 | if (inplace) { |
1495 | ASSERT_TRUE(torch::allclose(x, y_exp)); |
1496 | } |
1497 | |
1498 | // NOLINTNEXTLINE(bugprone-argument-comment) |
1499 | y = F::relu(x, /*inplace=*/inplace); |
1500 | |
1501 | ASSERT_EQ(y.ndimension(), 3); |
1502 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
1503 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1504 | if (inplace) { |
1505 | ASSERT_TRUE(torch::allclose(x, y_exp)); |
1506 | } |
1507 | } |
1508 | ASSERT_TRUE(F::relu(torch::tensor(1.)).defined()); |
1509 | } |
1510 | |
1511 | TEST_F(FunctionalTest, ReLUDefaultOptions) { |
1512 | const auto size = 3; |
1513 | auto x = torch::linspace(-10.0, 10.0, size * size * size); |
1514 | x.resize_({size, size, size}); |
1515 | auto y_exp = (x < 0) * 0 + (x >= 0) * x; |
1516 | auto y = F::relu(x); |
1517 | |
1518 | ASSERT_EQ(y.ndimension(), 3); |
1519 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
1520 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1521 | } |
1522 | |
1523 | TEST_F(FunctionalTest, ReLU6) { |
1524 | const auto size = 3; |
1525 | for (const auto inplace : {false, true}) { |
1526 | auto x = torch::linspace(-10.0, 10.0, size * size * size); |
1527 | x.resize_({size, size, size}); |
1528 | auto y_exp = (x < 0) * 0 + ((x >= 0) * (x <= 6)) * x + (x > 6) * 6; |
1529 | auto y = F::relu6(x, F::ReLU6FuncOptions().inplace(inplace)); |
1530 | |
1531 | ASSERT_EQ(y.ndimension(), 3); |
1532 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
1533 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1534 | if (inplace) { |
1535 | ASSERT_TRUE(torch::allclose(x, y_exp)); |
1536 | } |
1537 | |
1538 | // NOLINTNEXTLINE(bugprone-argument-comment) |
1539 | y = F::relu6(x, /*inplace=*/inplace); |
1540 | |
1541 | ASSERT_EQ(y.ndimension(), 3); |
1542 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
1543 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1544 | if (inplace) { |
1545 | ASSERT_TRUE(torch::allclose(x, y_exp)); |
1546 | } |
1547 | } |
1548 | ASSERT_TRUE(F::relu6(torch::tensor(1.)).defined()); |
1549 | } |
1550 | |
1551 | TEST_F(FunctionalTest, ReLU6DefaultOptions) { |
1552 | const auto size = 3; |
1553 | auto x = torch::linspace(-10.0, 10.0, size * size * size); |
1554 | x.resize_({size, size, size}); |
1555 | auto y_exp = (x < 0) * 0 + ((x >= 0) * (x <= 6)) * x + (x > 6) * 6; |
1556 | auto y = F::relu6(x); |
1557 | |
1558 | ASSERT_EQ(y.ndimension(), 3); |
1559 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
1560 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1561 | } |
1562 | |
1563 | TEST_F(FunctionalTest, RReLU) { |
1564 | const auto size = 3; |
1565 | for (const auto lower : {0.01, 0.1, 0.2}) { |
1566 | for (const auto upper : {0.3, 0.4, 0.5}) { |
1567 | for (const auto inplace : {false, true}) { |
1568 | for (const auto type : {torch::kFloat, torch::kBFloat16}) { |
1569 | auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type); |
1570 | x.resize_({size, size, size}); |
1571 | auto x_copy = x.clone(); |
1572 | auto y = F::rrelu( |
1573 | x, |
1574 | F::RReLUFuncOptions().lower(lower).upper(upper).inplace(inplace)); |
1575 | auto z = |
1576 | ((x_copy >= 0) * (x_copy == y) + |
1577 | (x_copy < 0) * (y >= x_copy * upper) * (y <= lower * x_copy)) * |
1578 | 1.0; |
1579 | |
1580 | ASSERT_EQ(y.ndimension(), 3); |
1581 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
1582 | ASSERT_TRUE(torch::allclose(z, torch::ones_like(z))); |
1583 | if (inplace) { |
1584 | ASSERT_TRUE(torch::allclose(x, y)); |
1585 | } |
1586 | } |
1587 | } |
1588 | } |
1589 | } |
1590 | ASSERT_TRUE(F::rrelu(torch::tensor(1.)).defined()); |
1591 | } |
1592 | |
1593 | TEST_F(FunctionalTest, RReLUDefaultOptions) { |
1594 | const auto size = 3; |
1595 | const auto lower = 1.0 / 8.0; |
1596 | const auto upper = 1.0 / 3.0; |
1597 | for (const auto type : {torch::kFloat, torch::kBFloat16}) { |
1598 | auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type); |
1599 | x.resize_({size, size, size}); |
1600 | auto x_copy = x.clone(); |
1601 | auto y = F::rrelu(x); |
1602 | auto z = ((x_copy >= 0) * (x_copy == y) + |
1603 | (x_copy < 0) * (y >= x_copy * upper) * (y <= lower * x_copy)) * |
1604 | 1.0; |
1605 | |
1606 | ASSERT_EQ(y.ndimension(), 3); |
1607 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
1608 | ASSERT_TRUE(torch::allclose(z, torch::ones_like(z))); |
1609 | } |
1610 | } |
1611 | |
1612 | TEST_F(FunctionalTest, CELU) { |
1613 | const auto size = 3; |
1614 | for (const auto inplace : {false, true}) { |
1615 | for (const auto alpha : {0.42, 1.0, 4.2, 42.42}) { |
1616 | auto x = torch::linspace(-10.0, 10.0, size * size * size); |
1617 | x.resize_({size, size, size}); |
1618 | auto x_bf16 = x.clone().to(torch::kBFloat16); |
1619 | auto y_exp = torch::max(torch::zeros_like(x), x) + |
1620 | torch::min(torch::zeros_like(x), |
1621 | alpha * (torch::exp(x / alpha) - 1.0)); |
1622 | auto y = F::celu(x, F::CELUFuncOptions().alpha(alpha).inplace(inplace)); |
1623 | auto y_bf16 = |
1624 | F::celu(x_bf16, F::CELUFuncOptions().alpha(alpha).inplace(inplace)); |
1625 | |
1626 | ASSERT_EQ(y.ndimension(), 3); |
1627 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
1628 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1629 | ASSERT_TRUE(torch::allclose(y_bf16.to(torch::kFloat), y, 1e-2, 1e-2)); |
1630 | if (inplace) { |
1631 | ASSERT_TRUE(torch::allclose(x, y_exp)); |
1632 | ASSERT_TRUE(torch::allclose(x_bf16.to(torch::kFloat), y, 1e-2, 1e-2)); |
1633 | } |
1634 | } |
1635 | } |
1636 | ASSERT_TRUE(F::celu(torch::tensor(1.)).defined()); |
1637 | } |
1638 | |
1639 | TEST_F(FunctionalTest, CELUDefaultOptions) { |
1640 | const auto size = 3; |
1641 | const auto alpha = 1.0; |
1642 | auto x = torch::linspace(-10.0, 10.0, size * size * size); |
1643 | x.resize_({size, size, size}); |
1644 | auto x_bf16 = x.clone().to(torch::kBFloat16); |
1645 | auto y_exp = torch::max(torch::zeros_like(x), x) + |
1646 | torch::min(torch::zeros_like(x), alpha * (torch::exp(x / alpha) - 1.0)); |
1647 | auto y = F::celu(x); |
1648 | auto y_bf16 = F::celu(x_bf16); |
1649 | |
1650 | ASSERT_EQ(y.ndimension(), 3); |
1651 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
1652 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1653 | ASSERT_TRUE(torch::allclose(y_bf16.to(torch::kFloat), y, 1e-2, 1e-2)); |
1654 | } |
1655 | |
1656 | TEST_F(FunctionalTest, PixelShuffle) { |
1657 | auto x = torch::tensor( |
1658 | {{{{-17, 19}, {-1, 2}}, |
1659 | {{7, 14}, {-3, 1}}, |
1660 | {{0, -2}, {-12, 14}}, |
1661 | {{-15, 0}, {-3, 9}}}}, |
1662 | torch::kFloat); |
1663 | auto y_exp = torch::tensor( |
1664 | {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}}, |
1665 | torch::kFloat); |
1666 | auto y = F::pixel_shuffle(x, 2); |
1667 | |
1668 | ASSERT_EQ(y.ndimension(), 4); |
1669 | ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 4, 4})); |
1670 | ASSERT_TRUE(y.allclose(y_exp)); |
1671 | } |
1672 | |
1673 | TEST_F(FunctionalTest, PixelUnshuffle) { |
1674 | auto x = torch::tensor( |
1675 | {{{{-17, 7, 19, 14}, {0, -15, -2, 0}, {-1, -3, 2, 1}, {-12, -3, 14, 9}}}}, |
1676 | torch::kFloat); |
1677 | auto y_exp = torch::tensor( |
1678 | {{{{-17, 19}, {-1, 2}}, |
1679 | {{7, 14}, {-3, 1}}, |
1680 | {{0, -2}, {-12, 14}}, |
1681 | {{-15, 0}, {-3, 9}}}}, |
1682 | torch::kFloat); |
1683 | auto y = F::pixel_unshuffle(x, 2); |
1684 | |
1685 | ASSERT_EQ(y.ndimension(), 4); |
1686 | ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 4, 2, 2})); |
1687 | ASSERT_TRUE(y.allclose(y_exp)); |
1688 | } |
1689 | |
1690 | TEST_F(FunctionalTest, Softplus) { |
1691 | const auto size = 3; |
1692 | for (const auto beta : {0.5, 1.0, 2.0}) { |
1693 | for (const auto threshold : {1.0, 3.0, 5.0}) { |
1694 | auto x = torch::linspace(-3.0, 3.0, 61); |
1695 | x.resize_({size, size, size}); |
1696 | auto y_exp = |
1697 | (x <= threshold) * torch::log(1 + torch::exp(x * beta)) / beta + |
1698 | (x > threshold) * x; |
1699 | auto y = F::softplus( |
1700 | x, F::SoftplusFuncOptions().beta(beta).threshold(threshold)); |
1701 | |
1702 | ASSERT_EQ(y.ndimension(), 3); |
1703 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
1704 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1705 | } |
1706 | } |
1707 | } |
1708 | |
1709 | TEST_F(FunctionalTest, SoftplusDefaultOptions) { |
1710 | const auto size = 3; |
1711 | const auto beta = 1.0; |
1712 | const auto threshold = 20.0; |
1713 | auto x = torch::linspace(-3.0, 3.0, 61); |
1714 | x.resize_({size, size, size}); |
1715 | auto y_exp = (x <= threshold) * torch::log(1 + torch::exp(x * beta)) / beta + |
1716 | (x > threshold) * x; |
1717 | auto y = F::softplus(x); |
1718 | |
1719 | ASSERT_EQ(y.ndimension(), 3); |
1720 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
1721 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1722 | } |
1723 | |
1724 | TEST_F(FunctionalTest, Fold) { |
1725 | auto input = torch::ones({1, 3 * 2 * 2, 2}, torch::kDouble); |
1726 | auto output = F::fold(input, F::FoldFuncOptions({3, 2}, {2, 2})); |
1727 | auto expected = torch::tensor( |
1728 | {{{{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}}, |
1729 | {{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}}, |
1730 | {{1.0, 1.0}, {2.0, 2.0}, {1.0, 1.0}}}}, |
1731 | torch::kDouble); |
1732 | |
1733 | ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 3, 3, 2})); |
1734 | ASSERT_TRUE(output.allclose(expected)); |
1735 | } |
1736 | |
1737 | TEST_F(FunctionalTest, Unfold) { |
1738 | auto input = torch::arange(0, 12, torch::kDouble).view({1, 2, 2, 3}); |
1739 | auto output = |
1740 | F::unfold(input, F::UnfoldFuncOptions({2, 2}).padding(1).stride(2)); |
1741 | auto expected = torch::tensor( |
1742 | {{{0.0, 0.0, 0.0, 4.0}, |
1743 | {0.0, 0.0, 3.0, 5.0}, |
1744 | {0.0, 1.0, 0.0, 0.0}, |
1745 | {0.0, 2.0, 0.0, 0.0}, |
1746 | {0.0, 0.0, 0.0, 10.0}, |
1747 | {0.0, 0.0, 9.0, 11.0}, |
1748 | {0.0, 7.0, 0.0, 0.0}, |
1749 | {6.0, 8.0, 0.0, 0.0}}}, |
1750 | torch::kDouble); |
1751 | |
1752 | ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 8, 4})); |
1753 | ASSERT_TRUE(output.allclose(expected)); |
1754 | } |
1755 | |
1756 | TEST_F(FunctionalTest, Softshrink) { |
1757 | const auto size = 3; |
1758 | for (const auto lambda : {0.0, 0.42, 1.0, 4.2, 42.42}) { |
1759 | auto x = torch::linspace(-10.0, 10.0, size * size * size); |
1760 | x.resize_({size, size, size}).set_requires_grad(true); |
1761 | // NOLINTNEXTLINE(bugprone-argument-comment) |
1762 | auto y = F::softshrink(x, /*lambda=*/lambda); |
1763 | torch::Tensor s = y.sum(); |
1764 | |
1765 | s.backward(); |
1766 | ASSERT_EQ(s.ndimension(), 0); |
1767 | |
1768 | ASSERT_EQ(y.ndimension(), 3); |
1769 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
1770 | auto y_exp = (x < -lambda) * (x + lambda) + (x > lambda) * (x - lambda); |
1771 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1772 | } |
1773 | } |
1774 | |
1775 | TEST_F(FunctionalTest, SoftshrinkDefaultOptions) { |
1776 | const auto size = 3; |
1777 | const auto lambda = 0.5; |
1778 | auto x = torch::linspace(-10.0, 10.0, size * size * size); |
1779 | x.resize_({size, size, size}).set_requires_grad(true); |
1780 | auto y = F::softshrink(x); |
1781 | torch::Tensor s = y.sum(); |
1782 | |
1783 | s.backward(); |
1784 | ASSERT_EQ(s.ndimension(), 0); |
1785 | |
1786 | ASSERT_EQ(y.ndimension(), 3); |
1787 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
1788 | auto y_exp = (x < -lambda) * (x + lambda) + (x > lambda) * (x - lambda); |
1789 | } |
1790 | |
1791 | TEST_F(FunctionalTest, Softsign) { |
1792 | auto x = torch::randn(100) * 10; |
1793 | auto y_exp = x / (1 + x.abs()); |
1794 | auto y = F::softsign(x); |
1795 | |
1796 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1797 | } |
1798 | |
1799 | TEST_F(FunctionalTest, Mish) { |
1800 | auto x = torch::randn(100) * 10; |
1801 | auto y_exp = x * x.exp().log1p().tanh(); |
1802 | auto y = F::mish(x); |
1803 | |
1804 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1805 | } |
1806 | |
1807 | TEST_F(FunctionalTest, Tanhshrink) { |
1808 | auto x = torch::randn(100) * 10; |
1809 | auto y_exp = x - x.tanh(); |
1810 | auto y = F::tanhshrink(x); |
1811 | |
1812 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1813 | } |
1814 | |
1815 | TEST_F(FunctionalTest, Threshold) { |
1816 | const auto size = 3; |
1817 | for (const auto threshold : {0.5, 1.0, 2.0}) { |
1818 | for (const auto value : {0.5, 1.0, 2.0}) { |
1819 | for (const auto inplace : {false, true}) { |
1820 | auto x = torch::linspace(-3.0, 3.0, 61); |
1821 | x.resize_({size, size, size}); |
1822 | auto y_exp = (x <= threshold) * value + (x > threshold) * x; |
1823 | auto y = F::threshold( |
1824 | x, F::ThresholdFuncOptions(threshold, value).inplace(inplace)); |
1825 | |
1826 | ASSERT_EQ(y.ndimension(), 3); |
1827 | ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size})); |
1828 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
1829 | if (inplace) { |
1830 | ASSERT_TRUE(torch::allclose(x, y_exp)); |
1831 | } |
1832 | } |
1833 | } |
1834 | } |
1835 | ASSERT_TRUE(F::threshold(torch::tensor(1.), F::ThresholdFuncOptions(0.5, 0.5)) |
1836 | .defined()); |
1837 | } |
1838 | |
1839 | TEST_F(FunctionalTest, BatchNorm1d) { |
1840 | int num_features = 5; |
1841 | double eps = 1e-05; |
1842 | double momentum = 0.1; |
1843 | |
1844 | auto input = torch::randn({2, 5}); |
1845 | auto mean = torch::randn(5); |
1846 | auto variance = torch::rand(5); |
1847 | auto weight = torch::ones({num_features}); |
1848 | auto bias = torch::zeros({num_features}); |
1849 | auto output = F::batch_norm( |
1850 | input, |
1851 | mean, |
1852 | variance, |
1853 | F::BatchNormFuncOptions() |
1854 | .weight(weight) |
1855 | .bias(bias) |
1856 | .momentum(momentum) |
1857 | .eps(eps) |
1858 | .training(false)); |
1859 | auto expected = (input - mean) / torch::sqrt(variance + eps); |
1860 | ASSERT_TRUE(output.allclose(expected)); |
1861 | } |
1862 | |
1863 | TEST_F(FunctionalTest, BatchNorm1dDefaultOptions) { |
1864 | auto input = torch::randn({2, 5}); |
1865 | auto mean = torch::randn(5); |
1866 | auto variance = torch::rand(5); |
1867 | auto output = F::batch_norm(input, mean, variance); |
1868 | auto expected = (input - mean) / torch::sqrt(variance + 1e-5); |
1869 | ASSERT_TRUE(output.allclose(expected)); |
1870 | } |
1871 | |
1872 | TEST_F(FunctionalTest, BatchNorm2d) { |
1873 | int num_features = 5; |
1874 | double eps = 1e-05; |
1875 | double momentum = 0.1; |
1876 | |
1877 | auto input = torch::randn({2, num_features, 4, 4}); |
1878 | auto mean = torch::randn(num_features); |
1879 | auto variance = torch::rand(num_features); |
1880 | auto weight = torch::ones({num_features}); |
1881 | auto bias = torch::zeros({num_features}); |
1882 | auto output = F::batch_norm( |
1883 | input, |
1884 | mean, |
1885 | variance, |
1886 | F::BatchNormFuncOptions() |
1887 | .weight(weight) |
1888 | .bias(bias) |
1889 | .momentum(momentum) |
1890 | .eps(eps) |
1891 | .training(false)); |
1892 | auto expected = torch::transpose( |
1893 | (torch::transpose(input, 1, 3) - mean) / torch::sqrt(variance + eps), |
1894 | 1, |
1895 | 3); |
1896 | ASSERT_TRUE(output.allclose(expected)); |
1897 | } |
1898 | |
1899 | TEST_F(FunctionalTest, BatchNorm2dDefaultOptions) { |
1900 | int num_features = 5; |
1901 | double eps = 1e-05; |
1902 | |
1903 | auto input = torch::randn({2, num_features, 4, 4}); |
1904 | auto mean = torch::randn(num_features); |
1905 | auto variance = torch::rand(num_features); |
1906 | auto output = F::batch_norm(input, mean, variance); |
1907 | auto expected = torch::transpose( |
1908 | (torch::transpose(input, 1, 3) - mean) / torch::sqrt(variance + eps), |
1909 | 1, |
1910 | 3); |
1911 | ASSERT_TRUE(output.allclose(expected)); |
1912 | } |
1913 | |
1914 | TEST_F(FunctionalTest, BatchNorm3d) { |
1915 | int num_features = 5; |
1916 | double eps = 1e-05; |
1917 | double momentum = 0.1; |
1918 | |
1919 | auto input = torch::randn({2, num_features, 2, 2, 2}); |
1920 | auto mean = torch::randn(num_features); |
1921 | auto variance = torch::rand(num_features); |
1922 | auto weight = torch::ones({num_features}); |
1923 | auto bias = torch::zeros({num_features}); |
1924 | auto output = F::batch_norm( |
1925 | input, |
1926 | mean, |
1927 | variance, |
1928 | F::BatchNormFuncOptions() |
1929 | .weight(weight) |
1930 | .bias(bias) |
1931 | .momentum(momentum) |
1932 | .eps(eps) |
1933 | .training(false)); |
1934 | auto expected = torch::transpose( |
1935 | (torch::transpose(input, 1, 4) - mean) / torch::sqrt(variance + eps), |
1936 | 1, |
1937 | 4); |
1938 | ASSERT_TRUE(output.allclose(expected)); |
1939 | } |
1940 | |
1941 | TEST_F(FunctionalTest, BatchNorm3dDefaultOptions) { |
1942 | int num_features = 5; |
1943 | double eps = 1e-05; |
1944 | |
1945 | auto input = torch::randn({2, num_features, 2, 2, 2}); |
1946 | auto mean = torch::randn(num_features); |
1947 | auto variance = torch::rand(num_features); |
1948 | auto output = F::batch_norm(input, mean, variance); |
1949 | auto expected = torch::transpose( |
1950 | (torch::transpose(input, 1, 4) - mean) / torch::sqrt(variance + eps), |
1951 | 1, |
1952 | 4); |
1953 | ASSERT_TRUE(output.allclose(expected)); |
1954 | } |
1955 | |
1956 | TEST_F(FunctionalTest, InstanceNorm1d) { |
1957 | int num_features = 5; |
1958 | double eps = 1e-05; |
1959 | double momentum = 0.1; |
1960 | |
1961 | auto input = torch::arange(40.).view({2, 5, 4}); |
1962 | auto mean = torch::arange(5.); |
1963 | auto variance = torch::arange(5.); |
1964 | auto weight = torch::arange((double)num_features); |
1965 | auto bias = torch::arange((double)num_features); |
1966 | auto output = F::instance_norm( |
1967 | input, |
1968 | F::InstanceNormFuncOptions() |
1969 | .running_mean(mean) |
1970 | .running_var(variance) |
1971 | .weight(weight) |
1972 | .bias(bias) |
1973 | .momentum(momentum) |
1974 | .eps(eps)); |
1975 | auto expected = torch::tensor( |
1976 | {{{0.0000, 0.0000, 0.0000, 0.0000}, |
1977 | {-0.3416, 0.5528, 1.4472, 2.3416}, |
1978 | {-0.6833, 1.1056, 2.8944, 4.6833}, |
1979 | {-1.0249, 1.6584, 4.3416, 7.0249}, |
1980 | {-1.3665, 2.2112, 5.7888, 9.3665}}, |
1981 | {{0.0000, 0.0000, 0.0000, 0.0000}, |
1982 | {-0.3416, 0.5528, 1.4472, 2.3416}, |
1983 | {-0.6833, 1.1056, 2.8944, 4.6833}, |
1984 | {-1.0249, 1.6584, 4.3416, 7.0249}, |
1985 | {-1.3665, 2.2112, 5.7888, 9.3665}}}); |
1986 | ASSERT_TRUE(output.allclose(expected, 2e-04)); |
1987 | } |
1988 | |
1989 | TEST_F(FunctionalTest, InstanceNorm1dDefaultOptions) { |
1990 | auto input = torch::arange(40.).view({2, 5, 4}); |
1991 | auto output = F::instance_norm(input); |
1992 | auto expected = torch::tensor( |
1993 | {{{-1.3416, -0.4472, 0.4472, 1.3416}, |
1994 | {-1.3416, -0.4472, 0.4472, 1.3416}, |
1995 | {-1.3416, -0.4472, 0.4472, 1.3416}, |
1996 | {-1.3416, -0.4472, 0.4472, 1.3416}, |
1997 | {-1.3416, -0.4472, 0.4472, 1.3416}}, |
1998 | {{-1.3416, -0.4472, 0.4472, 1.3416}, |
1999 | {-1.3416, -0.4472, 0.4472, 1.3416}, |
2000 | {-1.3416, -0.4472, 0.4472, 1.3416}, |
2001 | {-1.3416, -0.4472, 0.4472, 1.3416}, |
2002 | {-1.3416, -0.4472, 0.4472, 1.3416}}}); |
2003 | ASSERT_TRUE(output.allclose(expected, 2e-04)); |
2004 | } |
2005 | |
2006 | TEST_F(FunctionalTest, InstanceNorm2d) { |
2007 | int num_features = 5; |
2008 | double eps = 1e-05; |
2009 | double momentum = 0.1; |
2010 | |
2011 | auto input = |
2012 | torch::arange(2. * num_features * 2 * 2).view({2, num_features, 2, 2}); |
2013 | auto mean = torch::arange((double)num_features); |
2014 | auto variance = torch::arange((double)num_features); |
2015 | auto weight = torch::arange((double)num_features); |
2016 | auto bias = torch::arange((double)num_features); |
2017 | auto output = F::instance_norm( |
2018 | input, |
2019 | F::InstanceNormFuncOptions() |
2020 | .running_mean(mean) |
2021 | .running_var(variance) |
2022 | .weight(weight) |
2023 | .bias(bias) |
2024 | .momentum(momentum) |
2025 | .eps(eps)); |
2026 | auto expected = torch::tensor( |
2027 | {{{{0.0000, 0.0000}, {0.0000, 0.0000}}, |
2028 | {{-0.3416, 0.5528}, {1.4472, 2.3416}}, |
2029 | {{-0.6833, 1.1056}, {2.8944, 4.6833}}, |
2030 | {{-1.0249, 1.6584}, {4.3416, 7.0249}}, |
2031 | {{-1.3665, 2.2112}, {5.7888, 9.3665}}}, |
2032 | {{{0.0000, 0.0000}, {0.0000, 0.0000}}, |
2033 | {{-0.3416, 0.5528}, {1.4472, 2.3416}}, |
2034 | {{-0.6833, 1.1056}, {2.8944, 4.6833}}, |
2035 | {{-1.0249, 1.6584}, {4.3416, 7.0249}}, |
2036 | {{-1.3665, 2.2112}, {5.7888, 9.3665}}}}); |
2037 | ASSERT_TRUE(output.allclose(expected, 2e-04)); |
2038 | } |
2039 | |
2040 | TEST_F(FunctionalTest, InstanceNorm2dDefaultOptions) { |
2041 | int num_features = 5; |
2042 | |
2043 | auto input = |
2044 | torch::arange(2. * num_features * 2 * 2).view({2, num_features, 2, 2}); |
2045 | auto output = F::instance_norm(input); |
2046 | auto expected = torch::tensor( |
2047 | {{{{-1.3416, -0.4472}, {0.4472, 1.3416}}, |
2048 | {{-1.3416, -0.4472}, {0.4472, 1.3416}}, |
2049 | {{-1.3416, -0.4472}, {0.4472, 1.3416}}, |
2050 | {{-1.3416, -0.4472}, {0.4472, 1.3416}}, |
2051 | {{-1.3416, -0.4472}, {0.4472, 1.3416}}}, |
2052 | {{{-1.3416, -0.4472}, {0.4472, 1.3416}}, |
2053 | {{-1.3416, -0.4472}, {0.4472, 1.3416}}, |
2054 | {{-1.3416, -0.4472}, {0.4472, 1.3416}}, |
2055 | {{-1.3416, -0.4472}, {0.4472, 1.3416}}, |
2056 | {{-1.3416, -0.4472}, {0.4472, 1.3416}}}}); |
2057 | ASSERT_TRUE(output.allclose(expected, 2e-04)); |
2058 | } |
2059 | |
2060 | TEST_F(FunctionalTest, InstanceNorm3d) { |
2061 | int num_features = 5; |
2062 | double eps = 1e-05; |
2063 | double momentum = 0.1; |
2064 | |
2065 | auto input = torch::arange(2. * num_features * 2 * 2 * 2) |
2066 | .view({2, num_features, 2, 2, 2}); |
2067 | auto mean = torch::arange((double)num_features); |
2068 | auto variance = torch::arange((double)num_features); |
2069 | auto weight = torch::arange((double)num_features); |
2070 | auto bias = torch::arange((double)num_features); |
2071 | auto output = F::instance_norm( |
2072 | input, |
2073 | F::InstanceNormFuncOptions() |
2074 | .running_mean(mean) |
2075 | .running_var(variance) |
2076 | .weight(weight) |
2077 | .bias(bias) |
2078 | .momentum(momentum) |
2079 | .eps(eps)); |
2080 | auto expected = torch::tensor( |
2081 | {{{{{0.0000, 0.0000}, {0.0000, 0.0000}}, |
2082 | {{0.0000, 0.0000}, {0.0000, 0.0000}}}, |
2083 | {{{-0.5275, -0.0911}, {0.3453, 0.7818}}, |
2084 | {{1.2182, 1.6547}, {2.0911, 2.5275}}}, |
2085 | {{{-1.0550, -0.1822}, {0.6907, 1.5636}}, |
2086 | {{2.4364, 3.3093}, {4.1822, 5.0550}}}, |
2087 | {{{-1.5826, -0.2733}, {1.0360, 2.3453}}, |
2088 | {{3.6547, 4.9640}, {6.2733, 7.5826}}}, |
2089 | {{{-2.1101, -0.3644}, {1.3814, 3.1271}}, |
2090 | {{4.8729, 6.6186}, {8.3644, 10.1101}}}}, |
2091 | {{{{0.0000, 0.0000}, {0.0000, 0.0000}}, |
2092 | {{0.0000, 0.0000}, {0.0000, 0.0000}}}, |
2093 | {{{-0.5275, -0.0911}, {0.3453, 0.7818}}, |
2094 | {{1.2182, 1.6547}, {2.0911, 2.5275}}}, |
2095 | {{{-1.0550, -0.1822}, {0.6907, 1.5636}}, |
2096 | {{2.4364, 3.3093}, {4.1822, 5.0550}}}, |
2097 | {{{-1.5826, -0.2733}, {1.0360, 2.3453}}, |
2098 | {{3.6547, 4.9640}, {6.2733, 7.5826}}}, |
2099 | {{{-2.1101, -0.3644}, {1.3814, 3.1271}}, |
2100 | {{4.8729, 6.6186}, {8.3644, 10.1101}}}}}); |
2101 | ASSERT_TRUE(output.allclose(expected, 2e-04)); |
2102 | } |
2103 | |
2104 | TEST_F(FunctionalTest, InstanceNorm3dDefaultOptions) { |
2105 | int num_features = 5; |
2106 | |
2107 | auto input = torch::arange(2. * num_features * 2 * 2 * 2) |
2108 | .view({2, num_features, 2, 2, 2}); |
2109 | auto output = F::instance_norm(input); |
2110 | auto expected = torch::tensor( |
2111 | {{{{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, |
2112 | {{0.2182, 0.6547}, {1.0911, 1.5275}}}, |
2113 | {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, |
2114 | {{0.2182, 0.6547}, {1.0911, 1.5275}}}, |
2115 | {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, |
2116 | {{0.2182, 0.6547}, {1.0911, 1.5275}}}, |
2117 | {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, |
2118 | {{0.2182, 0.6547}, {1.0911, 1.5275}}}, |
2119 | {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, |
2120 | {{0.2182, 0.6547}, {1.0911, 1.5275}}}}, |
2121 | {{{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, |
2122 | {{0.2182, 0.6547}, {1.0911, 1.5275}}}, |
2123 | {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, |
2124 | {{0.2182, 0.6547}, {1.0911, 1.5275}}}, |
2125 | {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, |
2126 | {{0.2182, 0.6547}, {1.0911, 1.5275}}}, |
2127 | {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, |
2128 | {{0.2182, 0.6547}, {1.0911, 1.5275}}}, |
2129 | {{{-1.5275, -1.0911}, {-0.6547, -0.2182}}, |
2130 | {{0.2182, 0.6547}, {1.0911, 1.5275}}}}}); |
2131 | ASSERT_TRUE(output.allclose(expected, 2e-04)); |
2132 | } |
2133 | |
2134 | TEST_F(FunctionalTest, Interpolate) { |
2135 | { |
2136 | // 1D interpolation |
2137 | auto input = torch::ones({1, 1, 2}); |
2138 | auto options = F::InterpolateFuncOptions() |
2139 | .size(std::vector<int64_t>({4})) |
2140 | .mode(torch::kNearest); |
2141 | auto output = F::interpolate(input, options); |
2142 | auto expected = torch::ones({1, 1, 4}); |
2143 | |
2144 | ASSERT_TRUE(output.allclose(expected)); |
2145 | } |
2146 | { |
2147 | // 2D interpolation |
2148 | for (const auto align_corners : {true, false}) { |
2149 | // test float scale factor up & down sampling |
2150 | for (const auto scale_factor : {0.5, 1.5, 2.0}) { |
2151 | auto input = torch::ones({1, 1, 2, 2}); |
2152 | auto options = |
2153 | F::InterpolateFuncOptions() |
2154 | .scale_factor(std::vector<double>({scale_factor, scale_factor})) |
2155 | .mode(torch::kBilinear) |
2156 | .align_corners(align_corners); |
2157 | auto output = F::interpolate(input, options); |
2158 | auto expected_size = |
2159 | static_cast<int64_t>(std::floor(input.size(-1) * scale_factor)); |
2160 | auto expected = torch::ones({1, 1, expected_size, expected_size}); |
2161 | |
2162 | ASSERT_TRUE(output.allclose(expected)); |
2163 | } |
2164 | } |
2165 | } |
2166 | { |
2167 | // 3D interpolation |
2168 | for (const auto align_corners : {true, false}) { |
2169 | for (const auto scale_factor : {0.5, 1.5, 2.0}) { |
2170 | auto input = torch::ones({1, 1, 2, 2, 2}); |
2171 | auto options = F::InterpolateFuncOptions() |
2172 | .scale_factor(std::vector<double>( |
2173 | {scale_factor, scale_factor, scale_factor})) |
2174 | .mode(torch::kTrilinear) |
2175 | .align_corners(align_corners); |
2176 | auto output = F::interpolate(input, options); |
2177 | auto expected_size = |
2178 | static_cast<int64_t>(std::floor(input.size(-1) * scale_factor)); |
2179 | auto expected = |
2180 | torch::ones({1, 1, expected_size, expected_size, expected_size}); |
2181 | |
2182 | ASSERT_TRUE(output.allclose(expected)); |
2183 | } |
2184 | } |
2185 | } |
2186 | { |
2187 | auto input = torch::randn({3, 2, 2}); |
2188 | ASSERT_THROWS_WITH( |
2189 | F::interpolate( |
2190 | input[0], |
2191 | F::InterpolateFuncOptions().size(std::vector<int64_t>({4, 4}))), |
2192 | "Input Error: Only 3D, 4D and 5D input Tensors supported (got 2D) " |
2193 | "for the modes: nearest | linear | bilinear | bicubic | trilinear (got kNearest)" ); |
2194 | ASSERT_THROWS_WITH( |
2195 | F::interpolate( |
2196 | torch::reshape(input, {1, 1, 1, 3, 2, 2}), |
2197 | F::InterpolateFuncOptions().size( |
2198 | std::vector<int64_t>({1, 1, 1, 3, 4, 4}))), |
2199 | "Input Error: Only 3D, 4D and 5D input Tensors supported (got 6D) " |
2200 | "for the modes: nearest | linear | bilinear | bicubic | trilinear (got kNearest)" ); |
2201 | ASSERT_THROWS_WITH( |
2202 | F::interpolate(input, F::InterpolateFuncOptions()), |
2203 | "either size or scale_factor should be defined" ); |
2204 | ASSERT_THROWS_WITH( |
2205 | F::interpolate( |
2206 | input, |
2207 | F::InterpolateFuncOptions() |
2208 | .size(std::vector<int64_t>({3, 4, 4})) |
2209 | .scale_factor(std::vector<double>({0.5}))), |
2210 | "only one of size or scale_factor should be defined" ); |
2211 | ASSERT_THROWS_WITH( |
2212 | F::interpolate( |
2213 | input, |
2214 | F::InterpolateFuncOptions().scale_factor( |
2215 | std::vector<double>({3, 2}))), |
2216 | "scale_factor shape must match input shape. " |
2217 | "Input is 1D, scale_factor size is [3, 2]" ); |
2218 | ASSERT_THROWS_WITH( |
2219 | F::interpolate( |
2220 | input, |
2221 | F::InterpolateFuncOptions() |
2222 | .mode(torch::kNearest) |
2223 | .align_corners(true)), |
2224 | "align_corners option can only be set with the " |
2225 | "interpolating modes: linear | bilinear | bicubic | trilinear" ); |
2226 | } |
2227 | { |
2228 | auto tensor = torch::rand({2, 3, 32, 32}); |
2229 | std::vector<int64_t> osize = {8, 10}; |
2230 | auto expected = |
2231 | at::native::_upsample_nearest_exact2d(tensor, osize, torch::nullopt); |
2232 | |
2233 | auto options = F::InterpolateFuncOptions() |
2234 | .size(osize) |
2235 | .mode(torch::kNearestExact) |
2236 | .align_corners(false); |
2237 | auto output = F::interpolate(tensor, options); |
2238 | |
2239 | ASSERT_TRUE(output.allclose(expected)); |
2240 | } |
2241 | { |
2242 | auto tensor = torch::rand({2, 3, 32, 32}); |
2243 | std::vector<int64_t> osize = {8, 10}; |
2244 | auto expected = at::native::_upsample_bilinear2d_aa( |
2245 | tensor, osize, false, torch::nullopt); |
2246 | |
2247 | auto options = F::InterpolateFuncOptions() |
2248 | .size(osize) |
2249 | .mode(torch::kBilinear) |
2250 | .align_corners(false) |
2251 | .antialias(true); |
2252 | auto output = F::interpolate(tensor, options); |
2253 | ASSERT_TRUE(output.allclose(expected)); |
2254 | } |
2255 | { |
2256 | auto tensor = torch::rand({2, 3, 32, 32}); |
2257 | std::vector<int64_t> osize = {8, 10}; |
2258 | auto expected = at::native::_upsample_bicubic2d_aa( |
2259 | tensor, osize, false, torch::nullopt); |
2260 | |
2261 | auto options = F::InterpolateFuncOptions() |
2262 | .size(osize) |
2263 | .mode(torch::kBicubic) |
2264 | .align_corners(false) |
2265 | .antialias(true); |
2266 | auto output = F::interpolate(tensor, options); |
2267 | ASSERT_TRUE(output.allclose(expected)); |
2268 | } |
2269 | } |
2270 | |
2271 | TEST_F(FunctionalTest, Pad1) { |
2272 | { |
2273 | auto input = torch::arange(6, torch::kDouble).reshape({1, 2, 3}); |
2274 | auto output = |
2275 | F::pad(input, F::PadFuncOptions({1, 2}).mode(torch::kCircular)); |
2276 | auto expected = torch::tensor( |
2277 | {{{2., 0., 1., 2., 0., 1.}, {5., 3., 4., 5., 3., 4.}}}, torch::kDouble); |
2278 | ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 2, 6})); |
2279 | ASSERT_TRUE(output.allclose(expected, 1e-04)); |
2280 | } |
2281 | } |
2282 | TEST_F(FunctionalTest, Pad2) { |
2283 | { |
2284 | auto input = torch::arange(9, torch::kDouble).reshape({1, 1, 3, 3}); |
2285 | auto output = |
2286 | F::pad(input, F::PadFuncOptions({3, 3, 3, 1}).mode(torch::kCircular)); |
2287 | auto expected = torch::tensor( |
2288 | {{{{0., 1., 2., 0., 1., 2., 0., 1., 2.}, |
2289 | {3., 4., 5., 3., 4., 5., 3., 4., 5.}, |
2290 | {6., 7., 8., 6., 7., 8., 6., 7., 8.}, |
2291 | {0., 1., 2., 0., 1., 2., 0., 1., 2.}, |
2292 | {3., 4., 5., 3., 4., 5., 3., 4., 5.}, |
2293 | {6., 7., 8., 6., 7., 8., 6., 7., 8.}, |
2294 | {0., 1., 2., 0., 1., 2., 0., 1., 2.}}}}, |
2295 | torch::kDouble); |
2296 | ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 7, 9})); |
2297 | ASSERT_TRUE(output.allclose(expected, 1e-04)); |
2298 | } |
2299 | } |
2300 | TEST_F(FunctionalTest, Pad3) { |
2301 | { |
2302 | auto input = torch::arange(12, torch::kDouble).reshape({1, 1, 2, 2, 3}); |
2303 | auto output = F::pad( |
2304 | input, F::PadFuncOptions({3, 3, 2, 1, 2, 2}).mode(torch::kCircular)); |
2305 | auto expected = torch::tensor( |
2306 | {{{{{0., 1., 2., 0., 1., 2., 0., 1., 2.}, |
2307 | {3., 4., 5., 3., 4., 5., 3., 4., 5.}, |
2308 | {0., 1., 2., 0., 1., 2., 0., 1., 2.}, |
2309 | {3., 4., 5., 3., 4., 5., 3., 4., 5.}, |
2310 | {0., 1., 2., 0., 1., 2., 0., 1., 2.}}, |
2311 | |
2312 | {{6., 7., 8., 6., 7., 8., 6., 7., 8.}, |
2313 | {9., 10., 11., 9., 10., 11., 9., 10., 11.}, |
2314 | {6., 7., 8., 6., 7., 8., 6., 7., 8.}, |
2315 | {9., 10., 11., 9., 10., 11., 9., 10., 11.}, |
2316 | {6., 7., 8., 6., 7., 8., 6., 7., 8.}}, |
2317 | |
2318 | {{0., 1., 2., 0., 1., 2., 0., 1., 2.}, |
2319 | {3., 4., 5., 3., 4., 5., 3., 4., 5.}, |
2320 | {0., 1., 2., 0., 1., 2., 0., 1., 2.}, |
2321 | {3., 4., 5., 3., 4., 5., 3., 4., 5.}, |
2322 | {0., 1., 2., 0., 1., 2., 0., 1., 2.}}, |
2323 | |
2324 | {{6., 7., 8., 6., 7., 8., 6., 7., 8.}, |
2325 | {9., 10., 11., 9., 10., 11., 9., 10., 11.}, |
2326 | {6., 7., 8., 6., 7., 8., 6., 7., 8.}, |
2327 | {9., 10., 11., 9., 10., 11., 9., 10., 11.}, |
2328 | {6., 7., 8., 6., 7., 8., 6., 7., 8.}}, |
2329 | |
2330 | {{0., 1., 2., 0., 1., 2., 0., 1., 2.}, |
2331 | {3., 4., 5., 3., 4., 5., 3., 4., 5.}, |
2332 | {0., 1., 2., 0., 1., 2., 0., 1., 2.}, |
2333 | {3., 4., 5., 3., 4., 5., 3., 4., 5.}, |
2334 | {0., 1., 2., 0., 1., 2., 0., 1., 2.}}, |
2335 | |
2336 | {{6., 7., 8., 6., 7., 8., 6., 7., 8.}, |
2337 | {9., 10., 11., 9., 10., 11., 9., 10., 11.}, |
2338 | {6., 7., 8., 6., 7., 8., 6., 7., 8.}, |
2339 | {9., 10., 11., 9., 10., 11., 9., 10., 11.}, |
2340 | {6., 7., 8., 6., 7., 8., 6., 7., 8.}}}}}, |
2341 | torch::kDouble); |
2342 | ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 6, 5, 9})); |
2343 | ASSERT_TRUE(output.allclose(expected, 1e-04)); |
2344 | } |
2345 | } |
2346 | TEST_F(FunctionalTest, Pad4) { |
2347 | { |
2348 | auto input = torch::arange(16, torch::kDouble).reshape({2, 2, 2, 2}); |
2349 | auto output = |
2350 | F::pad(input, F::PadFuncOptions({1, 1, 1, 1}).mode(torch::kReflect)); |
2351 | auto expected = torch::tensor( |
2352 | {{{{3., 2., 3., 2.}, |
2353 | {1., 0., 1., 0.}, |
2354 | {3., 2., 3., 2.}, |
2355 | {1., 0., 1., 0.}}, |
2356 | |
2357 | {{7., 6., 7., 6.}, |
2358 | {5., 4., 5., 4.}, |
2359 | {7., 6., 7., 6.}, |
2360 | {5., 4., 5., 4.}}}, |
2361 | |
2362 | {{{11., 10., 11., 10.}, |
2363 | {9., 8., 9., 8.}, |
2364 | {11., 10., 11., 10.}, |
2365 | {9., 8., 9., 8.}}, |
2366 | |
2367 | {{15., 14., 15., 14.}, |
2368 | {13., 12., 13., 12.}, |
2369 | {15., 14., 15., 14.}, |
2370 | {13., 12., 13., 12.}}}}, |
2371 | torch::kDouble); |
2372 | ASSERT_EQ(output.sizes(), std::vector<int64_t>({2, 2, 4, 4})); |
2373 | ASSERT_TRUE(output.allclose(expected, 1e-04)); |
2374 | } |
2375 | } |
2376 | TEST_F(FunctionalTest, Pad5) { |
2377 | { |
2378 | auto input = torch::arange(12, torch::kDouble).reshape({1, 1, 2, 2, 3}); |
2379 | auto output = F::pad( |
2380 | input, F::PadFuncOptions({1, 2, 2, 1, 1, 2}).mode(torch::kReplicate)); |
2381 | auto expected = torch::tensor( |
2382 | {{{{{0., 0., 1., 2., 2., 2.}, |
2383 | {0., 0., 1., 2., 2., 2.}, |
2384 | {0., 0., 1., 2., 2., 2.}, |
2385 | {3., 3., 4., 5., 5., 5.}, |
2386 | {3., 3., 4., 5., 5., 5.}}, |
2387 | |
2388 | {{0., 0., 1., 2., 2., 2.}, |
2389 | {0., 0., 1., 2., 2., 2.}, |
2390 | {0., 0., 1., 2., 2., 2.}, |
2391 | {3., 3., 4., 5., 5., 5.}, |
2392 | {3., 3., 4., 5., 5., 5.}}, |
2393 | |
2394 | {{6., 6., 7., 8., 8., 8.}, |
2395 | {6., 6., 7., 8., 8., 8.}, |
2396 | {6., 6., 7., 8., 8., 8.}, |
2397 | {9., 9., 10., 11., 11., 11.}, |
2398 | {9., 9., 10., 11., 11., 11.}}, |
2399 | |
2400 | {{6., 6., 7., 8., 8., 8.}, |
2401 | {6., 6., 7., 8., 8., 8.}, |
2402 | {6., 6., 7., 8., 8., 8.}, |
2403 | {9., 9., 10., 11., 11., 11.}, |
2404 | {9., 9., 10., 11., 11., 11.}}, |
2405 | |
2406 | {{6., 6., 7., 8., 8., 8.}, |
2407 | {6., 6., 7., 8., 8., 8.}, |
2408 | {6., 6., 7., 8., 8., 8.}, |
2409 | {9., 9., 10., 11., 11., 11.}, |
2410 | {9., 9., 10., 11., 11., 11.}}}}}, |
2411 | torch::kDouble); |
2412 | ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 5, 5, 6})); |
2413 | ASSERT_TRUE(output.allclose(expected, 1e-04)); |
2414 | } |
2415 | } |
2416 | TEST_F(FunctionalTest, Pad6) { |
2417 | { |
2418 | auto input = torch::arange(18, torch::kDouble).reshape({1, 1, 3, 2, 3}); |
2419 | auto output = F::pad( |
2420 | input, F::PadFuncOptions({0, 2, 1, 0, 1, 2}).mode(torch::kReflect)); |
2421 | auto expected = torch::tensor( |
2422 | {{{{{9., 10., 11., 10., 9.}, |
2423 | {6., 7., 8., 7., 6.}, |
2424 | {9., 10., 11., 10., 9.}}, |
2425 | |
2426 | {{3., 4., 5., 4., 3.}, {0., 1., 2., 1., 0.}, {3., 4., 5., 4., 3.}}, |
2427 | |
2428 | {{9., 10., 11., 10., 9.}, |
2429 | {6., 7., 8., 7., 6.}, |
2430 | {9., 10., 11., 10., 9.}}, |
2431 | |
2432 | {{15., 16., 17., 16., 15.}, |
2433 | {12., 13., 14., 13., 12.}, |
2434 | {15., 16., 17., 16., 15.}}, |
2435 | |
2436 | {{9., 10., 11., 10., 9.}, |
2437 | {6., 7., 8., 7., 6.}, |
2438 | {9., 10., 11., 10., 9.}}, |
2439 | |
2440 | {{3., 4., 5., 4., 3.}, |
2441 | {0., 1., 2., 1., 0.}, |
2442 | {3., 4., 5., 4., 3.}}}}}, |
2443 | torch::kDouble); |
2444 | ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 6, 3, 5})); |
2445 | ASSERT_TRUE(output.allclose(expected, 1e-04)); |
2446 | } |
2447 | } |
2448 | TEST_F(FunctionalTest, Pad7) { |
2449 | { |
2450 | auto input = torch::ones({1, 1, 1, 1}, torch::kDouble); |
2451 | auto output = F::pad( |
2452 | input, F::PadFuncOptions({1, 1}).mode(torch::kConstant).value(0)); |
2453 | ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 1, 3})); |
2454 | auto expected = torch::tensor({{{{0., 1., 0.}}}}, torch::kDouble); |
2455 | } |
2456 | } |
2457 | TEST_F(FunctionalTest, Pad8) { |
2458 | { |
2459 | auto input = torch::ones({1, 1, 1, 1}, torch::kDouble); |
2460 | auto output = F::pad(input, F::PadFuncOptions({1, 1})); |
2461 | ASSERT_EQ(output.sizes(), std::vector<int64_t>({1, 1, 1, 3})); |
2462 | auto expected = torch::tensor({{{{0., 1., 0.}}}}, torch::kDouble); |
2463 | } |
2464 | } |
2465 | |
2466 | TEST_F(FunctionalTest, CTCLoss) { |
2467 | { // test CTCLoss typechecks |
2468 | const auto target_lengths = torch::tensor({30, 25, 20}); |
2469 | const auto input_lengths = torch::tensor({50, 50, 50}); |
2470 | const auto targets = |
2471 | torch::randint(1, 15, {target_lengths.sum().item<int>()}, torch::kInt); |
2472 | const auto log_probs = |
2473 | torch::randn({50, 3, 15}, torch::kFloat).log_softmax(2); |
2474 | |
2475 | const auto _input_lengths = input_lengths.to(torch::kFloat); |
2476 | ASSERT_THROWS_WITH( |
2477 | F::ctc_loss(log_probs, targets, _input_lengths, target_lengths), |
2478 | "input_lengths must be integral" ); |
2479 | |
2480 | const auto target_lengths_ = target_lengths.to(torch::kFloat); |
2481 | ASSERT_THROWS_WITH( |
2482 | F::ctc_loss(log_probs, targets, input_lengths, target_lengths_), |
2483 | "target_lengths must be integral" ); |
2484 | } |
2485 | { // test CTCLoss length checks |
2486 | const auto target_lengths = torch::tensor({30, 25, 20}); |
2487 | const auto input_lengths = torch::tensor({50, 50, 50}); |
2488 | const auto targets = torch::randint(1, 15, {3, 29}, torch::kInt); |
2489 | const auto log_probs = |
2490 | torch::randn({50, 3, 15}, torch::kFloat).log_softmax(2); |
2491 | ASSERT_THROWS_WITH( |
2492 | F::ctc_loss(log_probs, targets, input_lengths, target_lengths), |
2493 | "Expected tensor to have size at least 30 at dimension 1" ); |
2494 | } |
2495 | { // test CTCLoss empty target |
2496 | { |
2497 | const auto target_lengths = torch::tensor({0, 0, 0}); |
2498 | const auto input_lengths = torch::tensor({50, 50, 50}); |
2499 | const auto targets = |
2500 | torch::randint(1, 15, at::IntArrayRef({0}), torch::kLong); |
2501 | const auto log_probs = |
2502 | torch::randn({50, 3, 15}, torch::kDouble).log_softmax(2); |
2503 | const auto loss = F::ctc_loss( |
2504 | log_probs, |
2505 | targets, |
2506 | input_lengths, |
2507 | target_lengths, |
2508 | F::CTCLossFuncOptions().reduction(torch::kNone)); |
2509 | ASSERT_TRUE(loss.ge(0).all().item<bool>()); |
2510 | ASSERT_TRUE(torch::allclose( |
2511 | -log_probs.sum(0).slice(1, 0, 1).view_as(loss), loss)); |
2512 | } |
2513 | { |
2514 | const auto target_lengths = torch::tensor({0, 9, 0}); |
2515 | const auto input_lengths = torch::tensor({50, 50, 50}); |
2516 | const auto targets = torch::randint(1, 15, {9}, torch::kLong); |
2517 | const auto log_probs = |
2518 | torch::randn({50, 3, 15}, torch::kDouble).log_softmax(2); |
2519 | const auto loss = F::ctc_loss( |
2520 | log_probs, |
2521 | targets, |
2522 | input_lengths, |
2523 | target_lengths, |
2524 | F::CTCLossFuncOptions().reduction(torch::kNone)); |
2525 | ASSERT_TRUE(loss.ge(0).all().item<bool>()); |
2526 | ASSERT_TRUE(torch::allclose( |
2527 | -log_probs.sum(0) |
2528 | .index_select(0, torch::tensor({0, 2}, torch::kLong)) |
2529 | .slice(1, 0, 1) |
2530 | .view({2}), |
2531 | loss.index_select(0, torch::tensor({0, 2}, torch::kLong)))); |
2532 | } |
2533 | } |
2534 | } |
2535 | |
2536 | TEST_F(FunctionalTest, PoissonNLLLoss) { |
2537 | const auto input = torch::tensor({0.5, 1.5, 2.5}); |
2538 | const auto target = torch::tensor({1., 2., 3.}); |
2539 | const auto component_wise_loss = torch::exp(input) - target * input; |
2540 | ASSERT_TRUE(torch::allclose( |
2541 | torch::mean(component_wise_loss), F::poisson_nll_loss(input, target))); |
2542 | ASSERT_TRUE(torch::allclose( |
2543 | component_wise_loss, |
2544 | F::poisson_nll_loss( |
2545 | input, |
2546 | target, |
2547 | F::PoissonNLLLossFuncOptions().reduction(torch::kNone)))); |
2548 | ASSERT_TRUE(torch::allclose( |
2549 | torch::sum(component_wise_loss), |
2550 | F::poisson_nll_loss( |
2551 | input, |
2552 | target, |
2553 | F::PoissonNLLLossFuncOptions().reduction(torch::kSum)))); |
2554 | ASSERT_TRUE(torch::allclose( |
2555 | torch::mean(component_wise_loss), |
2556 | F::poisson_nll_loss( |
2557 | input, |
2558 | target, |
2559 | F::PoissonNLLLossFuncOptions().reduction(torch::kMean)))); |
2560 | } |
2561 | |
2562 | TEST_F(FunctionalTest, MarginRankingLoss) { |
2563 | { |
2564 | const auto input1 = torch::randn(15) * 10; |
2565 | const auto input2 = torch::randn(15) * 10; |
2566 | const auto target = torch::randn(15).sign(); |
2567 | ASSERT_TRUE(torch::allclose( |
2568 | F::margin_ranking_loss(input1, input2, target), |
2569 | (-target * (input1 - input2)).clamp(0).mean())); |
2570 | } |
2571 | { |
2572 | const auto input1 = torch::randn(15) * 10; |
2573 | const auto input2 = torch::randn(15) * 10; |
2574 | const auto target = torch::randn(15).sign(); |
2575 | const auto margin = 0.5; |
2576 | ASSERT_TRUE(torch::allclose( |
2577 | F::margin_ranking_loss( |
2578 | input1, |
2579 | input2, |
2580 | target, |
2581 | F::MarginRankingLossFuncOptions().margin(0.5).reduction( |
2582 | torch::kSum)), |
2583 | (-target * (input1 - input2) + margin).clamp(0).sum())); |
2584 | } |
2585 | { |
2586 | const auto input1 = torch::randn(15) * 10; |
2587 | const auto input2 = torch::randn(15) * 10; |
2588 | const auto target = torch::randn(15).sign(); |
2589 | const auto margin = 0.5; |
2590 | ASSERT_TRUE(torch::allclose( |
2591 | F::margin_ranking_loss( |
2592 | input1, |
2593 | input2, |
2594 | target, |
2595 | F::MarginRankingLossFuncOptions().margin(0.5).reduction( |
2596 | torch::kMean)), |
2597 | (-target * (input1 - input2) + margin).clamp(0).mean())); |
2598 | } |
2599 | } |
2600 | |
2601 | TEST_F(FunctionalTest, ConvTranspose1d) { |
2602 | auto x = torch::arange(20.).view({2, 2, 5}); |
2603 | auto weight = torch::arange(18.).view({2, 3, 3}); |
2604 | auto y = |
2605 | F::conv_transpose1d(x, weight, F::ConvTranspose1dFuncOptions().stride(1)); |
2606 | auto expected = torch::tensor( |
2607 | {{{45., 104., 179., 212., 245., 188., 107.}, |
2608 | {60., 140., 242., 293., 344., 260., 146.}, |
2609 | {75., 176., 305., 374., 443., 332., 185.}}, |
2610 | {{135., 304., 509., 542., 575., 428., 237.}, |
2611 | {210., 460., 752., 803., 854., 620., 336.}, |
2612 | {285., 616., 995., 1064., 1133., 812., 435.}}}); |
2613 | ASSERT_TRUE(torch::allclose(y, expected)); |
2614 | |
2615 | auto y_no_options = F::conv_transpose1d(x, weight); |
2616 | ASSERT_TRUE(torch::allclose(y_no_options, expected)); |
2617 | } |
2618 | |
2619 | TEST_F(FunctionalTest, ConvTranspose2dEven) { |
2620 | auto x = torch::arange(50.).view({1, 2, 5, 5}); |
2621 | auto weight = torch::arange(54.).view({2, 3, 3, 3}); |
2622 | auto y = |
2623 | F::conv_transpose2d(x, weight, F::ConvTranspose2dFuncOptions().stride(1)); |
2624 | auto expected = torch::tensor( |
2625 | {{{{675., 1402., 2183., 2270., 2357., 1634., 849.}, |
2626 | {1560., 3240., 5044., 5236., 5428., 3760., 1952.}, |
2627 | {2685., 5574., 8673., 8988., 9303., 6438., 3339.}, |
2628 | {3180., 6594., 10248., 10563., 10878., 7518., 3894.}, |
2629 | {3675., 7614., 11823., 12138., 12453., 8598., 4449.}, |
2630 | {2820., 5832., 9040., 9268., 9496., 6544., 3380.}, |
2631 | {1605., 3314., 5129., 5252., 5375., 3698., 1907.}}, |
2632 | {{900., 1870., 2912., 3053., 3194., 2210., 1146.}, |
2633 | {2100., 4356., 6772., 7072., 7372., 5092., 2636.}, |
2634 | {3630., 7518., 11670., 12147., 12624., 8706., 4500.}, |
2635 | {4395., 9078., 14055., 14532., 15009., 10326., 5325.}, |
2636 | {5160., 10638., 16440., 16917., 17394., 11946., 6150.}, |
2637 | {3900., 8028., 12388., 12724., 13060., 8956., 4604.}, |
2638 | {2190., 4502., 6938., 7115., 7292., 4994., 2564.}}, |
2639 | {{1125., 2338., 3641., 3836., 4031., 2786., 1443.}, |
2640 | {2640., 5472., 8500., 8908., 9316., 6424., 3320.}, |
2641 | {4575., 9462., 14667., 15306., 15945., 10974., 5661.}, |
2642 | {5610., 11562., 17862., 18501., 19140., 13134., 6756.}, |
2643 | {6645., 13662., 21057., 21696., 22335., 15294., 7851.}, |
2644 | {4980., 10224., 15736., 16180., 16624., 11368., 5828.}, |
2645 | {2775., 5690., 8747., 8978., 9209., 6290., 3221.}}}}); |
2646 | ASSERT_TRUE(torch::allclose(y, expected)); |
2647 | |
2648 | auto y_no_options = F::conv_transpose2d(x, weight); |
2649 | ASSERT_TRUE(torch::allclose(y_no_options, expected)); |
2650 | } |
2651 | |
2652 | TEST_F(FunctionalTest, ConvTranspose2dUneven) { |
2653 | auto x = torch::arange(40.).view({1, 2, 5, 4}); |
2654 | auto weight = torch::arange(36.).view({2, 3, 3, 2}); |
2655 | auto y = |
2656 | F::conv_transpose2d(x, weight, F::ConvTranspose2dFuncOptions().stride(1)); |
2657 | auto expected = torch::tensor( |
2658 | {{{{360., 758., 796., 834., 440.}, |
2659 | {832., 1752., 1836., 1920., 1012.}, |
2660 | {1432., 3014., 3152., 3290., 1732.}, |
2661 | {1696., 3566., 3704., 3842., 2020.}, |
2662 | {1960., 4118., 4256., 4394., 2308.}, |
2663 | {1504., 3152., 3252., 3352., 1756.}, |
2664 | {856., 1790., 1844., 1898., 992.}}, |
2665 | {{480., 1010., 1072., 1134., 596.}, |
2666 | {1120., 2352., 2484., 2616., 1372.}, |
2667 | {1936., 4058., 4268., 4478., 2344.}, |
2668 | {2344., 4898., 5108., 5318., 2776.}, |
2669 | {2752., 5738., 5948., 6158., 3208.}, |
2670 | {2080., 4328., 4476., 4624., 2404.}, |
2671 | {1168., 2426., 2504., 2582., 1340.}}, |
2672 | {{600., 1262., 1348., 1434., 752.}, |
2673 | {1408., 2952., 3132., 3312., 1732.}, |
2674 | {2440., 5102., 5384., 5666., 2956.}, |
2675 | {2992., 6230., 6512., 6794., 3532.}, |
2676 | {3544., 7358., 7640., 7922., 4108.}, |
2677 | {2656., 5504., 5700., 5896., 3052.}, |
2678 | {1480., 3062., 3164., 3266., 1688.}}}}); |
2679 | ASSERT_TRUE(torch::allclose(y, expected)); |
2680 | |
2681 | auto y_no_options = F::conv_transpose2d(x, weight); |
2682 | ASSERT_TRUE(torch::allclose(y_no_options, expected)); |
2683 | } |
2684 | |
2685 | TEST_F(FunctionalTest, ConvTranspose3d) { |
2686 | auto x = torch::arange(16.).view({1, 2, 2, 2, 2}); |
2687 | auto weight = torch::arange(32.).view({2, 2, 2, 2, 2}); |
2688 | auto y = |
2689 | F::conv_transpose3d(x, weight, F::ConvTranspose3dFuncOptions().stride(1)); |
2690 | auto expected = torch::tensor( |
2691 | {{{{{128., 280., 154.}, {304., 664., 364.}, {184., 400., 218.}}, |
2692 | {{352., 768., 420.}, {832., 1808., 984.}, {496., 1072., 580.}}, |
2693 | {{256., 552., 298.}, {592., 1272., 684.}, {344., 736., 394.}}}, |
2694 | {{{192., 424., 234.}, {464., 1016., 556.}, {280., 608., 330.}}, |
2695 | {{544., 1184., 644.}, {1280., 2768., 1496.}, {752., 1616., 868.}}, |
2696 | {{384., 824., 442.}, {880., 1880., 1004.}, {504., 1072., 570.}}}}}); |
2697 | ASSERT_TRUE(torch::allclose(y, expected)); |
2698 | |
2699 | auto y_no_options = F::conv_transpose3d(x, weight); |
2700 | ASSERT_TRUE(torch::allclose(y_no_options, expected)); |
2701 | } |
2702 | |
2703 | TEST_F(FunctionalTest, AlphaDropout) { |
2704 | auto input = torch::randn(5000); |
2705 | auto input_mean = input.mean(); |
2706 | auto input_std = input.std(); |
2707 | |
2708 | for (const auto rate : {0.2, 0.5, 0.8}) { |
2709 | for (const auto inplace : {false, true}) { |
2710 | auto input_ = input.clone(); |
2711 | auto output = F::alpha_dropout( |
2712 | input_, |
2713 | F::AlphaDropoutFuncOptions().p(rate).training(false).inplace( |
2714 | inplace)); |
2715 | ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.1)); |
2716 | ASSERT_TRUE(torch::allclose(input_std, output.std(), 0.1)); |
2717 | if (inplace) { |
2718 | ASSERT_TRUE(torch::allclose(input_, output)); |
2719 | } |
2720 | } |
2721 | } |
2722 | auto output = F::detail::alpha_dropout(input, 0.5, false, false); |
2723 | ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.1)); |
2724 | ASSERT_TRUE(torch::allclose(input_std, output.std(), 0.1)); |
2725 | } |
2726 | |
2727 | TEST_F(FunctionalTest, FeatureAlphaDropout) { |
2728 | auto input = torch::randn(5000); |
2729 | auto input_mean = input.mean(); |
2730 | auto input_std = input.std(); |
2731 | |
2732 | for (const auto rate : {0.2, 0.5, 0.8}) { |
2733 | for (const auto inplace : {false, true}) { |
2734 | auto input_ = input.clone(); |
2735 | auto output = F::feature_alpha_dropout( |
2736 | input_, |
2737 | F::FeatureAlphaDropoutFuncOptions().p(rate).training(false).inplace( |
2738 | inplace)); |
2739 | ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.1)); |
2740 | ASSERT_TRUE(torch::allclose(input_std, output.std(), 0.1)); |
2741 | if (inplace) { |
2742 | ASSERT_TRUE(torch::allclose(input_, output)); |
2743 | } |
2744 | } |
2745 | } |
2746 | auto output = F::feature_alpha_dropout(input); |
2747 | ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.1)); |
2748 | ASSERT_TRUE(torch::allclose(input_std, output.std(), 0.1)); |
2749 | } |
2750 | |
2751 | TEST_F(FunctionalTest, Dropout) { |
2752 | auto input = torch::randn(5000); |
2753 | auto input_mean = input.mean(); |
2754 | auto input_std = input.std(); |
2755 | |
2756 | for (const auto rate : {0.2, 0.5, 0.8}) { |
2757 | auto output = F::dropout(input, F::DropoutFuncOptions().p(rate)); |
2758 | ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05)); |
2759 | ASSERT_TRUE((input_std <= output.std()).all().item<bool>()); |
2760 | } |
2761 | auto output = F::dropout(input); |
2762 | ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05)); |
2763 | ASSERT_TRUE((input_std <= output.std()).all().item<bool>()); |
2764 | ASSERT_TRUE(F::dropout(torch::tensor(1.)).defined()); |
2765 | } |
2766 | |
2767 | TEST_F(FunctionalTest, Dropout2d) { |
2768 | auto input = torch::randn({2, 2, 50, 100}); |
2769 | auto input_mean = input.mean(); |
2770 | auto input_std = input.std(); |
2771 | |
2772 | for (const auto rate : {0.2, 0.5, 0.8}) { |
2773 | auto output = F::dropout2d(input, F::Dropout2dFuncOptions().p(rate)); |
2774 | ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05)); |
2775 | } |
2776 | auto output = F::dropout2d(input); |
2777 | ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05)); |
2778 | ASSERT_TRUE(F::dropout2d(torch::randn({2, 50, 100})).defined()); |
2779 | } |
2780 | |
2781 | TEST_F(FunctionalTest, Dropout3d) { |
2782 | auto input = torch::randn({2, 2, 50, 10, 10}); |
2783 | auto input_mean = input.mean(); |
2784 | auto input_std = input.std(); |
2785 | |
2786 | for (const auto rate : {0.2, 0.5, 0.8}) { |
2787 | auto output = F::dropout3d(input, F::Dropout3dFuncOptions().p(rate)); |
2788 | ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05)); |
2789 | } |
2790 | auto output = F::dropout3d(input); |
2791 | ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05)); |
2792 | ASSERT_TRUE(F::dropout3d(torch::randn({2, 50, 10, 10})).defined()); |
2793 | } |
2794 | |
2795 | template <c10::ScalarType S, typename T> |
2796 | void test_isfinite(const at::Device& device) { |
2797 | const std::vector<T> values = { |
2798 | std::numeric_limits<T>::lowest(), |
2799 | 0, |
2800 | 1, |
2801 | 42, |
2802 | std::numeric_limits<T>::min(), |
2803 | std::numeric_limits<T>::max()}; |
2804 | for (const auto value : values) { |
2805 | const auto x = torch::full( |
2806 | {3, 3}, value, torch::TensorOptions().dtype(S).device(device)); |
2807 | ASSERT_TRUE(torch::isfinite(x).all().template item<bool>()); |
2808 | } |
2809 | if (std::numeric_limits<T>::has_infinity) { |
2810 | const auto inf = std::numeric_limits<T>::infinity(); |
2811 | const auto x = torch::tensor( |
2812 | {-inf, |
2813 | std::numeric_limits<T>::lowest(), |
2814 | static_cast<T>(0), |
2815 | static_cast<T>(1), |
2816 | static_cast<T>(42), |
2817 | std::numeric_limits<T>::min(), |
2818 | std::numeric_limits<T>::max(), |
2819 | inf}, |
2820 | torch::TensorOptions().dtype(S).device(device)); |
2821 | ASSERT_TRUE(torch::allclose( |
2822 | // torch::allclose does not support comparing torch::kBool |
2823 | torch::isfinite(x).toType(torch::kInt), |
2824 | torch::tensor( |
2825 | {false, true, true, true, true, true, true, false}, |
2826 | torch::TensorOptions().device(device)) |
2827 | .toType(torch::kInt))); |
2828 | } |
2829 | if (std::numeric_limits<T>::has_quiet_NaN) { |
2830 | const auto x = torch::tensor( |
2831 | {std::numeric_limits<T>::quiet_NaN()}, |
2832 | torch::TensorOptions().dtype(S).device(device)); |
2833 | ASSERT_FALSE(torch::isfinite(x).all().template item<bool>()); |
2834 | } |
2835 | if (std::numeric_limits<T>::has_signaling_NaN) { |
2836 | const auto x = torch::tensor( |
2837 | {std::numeric_limits<T>::signaling_NaN()}, |
2838 | torch::TensorOptions().dtype(S).device(device)); |
2839 | ASSERT_FALSE(torch::isfinite(x).all().template item<bool>()); |
2840 | } |
2841 | } |
2842 | |
2843 | TEST_F(FunctionalTest, isfinite) { |
2844 | const at::Device device("cpu" ); |
2845 | test_isfinite<torch::kUInt8, uint8_t>(device); |
2846 | test_isfinite<torch::kInt8, int8_t>(device); |
2847 | test_isfinite<torch::kInt16, int16_t>(device); |
2848 | test_isfinite<torch::kInt32, int32_t>(device); |
2849 | test_isfinite<torch::kInt64, int64_t>(device); |
2850 | test_isfinite<torch::kFloat32, float>(device); |
2851 | test_isfinite<torch::kFloat64, double>(device); |
2852 | } |
2853 | |
2854 | TEST_F(FunctionalTest, isfinite_CUDA) { |
2855 | const at::Device device("cuda" ); |
2856 | test_isfinite<torch::kUInt8, uint8_t>(device); |
2857 | test_isfinite<torch::kInt8, int8_t>(device); |
2858 | test_isfinite<torch::kInt16, int16_t>(device); |
2859 | test_isfinite<torch::kInt32, int32_t>(device); |
2860 | test_isfinite<torch::kInt64, int64_t>(device); |
2861 | test_isfinite<torch::kFloat32, float>(device); |
2862 | test_isfinite<torch::kFloat64, double>(device); |
2863 | test_isfinite<torch::kFloat16, c10::Half>(device); |
2864 | } |
2865 | |
2866 | template <c10::ScalarType S, typename T> |
2867 | void test_isinf(const at::Device& device) { |
2868 | const std::vector<T> values = { |
2869 | std::numeric_limits<T>::lowest(), |
2870 | 0, |
2871 | 1, |
2872 | 42, |
2873 | std::numeric_limits<T>::min(), |
2874 | std::numeric_limits<T>::max()}; |
2875 | for (const auto value : values) { |
2876 | const auto x = torch::full( |
2877 | {3, 3}, value, torch::TensorOptions().dtype(S).device(device)); |
2878 | ASSERT_FALSE(torch::isinf(x).all().template item<bool>()); |
2879 | } |
2880 | if (std::numeric_limits<T>::has_infinity) { |
2881 | const auto inf = std::numeric_limits<T>::infinity(); |
2882 | const auto x = torch::tensor( |
2883 | {-inf, |
2884 | std::numeric_limits<T>::lowest(), |
2885 | static_cast<T>(0), |
2886 | static_cast<T>(1), |
2887 | static_cast<T>(42), |
2888 | std::numeric_limits<T>::min(), |
2889 | std::numeric_limits<T>::max(), |
2890 | inf}, |
2891 | torch::TensorOptions().dtype(S).device(device)); |
2892 | ASSERT_TRUE(torch::allclose( |
2893 | // torch::allclose does not support comparing torch::kBool |
2894 | torch::isinf(x).toType(torch::kInt), |
2895 | torch::tensor( |
2896 | {true, false, false, false, false, false, false, true}, |
2897 | torch::TensorOptions().device(device)) |
2898 | .toType(torch::kInt))); |
2899 | } |
2900 | if (std::numeric_limits<T>::has_quiet_NaN) { |
2901 | const auto x = torch::tensor( |
2902 | {std::numeric_limits<T>::quiet_NaN()}, |
2903 | torch::TensorOptions().dtype(S).device(device)); |
2904 | ASSERT_FALSE(torch::isinf(x).all().template item<bool>()); |
2905 | } |
2906 | if (std::numeric_limits<T>::has_signaling_NaN) { |
2907 | const auto x = torch::tensor( |
2908 | {std::numeric_limits<T>::signaling_NaN()}, |
2909 | torch::TensorOptions().dtype(S).device(device)); |
2910 | ASSERT_FALSE(torch::isinf(x).all().template item<bool>()); |
2911 | } |
2912 | } |
2913 | |
2914 | TEST_F(FunctionalTest, isinf) { |
2915 | const at::Device device("cpu" ); |
2916 | test_isinf<torch::kUInt8, uint8_t>(device); |
2917 | test_isinf<torch::kInt8, int8_t>(device); |
2918 | test_isinf<torch::kInt16, int16_t>(device); |
2919 | test_isinf<torch::kInt32, int32_t>(device); |
2920 | test_isinf<torch::kInt64, int64_t>(device); |
2921 | test_isinf<torch::kFloat32, float>(device); |
2922 | test_isinf<torch::kFloat64, double>(device); |
2923 | } |
2924 | |
2925 | TEST_F(FunctionalTest, isinf_CUDA) { |
2926 | const at::Device device("cuda" ); |
2927 | test_isinf<torch::kUInt8, uint8_t>(device); |
2928 | test_isinf<torch::kInt8, int8_t>(device); |
2929 | test_isinf<torch::kInt16, int16_t>(device); |
2930 | test_isinf<torch::kInt32, int32_t>(device); |
2931 | test_isinf<torch::kInt64, int64_t>(device); |
2932 | test_isinf<torch::kFloat32, float>(device); |
2933 | test_isinf<torch::kFloat64, double>(device); |
2934 | test_isinf<torch::kFloat16, c10::Half>(device); |
2935 | } |
2936 | |
2937 | template <c10::ScalarType S, typename T> |
2938 | void test_allclose(const at::Device& device) { |
2939 | const std::vector<T> values = { |
2940 | std::numeric_limits<T>::lowest(), |
2941 | 0, |
2942 | 1, |
2943 | 42, |
2944 | std::numeric_limits<T>::min(), |
2945 | std::numeric_limits<T>::max()}; |
2946 | for (const auto value : values) { |
2947 | const auto x = |
2948 | torch::full({1}, value, torch::TensorOptions().dtype(S).device(device)); |
2949 | const auto y = |
2950 | torch::full({1}, value, torch::TensorOptions().dtype(S).device(device)); |
2951 | ASSERT_TRUE(torch::allclose(x, x)); |
2952 | ASSERT_TRUE(torch::allclose(x, y)); |
2953 | ASSERT_TRUE(torch::allclose(y, x)); |
2954 | ASSERT_FALSE(torch::allclose(1.1 * x + 0.1, 1.0 * x)); |
2955 | ASSERT_TRUE(torch::allclose(0.99 * x + 0.1, 1.0 * x, 1.1, 0.1)); |
2956 | } |
2957 | if (std::numeric_limits<T>::has_infinity) { |
2958 | const auto inf = std::numeric_limits<T>::infinity(); |
2959 | const auto x = torch::tensor( |
2960 | {-inf, inf}, torch::TensorOptions().dtype(S).device(device)); |
2961 | const auto y = torch::tensor( |
2962 | {-inf, inf}, torch::TensorOptions().dtype(S).device(device)); |
2963 | ASSERT_TRUE(torch::allclose(x, x)); |
2964 | ASSERT_TRUE(torch::allclose(x, y)); |
2965 | ASSERT_TRUE(torch::allclose(y, x)); |
2966 | } |
2967 | if (std::numeric_limits<T>::has_quiet_NaN) { |
2968 | const auto x = torch::tensor( |
2969 | {std::numeric_limits<T>::quiet_NaN()}, |
2970 | torch::TensorOptions().dtype(S).device(device)); |
2971 | const auto y = torch::tensor( |
2972 | {std::numeric_limits<T>::quiet_NaN()}, |
2973 | torch::TensorOptions().dtype(S).device(device)); |
2974 | ASSERT_TRUE(torch::allclose(x, x, 1.0, 0.0, /*equal_nan=*/true)); |
2975 | ASSERT_TRUE(torch::allclose(x, y, 1.0, 0.0, /*equal_nan=*/true)); |
2976 | ASSERT_TRUE(torch::allclose(y, x, 1.0, 0.0, /*equal_nan=*/true)); |
2977 | } |
2978 | if (std::numeric_limits<T>::has_signaling_NaN) { |
2979 | const auto x = torch::tensor( |
2980 | {std::numeric_limits<T>::signaling_NaN()}, |
2981 | torch::TensorOptions().dtype(S).device(device)); |
2982 | const auto y = torch::tensor( |
2983 | {std::numeric_limits<T>::signaling_NaN()}, |
2984 | torch::TensorOptions().dtype(S).device(device)); |
2985 | ASSERT_TRUE(torch::allclose(x, x, 1.0, 0.0, /*equal_nan=*/true)); |
2986 | ASSERT_TRUE(torch::allclose(x, y, 1.0, 0.0, /*equal_nan=*/true)); |
2987 | ASSERT_TRUE(torch::allclose(y, x, 1.0, 0.0, /*equal_nan=*/true)); |
2988 | } |
2989 | } |
2990 | |
2991 | TEST_F(FunctionalTest, AllClose) { |
2992 | const at::Device device("cpu" ); |
2993 | test_allclose<torch::kUInt8, uint8_t>(device); |
2994 | test_allclose<torch::kInt8, int8_t>(device); |
2995 | test_allclose<torch::kInt16, int16_t>(device); |
2996 | test_allclose<torch::kInt32, int32_t>(device); |
2997 | test_allclose<torch::kInt64, int64_t>(device); |
2998 | test_allclose<torch::kFloat32, float>(device); |
2999 | test_allclose<torch::kFloat64, double>(device); |
3000 | } |
3001 | |
3002 | TEST_F(FunctionalTest, AllClose_CUDA) { |
3003 | const at::Device device("cuda" ); |
3004 | test_allclose<torch::kUInt8, uint8_t>(device); |
3005 | test_allclose<torch::kInt8, int8_t>(device); |
3006 | test_allclose<torch::kInt16, int16_t>(device); |
3007 | test_allclose<torch::kInt32, int32_t>(device); |
3008 | test_allclose<torch::kInt64, int64_t>(device); |
3009 | test_allclose<torch::kFloat32, float>(device); |
3010 | test_allclose<torch::kFloat64, double>(device); |
3011 | test_allclose<torch::kFloat16, c10::Half>(device); |
3012 | } |
3013 | |
3014 | TEST_F(FunctionalTest, BCEWithLogitsLoss) { |
3015 | {// test BCE with logits raises if target and input are different size |
3016 | {const auto target = torch::rand(5); |
3017 | const auto input = torch::rand({5, 1}); |
3018 | ASSERT_THROWS_WITH( |
3019 | F::binary_cross_entropy_with_logits(input, target), |
3020 | "must be the same as input size" ); |
3021 | } |
3022 | { |
3023 | const auto target = torch::rand({5, 1}); |
3024 | const auto input = torch::rand(5); |
3025 | ASSERT_THROWS_WITH( |
3026 | F::binary_cross_entropy_with_logits(input, target), |
3027 | "must be the same as input size" ); |
3028 | } |
3029 | } |
3030 | { // test BCE with logits gives same result as sigmoid and bce loss |
3031 | auto sigmoid = Sigmoid(); |
3032 | |
3033 | auto target = torch::rand({64, 4}); |
3034 | auto output = torch::rand({64, 4}) - 0.5; |
3035 | |
3036 | ASSERT_TRUE(torch::allclose( |
3037 | F::binary_cross_entropy_with_logits(output, target), |
3038 | F::binary_cross_entropy(sigmoid(output), target))); |
3039 | |
3040 | auto weight = torch::rand(4); |
3041 | ASSERT_TRUE(torch::allclose( |
3042 | F::binary_cross_entropy_with_logits( |
3043 | output, |
3044 | target, |
3045 | F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight)), |
3046 | F::binary_cross_entropy( |
3047 | sigmoid(output), |
3048 | target, |
3049 | F::BinaryCrossEntropyFuncOptions().weight(weight)))); |
3050 | |
3051 | target = torch::zeros({4, 1}, torch::kFloat); |
3052 | output = torch::empty({4, 1}, torch::kFloat).fill_(-100); |
3053 | |
3054 | ASSERT_TRUE(torch::allclose( |
3055 | F::binary_cross_entropy_with_logits(output, target), |
3056 | F::binary_cross_entropy(sigmoid(output), target))); |
3057 | |
3058 | ASSERT_TRUE(torch::allclose( |
3059 | F::binary_cross_entropy_with_logits( |
3060 | output, |
3061 | target, |
3062 | F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone)), |
3063 | F::binary_cross_entropy( |
3064 | sigmoid(output), |
3065 | target, |
3066 | F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone)))); |
3067 | |
3068 | weight = torch::rand({1}, torch::kFloat); |
3069 | ASSERT_TRUE(torch::allclose( |
3070 | F::binary_cross_entropy_with_logits( |
3071 | output, |
3072 | target, |
3073 | F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight)), |
3074 | F::binary_cross_entropy( |
3075 | sigmoid(output), |
3076 | target, |
3077 | F::BinaryCrossEntropyFuncOptions().weight(weight)))); |
3078 | } |
3079 | { // test BCE with logits has correct grad at zero |
3080 | const auto output = torch::zeros({3, 1}, torch::requires_grad()); |
3081 | const auto target = torch::zeros({3, 1}); |
3082 | F::binary_cross_entropy_with_logits( |
3083 | output, |
3084 | target, |
3085 | F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kSum)) |
3086 | .backward(); |
3087 | const auto expected_grad = torch::empty({3, 1}).fill_(0.5); |
3088 | ASSERT_TRUE(torch::allclose(output.grad(), expected_grad)); |
3089 | } |
3090 | { // test BCE with logits broadcasts weights |
3091 | const auto target = torch::rand({16, 4}); |
3092 | const auto output = torch::rand({16, 4}) - 0.5; |
3093 | |
3094 | auto weight = torch::rand(4); |
3095 | auto out1 = F::binary_cross_entropy_with_logits( |
3096 | output, |
3097 | target, |
3098 | F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight)); |
3099 | |
3100 | weight = weight.expand({16, 4}).contiguous(); |
3101 | auto out2 = F::binary_cross_entropy_with_logits( |
3102 | output, |
3103 | target, |
3104 | F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight)); |
3105 | |
3106 | ASSERT_TRUE(torch::allclose(out1, out2)); |
3107 | |
3108 | weight = torch::rand({16, 1}); |
3109 | out1 = F::binary_cross_entropy_with_logits( |
3110 | output, |
3111 | target, |
3112 | F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight)); |
3113 | |
3114 | weight = weight.expand({16, 4}).contiguous(); |
3115 | out2 = F::binary_cross_entropy_with_logits( |
3116 | output, |
3117 | target, |
3118 | F::BinaryCrossEntropyWithLogitsFuncOptions().weight(weight)); |
3119 | |
3120 | ASSERT_TRUE(torch::allclose(out1, out2)); |
3121 | } |
3122 | { // test BCE with logits ones in pos weights are the same as none |
3123 | const auto target = torch::rand({64, 4}); |
3124 | const auto output = torch::rand({64, 4}) - 0.5; |
3125 | const auto pos_weight = torch::ones({64, 4}); |
3126 | |
3127 | ASSERT_TRUE(torch::allclose( |
3128 | F::binary_cross_entropy_with_logits(output, target), |
3129 | F::binary_cross_entropy_with_logits( |
3130 | output, |
3131 | target, |
3132 | F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight( |
3133 | pos_weight)))); |
3134 | } |
3135 | { // test BCE with logits broadcasts pos weights |
3136 | const auto target = torch::rand({64, 4}); |
3137 | const auto output = torch::rand({64, 4}) - 0.5; |
3138 | const auto pos_weight = torch::rand(4); |
3139 | const auto out1 = F::binary_cross_entropy_with_logits( |
3140 | output, |
3141 | target, |
3142 | F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight)); |
3143 | |
3144 | const auto pos_weight1 = pos_weight.expand({1, 4}); |
3145 | const auto out2 = F::binary_cross_entropy_with_logits( |
3146 | output, |
3147 | target, |
3148 | F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight)); |
3149 | |
3150 | const auto pos_weight2 = pos_weight.expand({64, 4}); |
3151 | const auto out3 = F::binary_cross_entropy_with_logits( |
3152 | output, |
3153 | target, |
3154 | F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight)); |
3155 | |
3156 | ASSERT_TRUE(torch::allclose(out1, out2)); |
3157 | ASSERT_TRUE(torch::allclose(out1, out3)); |
3158 | } |
3159 | { // test BCE with logits with pos weight has correct grad at zero |
3160 | const auto output = torch::zeros({3, 1}, torch::requires_grad()); |
3161 | const auto target = torch::zeros({3, 1}); |
3162 | const auto pos_weight = torch::ones({3, 1}); |
3163 | F::binary_cross_entropy_with_logits( |
3164 | output, |
3165 | target, |
3166 | F::BinaryCrossEntropyWithLogitsFuncOptions() |
3167 | .pos_weight(pos_weight) |
3168 | .reduction(torch::kSum)) |
3169 | .backward(); |
3170 | const auto expected_grad = torch::empty({3, 1}).fill_(0.5); |
3171 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
3172 | const auto grad = output.grad(); |
3173 | ASSERT_TRUE(torch::allclose(grad, expected_grad)); |
3174 | } |
3175 | { // test BCE with logits stability |
3176 | const auto output = torch::tensor({0., -120.}); |
3177 | const auto target = torch::tensor({0., 1.}); |
3178 | const auto pos_weight = torch::tensor({1., 1.}); |
3179 | |
3180 | const auto out1 = F::binary_cross_entropy_with_logits(output, target); |
3181 | ASSERT_TRUE(torch::isfinite(out1).all().item<bool>()); |
3182 | |
3183 | const auto out2 = F::binary_cross_entropy_with_logits( |
3184 | output, |
3185 | target, |
3186 | F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight)); |
3187 | ASSERT_TRUE(torch::isfinite(out2).all().item<bool>()); |
3188 | } |
3189 | } |
3190 | |