1 | #include <ATen/Functions.h> |
2 | #include <ATen/core/IListRef.h> |
3 | #include <ATen/core/Tensor.h> |
4 | #include <gtest/gtest.h> |
5 | #include <algorithm> |
6 | #include <iterator> |
7 | |
8 | using namespace c10; |
9 | |
10 | static std::vector<at::Tensor> get_tensor_vector() { |
11 | std::vector<at::Tensor> tensors; |
12 | const size_t SIZE = 5; |
13 | for (size_t i = 0; i < SIZE; i++) { |
14 | tensors.emplace_back(at::empty({0})); |
15 | } |
16 | return tensors; |
17 | } |
18 | |
19 | static std::vector<optional<at::Tensor>> get_boxed_opt_tensor_vector() { |
20 | std::vector<optional<at::Tensor>> optional_tensors; |
21 | const size_t SIZE = 5; |
22 | for (size_t i = 0; i < SIZE * 2; i++) { |
23 | auto opt_tensor = (i % 2 == 0) ? optional<at::Tensor>(at::empty({0})) : nullopt; |
24 | optional_tensors.emplace_back(opt_tensor); |
25 | } |
26 | return optional_tensors; |
27 | } |
28 | |
29 | static std::vector<at::OptionalTensorRef> get_unboxed_opt_tensor_vector() { |
30 | std::vector<at::OptionalTensorRef> optional_tensors; |
31 | const size_t SIZE = 5; |
32 | for (size_t i = 0; i < SIZE * 2; i++) { |
33 | auto opt_tensor = (i % 2 == 0) ? at::OptionalTensorRef(at::empty({0})) |
34 | : at::OptionalTensorRef(); |
35 | optional_tensors.emplace_back(opt_tensor); |
36 | } |
37 | return optional_tensors; |
38 | } |
39 | |
40 | template <typename T> |
41 | void check_elements_same(at::ITensorListRef list, const T& thing, int use_count) { |
42 | EXPECT_EQ(thing.size(), list.size()); |
43 | size_t i = 0; |
44 | for (const auto& t : list) { |
45 | const at::Tensor& other = thing[i]; |
46 | EXPECT_EQ(other.use_count(), use_count); |
47 | EXPECT_TRUE(other.is_same(t)); |
48 | i++; |
49 | } |
50 | } |
51 | |
52 | TEST(ITensorListRefTest, CtorEmpty_IsNone_Throws) { |
53 | at::ITensorListRef list; |
54 | EXPECT_TRUE(list.isNone()); |
55 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
56 | EXPECT_THROW(list.size(), c10::Error); |
57 | } |
58 | |
59 | TEST(ITensorListRefTest, CtorBoxed_IsBoxed) { |
60 | auto vec = get_tensor_vector(); |
61 | List<at::Tensor> boxed(vec); |
62 | at::ITensorListRef list(boxed); |
63 | EXPECT_TRUE(list.isBoxed()); |
64 | } |
65 | |
66 | TEST(ITensorListRefTest, CtorUnboxed_IsUnboxed) { |
67 | auto vec = get_tensor_vector(); |
68 | at::ArrayRef<at::Tensor> unboxed(vec); |
69 | at::ITensorListRef list(unboxed); |
70 | EXPECT_TRUE(list.isUnboxed()); |
71 | } |
72 | |
73 | TEST(ITensorListRefTest, CtorUnboxedIndirect_IsUnboxed) { |
74 | auto vec = get_tensor_vector(); |
75 | auto check_is_unboxed = [](at::ITensorListRef list) { |
76 | EXPECT_TRUE(list.isUnboxed()); |
77 | }; |
78 | check_is_unboxed(at::ITensorListRef{vec[0]}); |
79 | check_is_unboxed(at::ITensorListRef{vec.data(), vec.size()}); |
80 | check_is_unboxed(at::ITensorListRef{vec.data(), vec.data() + vec.size()}); |
81 | check_is_unboxed(vec); |
82 | check_is_unboxed({vec[0], vec[1], vec[2]}); |
83 | } |
84 | |
85 | TEST(ITensorListRefTest, CtorTemp_IsUnboxed) { |
86 | auto check_is_unboxed = [](at::ITensorListRef list) { |
87 | EXPECT_TRUE(list.isUnboxed()); |
88 | }; |
89 | |
90 | auto vec = get_tensor_vector(); |
91 | check_is_unboxed({vec[0], vec[1]}); |
92 | } |
93 | |
94 | TEST(ITensorListRefTest, Boxed_GetConstRefTensor) { |
95 | auto vec = get_tensor_vector(); |
96 | // We need 'boxed' to be 'const' here (and some other tests below) |
97 | // because 'List<Tensor>::operator[]' returns a 'ListElementReference' |
98 | // instead of returning a 'Tensor'. On the other hand, |
99 | // 'List<Tensor>::operator[] const' returns a 'const Tensor &'. |
100 | const List<at::Tensor> boxed(vec); |
101 | at::ITensorListRef list(boxed); |
102 | static_assert( |
103 | std::is_same<decltype(*list.begin()), const at::Tensor&>::value, |
104 | "Accessing elements from List<Tensor> through a ITensorListRef should be const references." ); |
105 | EXPECT_TRUE(boxed[0].is_same(*list.begin())); |
106 | EXPECT_TRUE(boxed[1].is_same(*(++list.begin()))); |
107 | } |
108 | |
109 | TEST(ITensorListRefTest, Unboxed_GetConstRefTensor) { |
110 | auto vec = get_tensor_vector(); |
111 | at::ITensorListRef list(vec); |
112 | static_assert( |
113 | std::is_same<decltype(*list.begin()), const at::Tensor&>::value, |
114 | "Accessing elements from ArrayRef<Tensor> through a ITensorListRef should be const references." ); |
115 | EXPECT_TRUE(vec[0].is_same(*list.begin())); |
116 | EXPECT_TRUE(vec[1].is_same(*(++list.begin()))); |
117 | } |
118 | |
119 | TEST(ITensorListRefTest, Boxed_Equal) { |
120 | auto vec = get_tensor_vector(); |
121 | List<at::Tensor> boxed(vec); |
122 | check_elements_same(boxed, vec, /* use_count= */ 2); |
123 | } |
124 | |
125 | TEST(ITensorListRefTest, Unboxed_Equal) { |
126 | auto vec = get_tensor_vector(); |
127 | check_elements_same(at::ArrayRef<at::Tensor>(vec), vec, /* use_count= */ 1); |
128 | } |
129 | |
130 | TEST(ITensorListRefTest, UnboxedIndirect_Equal) { |
131 | // The 4 ref-count locations: |
132 | // 1. `vec` |
133 | // 2. `initializer_list` for `ITensorListRef` |
134 | // 3. `initializer_list` for `std::vector` |
135 | // 4. temporary `std::vector` |
136 | auto vec = get_tensor_vector(); |
137 | // Implicit constructors |
138 | check_elements_same(vec[0], std::vector<at::Tensor>{vec[0]}, /* use_count= */ 3); |
139 | check_elements_same({vec.data(), vec.size()}, vec, /* use_count= */ 1); |
140 | check_elements_same({vec.data(), vec.data() + vec.size()}, vec, /* use_count= */ 1); |
141 | // Vector constructor |
142 | check_elements_same(vec, vec, /* use_count= */ 1); |
143 | // InitializerList constructor |
144 | check_elements_same({vec[0], vec[1], vec[2]}, std::vector<at::Tensor>{vec[0], vec[1], vec[2]}, /* use_count= */ 4); |
145 | } |
146 | |
147 | TEST(ITensorListRefTest, BoxedMaterialize_Equal) { |
148 | auto vec = get_tensor_vector(); |
149 | List<at::Tensor> boxed(vec); |
150 | at::ITensorListRef list(boxed); |
151 | auto materialized = list.materialize(); |
152 | check_elements_same(list, vec, 2); |
153 | check_elements_same(list, materialized, 2); |
154 | check_elements_same(materialized, vec, 2); |
155 | } |
156 | |
157 | TEST(ITensorListRefTest, UnboxedMaterialize_Equal) { |
158 | auto vec = get_tensor_vector(); |
159 | at::ArrayRef<at::Tensor> unboxed(vec); |
160 | at::ITensorListRef list(unboxed); |
161 | auto materialized = list.materialize(); |
162 | check_elements_same(list, vec, 1); |
163 | check_elements_same(list, materialized, 1); |
164 | check_elements_same(materialized, vec, 1); |
165 | } |
166 | |
167 | TEST(ITensorListRefIteratorTest, CtorEmpty_ThrowsError) { |
168 | at::ITensorListRefIterator* it = new at::ITensorListRefIterator(); |
169 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
170 | EXPECT_THROW(**it, c10::Error); |
171 | |
172 | #if defined(_MSC_VER) && _ITERATOR_DEBUG_LEVEL == 2 |
173 | EXPECT_THROW({ delete it; }, c10::Error); |
174 | #else |
175 | delete it; |
176 | #endif |
177 | } |
178 | |
179 | TEST(ITensorListRefIteratorTest, Boxed_GetFirstElement) { |
180 | auto vec = get_tensor_vector(); |
181 | const List<at::Tensor> boxed(vec); |
182 | at::ITensorListRef list(boxed); |
183 | EXPECT_TRUE(boxed[0].is_same(*list.begin())); |
184 | } |
185 | |
186 | TEST(ITensorListRefIteratorTest, Unboxed_GetFirstElement) { |
187 | auto vec = get_tensor_vector(); |
188 | at::ITensorListRef list(vec); |
189 | EXPECT_TRUE(vec[0].is_same(*list.begin())); |
190 | } |
191 | |
192 | TEST(ITensorListRefIteratorTest, Boxed_Equality) { |
193 | auto vec = get_tensor_vector(); |
194 | List<at::Tensor> boxed(vec); |
195 | at::ITensorListRef list(boxed); |
196 | EXPECT_EQ(list.begin(), list.begin()); |
197 | EXPECT_NE(list.begin(), list.end()); |
198 | EXPECT_NE(list.end(), list.begin()); |
199 | EXPECT_EQ(list.end(), list.end()); |
200 | } |
201 | |
202 | TEST(ITensorListRefIteratorTest, Unboxed_Equality) { |
203 | auto vec = get_tensor_vector(); |
204 | at::ITensorListRef list(vec); |
205 | EXPECT_EQ(list.begin(), list.begin()); |
206 | EXPECT_NE(list.begin(), list.end()); |
207 | EXPECT_NE(list.end(), list.begin()); |
208 | EXPECT_EQ(list.end(), list.end()); |
209 | } |
210 | |
211 | TEST(ITensorListRefIteratorTest, Boxed_Iterate) { |
212 | auto vec = get_tensor_vector(); |
213 | const List<at::Tensor> boxed(vec); |
214 | at::ITensorListRef list(boxed); |
215 | size_t i = 0; |
216 | for (const auto& t : list) { |
217 | EXPECT_TRUE(boxed[i++].is_same(t)); |
218 | } |
219 | EXPECT_EQ(i, list.size()); |
220 | } |
221 | |
222 | TEST(ITensorListRefIteratorTest, Unboxed_Iterate) { |
223 | auto vec = get_tensor_vector(); |
224 | at::ITensorListRef list(vec); |
225 | size_t i = 0; |
226 | for (const auto& t : list) { |
227 | EXPECT_TRUE(vec[i++].is_same(t)); |
228 | } |
229 | EXPECT_EQ(i, list.size()); |
230 | } |
231 | |
232 | TEST(IOptTensorListRefTest, Boxed_Iterate) { |
233 | auto vec = get_boxed_opt_tensor_vector(); |
234 | const List<optional<at::Tensor>> boxed(vec); |
235 | at::IOptTensorListRef list(boxed); |
236 | size_t i = 0; |
237 | for (const auto t : list) { |
238 | EXPECT_EQ(boxed[i].has_value(), t.has_value()); |
239 | if (t.has_value()) { |
240 | EXPECT_TRUE((*boxed[i]).is_same(*t)); |
241 | } |
242 | i++; |
243 | } |
244 | EXPECT_EQ(i, list.size()); |
245 | } |
246 | |
247 | TEST(IOptTensorListRefTest, Unboxed_Iterate) { |
248 | auto vec = get_unboxed_opt_tensor_vector(); |
249 | at::ArrayRef<at::OptionalTensorRef> unboxed(vec); |
250 | at::IOptTensorListRef list(unboxed); |
251 | size_t i = 0; |
252 | for (const auto t : list) { |
253 | EXPECT_EQ(unboxed[i].has_value(), t.has_value()); |
254 | if (t.has_value()) { |
255 | EXPECT_TRUE((*unboxed[i]).is_same(*t)); |
256 | } |
257 | i++; |
258 | } |
259 | EXPECT_EQ(i, list.size()); |
260 | } |
261 | |