1#include <gtest/gtest.h>
2#include <torch/torch.h>
3#include <algorithm>
4#include <memory>
5#include <vector>
6
7#include <test/cpp/api/support.h>
8
9using namespace torch::nn;
10using namespace torch::test;
11
12struct ParameterDictTest : torch::test::SeedingFixture {};
13
14TEST_F(ParameterDictTest, ConstructFromTensor) {
15 ParameterDict dict;
16 torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
17 torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
18 torch::Tensor tc = torch::randn({1, 2});
19 ASSERT_TRUE(ta.requires_grad());
20 ASSERT_FALSE(tb.requires_grad());
21 dict->insert("A", ta);
22 dict->insert("B", tb);
23 dict->insert("C", tc);
24 ASSERT_EQ(dict->size(), 3);
25 ASSERT_TRUE(torch::all(torch::eq(dict["A"], ta)).item<bool>());
26 ASSERT_TRUE(dict["A"].requires_grad());
27 ASSERT_TRUE(torch::all(torch::eq(dict["B"], tb)).item<bool>());
28 ASSERT_FALSE(dict["B"].requires_grad());
29}
30
31TEST_F(ParameterDictTest, ConstructFromOrderedDict) {
32 torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
33 torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
34 torch::Tensor tc = torch::randn({1, 2});
35 torch::OrderedDict<std::string, torch::Tensor> params = {
36 {"A", ta}, {"B", tb}, {"C", tc}};
37 auto dict = torch::nn::ParameterDict(params);
38 ASSERT_EQ(dict->size(), 3);
39 ASSERT_TRUE(torch::all(torch::eq(dict["A"], ta)).item<bool>());
40 ASSERT_TRUE(dict["A"].requires_grad());
41 ASSERT_TRUE(torch::all(torch::eq(dict["B"], tb)).item<bool>());
42 ASSERT_FALSE(dict["B"].requires_grad());
43}
44
45TEST_F(ParameterDictTest, InsertAndContains) {
46 ParameterDict dict;
47 dict->insert("A", torch::tensor({1.0}));
48 ASSERT_EQ(dict->size(), 1);
49 ASSERT_TRUE(dict->contains("A"));
50 ASSERT_FALSE(dict->contains("C"));
51}
52
53TEST_F(ParameterDictTest, InsertAndClear) {
54 ParameterDict dict;
55 dict->insert("A", torch::tensor({1.0}));
56 ASSERT_EQ(dict->size(), 1);
57 dict->clear();
58 ASSERT_EQ(dict->size(), 0);
59}
60
61TEST_F(ParameterDictTest, InsertAndPop) {
62 ParameterDict dict;
63 dict->insert("A", torch::tensor({1.0}));
64 ASSERT_EQ(dict->size(), 1);
65 ASSERT_THROWS_WITH(dict->pop("B"), "Parameter 'B' is not defined");
66 torch::Tensor p = dict->pop("A");
67 ASSERT_EQ(dict->size(), 0);
68 ASSERT_TRUE(torch::eq(p, torch::tensor({1.0})).item<bool>());
69}
70
71TEST_F(ParameterDictTest, SimpleUpdate) {
72 ParameterDict dict;
73 ParameterDict wrongDict;
74 ParameterDict rightDict;
75 dict->insert("A", torch::tensor({1.0}));
76 dict->insert("B", torch::tensor({2.0}));
77 dict->insert("C", torch::tensor({3.0}));
78 wrongDict->insert("A", torch::tensor({5.0}));
79 wrongDict->insert("D", torch::tensor({5.0}));
80 ASSERT_THROWS_WITH(dict->update(*wrongDict), "Parameter 'D' is not defined");
81 rightDict->insert("A", torch::tensor({5.0}));
82 dict->update(*rightDict);
83 ASSERT_EQ(dict->size(), 3);
84 ASSERT_TRUE(torch::eq(dict["A"], torch::tensor({5.0})).item<bool>());
85}
86
87TEST_F(ParameterDictTest, Keys) {
88 torch::OrderedDict<std::string, torch::Tensor> params = {
89 {"a", torch::tensor({1.0})},
90 {"b", torch::tensor({2.0})},
91 {"c", torch::tensor({1.0, 2.0})}};
92 auto dict = torch::nn::ParameterDict(params);
93 std::vector<std::string> keys = dict->keys();
94 std::vector<std::string> true_keys{"a", "b", "c"};
95 ASSERT_EQ(keys, true_keys);
96}
97
98TEST_F(ParameterDictTest, Values) {
99 torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
100 torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
101 torch::Tensor tc = torch::randn({1, 2});
102 torch::OrderedDict<std::string, torch::Tensor> params = {
103 {"a", ta}, {"b", tb}, {"c", tc}};
104 auto dict = torch::nn::ParameterDict(params);
105 std::vector<torch::Tensor> values = dict->values();
106 std::vector<torch::Tensor> true_values{ta, tb, tc};
107 for (auto i = 0U; i < values.size(); i += 1) {
108 ASSERT_TRUE(torch::all(torch::eq(values[i], true_values[i])).item<bool>());
109 }
110}
111
112TEST_F(ParameterDictTest, Get) {
113 ParameterDict dict;
114 torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true));
115 torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false));
116 torch::Tensor tc = torch::randn({1, 2});
117 ASSERT_TRUE(ta.requires_grad());
118 ASSERT_FALSE(tb.requires_grad());
119 dict->insert("A", ta);
120 dict->insert("B", tb);
121 dict->insert("C", tc);
122 ASSERT_EQ(dict->size(), 3);
123 ASSERT_TRUE(torch::all(torch::eq(dict->get("A"), ta)).item<bool>());
124 ASSERT_TRUE(dict->get("A").requires_grad());
125 ASSERT_TRUE(torch::all(torch::eq(dict->get("B"), tb)).item<bool>());
126 ASSERT_FALSE(dict->get("B").requires_grad());
127}
128
129TEST_F(ParameterDictTest, PrettyPrintParameterDict) {
130 torch::OrderedDict<std::string, torch::Tensor> params = {
131 {"a", torch::tensor({1.0})},
132 {"b", torch::tensor({2.0, 1.0})},
133 {"c", torch::tensor({{3.0}, {2.1}})},
134 {"d", torch::tensor({{3.0, 1.3}, {1.2, 2.1}})}};
135 auto dict = torch::nn::ParameterDict(params);
136 ASSERT_EQ(
137 c10::str(dict),
138 "torch::nn::ParameterDict(\n"
139 "(a): Parameter containing: [Float of size [1]]\n"
140 "(b): Parameter containing: [Float of size [2]]\n"
141 "(c): Parameter containing: [Float of size [2, 1]]\n"
142 "(d): Parameter containing: [Float of size [2, 2]]\n"
143 ")");
144}
145