1#include <gtest/gtest.h>
2
3#include <c10/util/irange.h>
4#include <torch/torch.h>
5
6#include <test/cpp/api/support.h>
7struct OperationTest : torch::test::SeedingFixture {
8 protected:
9 void SetUp() override {}
10
11 const int TEST_AMOUNT = 10;
12};
13
14TEST_F(OperationTest, Lerp) {
15 for (const auto i : c10::irange(TEST_AMOUNT)) {
16 (void)i; // Suppress unused variable warning
17 // test lerp_kernel_scalar
18 auto start = torch::rand({3, 5});
19 auto end = torch::rand({3, 5});
20 auto scalar = 0.5;
21 // expected and actual
22 auto scalar_expected = start + scalar * (end - start);
23 auto out = torch::lerp(start, end, scalar);
24 // compare
25 ASSERT_EQ(out.dtype(), scalar_expected.dtype());
26 ASSERT_TRUE(out.allclose(scalar_expected));
27
28 // test lerp_kernel_tensor
29 auto weight = torch::rand({3, 5});
30 // expected and actual
31 auto tensor_expected = start + weight * (end - start);
32 out = torch::lerp(start, end, weight);
33 // compare
34 ASSERT_EQ(out.dtype(), tensor_expected.dtype());
35 ASSERT_TRUE(out.allclose(tensor_expected));
36 }
37}
38
39TEST_F(OperationTest, Cross) {
40 for (const auto i : c10::irange(TEST_AMOUNT)) {
41 (void)i; // Suppress unused variable warning
42 // input
43 auto a = torch::rand({10, 3});
44 auto b = torch::rand({10, 3});
45 // expected
46 auto exp = torch::empty({10, 3});
47 for (const auto j : c10::irange(10)) {
48 auto u1 = a[j][0], u2 = a[j][1], u3 = a[j][2];
49 auto v1 = b[j][0], v2 = b[j][1], v3 = b[j][2];
50 exp[j][0] = u2 * v3 - v2 * u3;
51 exp[j][1] = v1 * u3 - u1 * v3;
52 exp[j][2] = u1 * v2 - v1 * u2;
53 }
54 // actual
55 auto out = torch::cross(a, b);
56 // compare
57 ASSERT_EQ(out.dtype(), exp.dtype());
58 ASSERT_TRUE(out.allclose(exp));
59 }
60}
61
62TEST_F(OperationTest, Linear_out) {
63 {
64 const auto x = torch::arange(100., 118).resize_({3, 3, 2});
65 const auto w = torch::arange(200., 206).resize_({3, 2});
66 const auto b = torch::arange(300., 303);
67 auto y = torch::empty({3, 3, 3});
68 at::linear_out(y, x, w, b);
69 const auto y_exp = torch::tensor(
70 {{{40601, 41004, 41407}, {41403, 41814, 42225}, {42205, 42624, 43043}},
71 {{43007, 43434, 43861}, {43809, 44244, 44679}, {44611, 45054, 45497}},
72 {{45413, 45864, 46315}, {46215, 46674, 47133}, {47017, 47484, 47951}}},
73 torch::kFloat);
74 ASSERT_TRUE(torch::allclose(y, y_exp));
75 }
76 {
77 const auto x = torch::arange(100., 118).resize_({3, 3, 2});
78 const auto w = torch::arange(200., 206).resize_({3, 2});
79 auto y = torch::empty({3, 3, 3});
80 at::linear_out(y, x, w);
81 ASSERT_EQ(y.ndimension(), 3);
82 ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 3, 3}));
83 const auto y_exp = torch::tensor(
84 {{{40301, 40703, 41105}, {41103, 41513, 41923}, {41905, 42323, 42741}},
85 {{42707, 43133, 43559}, {43509, 43943, 44377}, {44311, 44753, 45195}},
86 {{45113, 45563, 46013}, {45915, 46373, 46831}, {46717, 47183, 47649}}},
87 torch::kFloat);
88 ASSERT_TRUE(torch::allclose(y, y_exp));
89 }
90}
91