1#include <gtest/gtest.h>
2
3#include <c10/util/irange.h>
4#include <torch/torch.h>
5
6#include <test/cpp/api/init_baseline.h>
7#include <test/cpp/api/support.h>
8
9#include <functional>
10#include <vector>
11
12void check_exact_values(
13 const std::vector<torch::Tensor>& parameters,
14 const std::vector<std::vector<torch::Tensor>>& expected_parameters) {
15 ASSERT_EQ(parameters.size(), expected_parameters.size());
16
17 for (const auto i : c10::irange(parameters.size())) {
18 auto layerParameters = parameters[i];
19 auto expectedLayerParameters = expected_parameters[i];
20
21 if (static_cast<size_t>(layerParameters.size(0)) !=
22 expectedLayerParameters.size()) {
23 std::cout << "layer #" << i
24 << " layerParameters size: " << layerParameters.size(0)
25 << " != "
26 << " expectedLayerParameters size: "
27 << expectedLayerParameters.size() << std::endl;
28 ASSERT_TRUE(false);
29 }
30
31 for (const auto p : c10::irange(layerParameters.size(0))) {
32 // Always compare using double dtype, regardless of the original dtype of
33 // the tensors
34 auto tensor = layerParameters[p].to(torch::kFloat64);
35 auto expectedTensor = expectedLayerParameters[p].to(torch::kFloat64);
36
37 if (!tensor.allclose(expectedTensor, /*rtol=*/1e-3, /*atol=*/5e-4)) {
38 std::cout << "layer " << i << ": " << tensor << " != " << expectedTensor
39 << " (parameter " << p << ")" << std::endl;
40 ASSERT_TRUE(false);
41 }
42 }
43 }
44}
45
46void check_initializer_against_baseline(
47 std::function<void(torch::Tensor)> initializer,
48 std::vector<std::vector<torch::Tensor>> expected) {
49 torch::manual_seed(0);
50
51 auto layer1 = torch::nn::Linear(7, 15);
52 initializer(layer1->weight);
53 layer1->to(torch::kFloat64);
54
55 auto layer2 = torch::nn::Linear(15, 15);
56 initializer(layer2->weight);
57 layer2->to(torch::kFloat64);
58
59 auto layer3 = torch::nn::Linear(15, 2);
60 initializer(layer3->weight);
61 layer3->to(torch::kFloat64);
62
63 auto parameters = std::vector<torch::Tensor>{
64 layer1->weight,
65 layer2->weight,
66 layer3->weight,
67 };
68
69 check_exact_values(parameters, expected);
70}
71
72TEST(InitTest, ProducesPyTorchValues_XavierUniform) {
73 auto expected = expected_parameters::Xavier_Uniform();
74 auto initializer = [](torch::Tensor tensor) {
75 torch::nn::init::xavier_uniform_(tensor);
76 };
77 check_initializer_against_baseline(initializer, expected);
78}
79
80TEST(InitTest, ProducesPyTorchValues_XavierNormal) {
81 auto expected = expected_parameters::Xavier_Normal();
82 auto initializer = [](torch::Tensor tensor) {
83 torch::nn::init::xavier_normal_(tensor);
84 };
85 check_initializer_against_baseline(initializer, expected);
86}
87
88TEST(InitTest, ProducesPyTorchValues_KaimingNormal) {
89 auto expected = expected_parameters::Kaiming_Normal();
90 auto initializer = [](torch::Tensor tensor) {
91 torch::nn::init::kaiming_normal_(tensor);
92 };
93 check_initializer_against_baseline(initializer, expected);
94}
95
96TEST(InitTest, ProducesPyTorchValues_KaimingUniform) {
97 auto expected = expected_parameters::Kaiming_Uniform();
98 auto initializer = [](torch::Tensor tensor) {
99 torch::nn::init::kaiming_uniform_(tensor);
100 };
101 check_initializer_against_baseline(initializer, expected);
102}
103
104TEST(InitTest, CanInitializeTensorThatRequiresGrad) {
105 auto tensor = torch::empty({3, 4}, torch::requires_grad());
106 ASSERT_THROWS_WITH(
107 tensor.fill_(1),
108 "a leaf Variable that requires grad "
109 "is being used in an in-place operation");
110 ASSERT_EQ(torch::nn::init::ones_(tensor).sum().item<int32_t>(), 12);
111}
112
113TEST(InitTest, CalculateGainWithTanh) {
114 double gain = torch::nn::init::calculate_gain(torch::kTanh);
115 ASSERT_DOUBLE_EQ(gain, 5.0 / 3.0);
116}
117
118TEST(InitTest, CalculateGainWithRelu) {
119 double gain = torch::nn::init::calculate_gain(torch::kReLU);
120 ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0));
121}
122
123TEST(InitTest, CalculateGainWithLeakyRelu) {
124 double gain = torch::nn::init::calculate_gain(torch::kLeakyReLU);
125 ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0 / (1 + pow(0.01, 2))));
126}
127
128TEST(InitTest, CanInitializeCnnWithOrthogonal) {
129 torch::nn::Conv2d conv_layer(torch::nn::Conv2dOptions(3, 2, 3).stride(2));
130 torch::nn::init::orthogonal_(conv_layer->named_parameters()["weight"]);
131}
132