1 | #include <gtest/gtest.h> |
2 | |
3 | #include <c10/util/irange.h> |
4 | #include <torch/torch.h> |
5 | |
6 | #include <test/cpp/api/support.h> |
7 | struct OperationTest : torch::test::SeedingFixture { |
8 | protected: |
9 | void SetUp() override {} |
10 | |
11 | const int TEST_AMOUNT = 10; |
12 | }; |
13 | |
14 | TEST_F(OperationTest, Lerp) { |
15 | for (const auto i : c10::irange(TEST_AMOUNT)) { |
16 | (void)i; // Suppress unused variable warning |
17 | // test lerp_kernel_scalar |
18 | auto start = torch::rand({3, 5}); |
19 | auto end = torch::rand({3, 5}); |
20 | auto scalar = 0.5; |
21 | // expected and actual |
22 | auto scalar_expected = start + scalar * (end - start); |
23 | auto out = torch::lerp(start, end, scalar); |
24 | // compare |
25 | ASSERT_EQ(out.dtype(), scalar_expected.dtype()); |
26 | ASSERT_TRUE(out.allclose(scalar_expected)); |
27 | |
28 | // test lerp_kernel_tensor |
29 | auto weight = torch::rand({3, 5}); |
30 | // expected and actual |
31 | auto tensor_expected = start + weight * (end - start); |
32 | out = torch::lerp(start, end, weight); |
33 | // compare |
34 | ASSERT_EQ(out.dtype(), tensor_expected.dtype()); |
35 | ASSERT_TRUE(out.allclose(tensor_expected)); |
36 | } |
37 | } |
38 | |
39 | TEST_F(OperationTest, Cross) { |
40 | for (const auto i : c10::irange(TEST_AMOUNT)) { |
41 | (void)i; // Suppress unused variable warning |
42 | // input |
43 | auto a = torch::rand({10, 3}); |
44 | auto b = torch::rand({10, 3}); |
45 | // expected |
46 | auto exp = torch::empty({10, 3}); |
47 | for (const auto j : c10::irange(10)) { |
48 | auto u1 = a[j][0], u2 = a[j][1], u3 = a[j][2]; |
49 | auto v1 = b[j][0], v2 = b[j][1], v3 = b[j][2]; |
50 | exp[j][0] = u2 * v3 - v2 * u3; |
51 | exp[j][1] = v1 * u3 - u1 * v3; |
52 | exp[j][2] = u1 * v2 - v1 * u2; |
53 | } |
54 | // actual |
55 | auto out = torch::cross(a, b); |
56 | // compare |
57 | ASSERT_EQ(out.dtype(), exp.dtype()); |
58 | ASSERT_TRUE(out.allclose(exp)); |
59 | } |
60 | } |
61 | |
62 | TEST_F(OperationTest, Linear_out) { |
63 | { |
64 | const auto x = torch::arange(100., 118).resize_({3, 3, 2}); |
65 | const auto w = torch::arange(200., 206).resize_({3, 2}); |
66 | const auto b = torch::arange(300., 303); |
67 | auto y = torch::empty({3, 3, 3}); |
68 | at::linear_out(y, x, w, b); |
69 | const auto y_exp = torch::tensor( |
70 | {{{40601, 41004, 41407}, {41403, 41814, 42225}, {42205, 42624, 43043}}, |
71 | {{43007, 43434, 43861}, {43809, 44244, 44679}, {44611, 45054, 45497}}, |
72 | {{45413, 45864, 46315}, {46215, 46674, 47133}, {47017, 47484, 47951}}}, |
73 | torch::kFloat); |
74 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
75 | } |
76 | { |
77 | const auto x = torch::arange(100., 118).resize_({3, 3, 2}); |
78 | const auto w = torch::arange(200., 206).resize_({3, 2}); |
79 | auto y = torch::empty({3, 3, 3}); |
80 | at::linear_out(y, x, w); |
81 | ASSERT_EQ(y.ndimension(), 3); |
82 | ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 3, 3})); |
83 | const auto y_exp = torch::tensor( |
84 | {{{40301, 40703, 41105}, {41103, 41513, 41923}, {41905, 42323, 42741}}, |
85 | {{42707, 43133, 43559}, {43509, 43943, 44377}, {44311, 44753, 45195}}, |
86 | {{45113, 45563, 46013}, {45915, 46373, 46831}, {46717, 47183, 47649}}}, |
87 | torch::kFloat); |
88 | ASSERT_TRUE(torch::allclose(y, y_exp)); |
89 | } |
90 | } |
91 | |