1#include <ATen/Functions.h>
2#include <ATen/core/IListRef.h>
3#include <ATen/core/Tensor.h>
4#include <gtest/gtest.h>
5#include <algorithm>
6#include <iterator>
7
8using namespace c10;
9
10static std::vector<at::Tensor> get_tensor_vector() {
11 std::vector<at::Tensor> tensors;
12 const size_t SIZE = 5;
13 for (size_t i = 0; i < SIZE; i++) {
14 tensors.emplace_back(at::empty({0}));
15 }
16 return tensors;
17}
18
19static std::vector<optional<at::Tensor>> get_boxed_opt_tensor_vector() {
20 std::vector<optional<at::Tensor>> optional_tensors;
21 const size_t SIZE = 5;
22 for (size_t i = 0; i < SIZE * 2; i++) {
23 auto opt_tensor = (i % 2 == 0) ? optional<at::Tensor>(at::empty({0})) : nullopt;
24 optional_tensors.emplace_back(opt_tensor);
25 }
26 return optional_tensors;
27}
28
29static std::vector<at::OptionalTensorRef> get_unboxed_opt_tensor_vector() {
30 std::vector<at::OptionalTensorRef> optional_tensors;
31 const size_t SIZE = 5;
32 for (size_t i = 0; i < SIZE * 2; i++) {
33 auto opt_tensor = (i % 2 == 0) ? at::OptionalTensorRef(at::empty({0}))
34 : at::OptionalTensorRef();
35 optional_tensors.emplace_back(opt_tensor);
36 }
37 return optional_tensors;
38}
39
40template <typename T>
41void check_elements_same(at::ITensorListRef list, const T& thing, int use_count) {
42 EXPECT_EQ(thing.size(), list.size());
43 size_t i = 0;
44 for (const auto& t : list) {
45 const at::Tensor& other = thing[i];
46 EXPECT_EQ(other.use_count(), use_count);
47 EXPECT_TRUE(other.is_same(t));
48 i++;
49 }
50}
51
52TEST(ITensorListRefTest, CtorEmpty_IsNone_Throws) {
53 at::ITensorListRef list;
54 EXPECT_TRUE(list.isNone());
55 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
56 EXPECT_THROW(list.size(), c10::Error);
57}
58
59TEST(ITensorListRefTest, CtorBoxed_IsBoxed) {
60 auto vec = get_tensor_vector();
61 List<at::Tensor> boxed(vec);
62 at::ITensorListRef list(boxed);
63 EXPECT_TRUE(list.isBoxed());
64}
65
66TEST(ITensorListRefTest, CtorUnboxed_IsUnboxed) {
67 auto vec = get_tensor_vector();
68 at::ArrayRef<at::Tensor> unboxed(vec);
69 at::ITensorListRef list(unboxed);
70 EXPECT_TRUE(list.isUnboxed());
71}
72
73TEST(ITensorListRefTest, CtorUnboxedIndirect_IsUnboxed) {
74 auto vec = get_tensor_vector();
75 auto check_is_unboxed = [](at::ITensorListRef list) {
76 EXPECT_TRUE(list.isUnboxed());
77 };
78 check_is_unboxed(at::ITensorListRef{vec[0]});
79 check_is_unboxed(at::ITensorListRef{vec.data(), vec.size()});
80 check_is_unboxed(at::ITensorListRef{vec.data(), vec.data() + vec.size()});
81 check_is_unboxed(vec);
82 check_is_unboxed({vec[0], vec[1], vec[2]});
83}
84
85TEST(ITensorListRefTest, CtorTemp_IsUnboxed) {
86 auto check_is_unboxed = [](at::ITensorListRef list) {
87 EXPECT_TRUE(list.isUnboxed());
88 };
89
90 auto vec = get_tensor_vector();
91 check_is_unboxed({vec[0], vec[1]});
92}
93
94TEST(ITensorListRefTest, Boxed_GetConstRefTensor) {
95 auto vec = get_tensor_vector();
96 // We need 'boxed' to be 'const' here (and some other tests below)
97 // because 'List<Tensor>::operator[]' returns a 'ListElementReference'
98 // instead of returning a 'Tensor'. On the other hand,
99 // 'List<Tensor>::operator[] const' returns a 'const Tensor &'.
100 const List<at::Tensor> boxed(vec);
101 at::ITensorListRef list(boxed);
102 static_assert(
103 std::is_same<decltype(*list.begin()), const at::Tensor&>::value,
104 "Accessing elements from List<Tensor> through a ITensorListRef should be const references.");
105 EXPECT_TRUE(boxed[0].is_same(*list.begin()));
106 EXPECT_TRUE(boxed[1].is_same(*(++list.begin())));
107}
108
109TEST(ITensorListRefTest, Unboxed_GetConstRefTensor) {
110 auto vec = get_tensor_vector();
111 at::ITensorListRef list(vec);
112 static_assert(
113 std::is_same<decltype(*list.begin()), const at::Tensor&>::value,
114 "Accessing elements from ArrayRef<Tensor> through a ITensorListRef should be const references.");
115 EXPECT_TRUE(vec[0].is_same(*list.begin()));
116 EXPECT_TRUE(vec[1].is_same(*(++list.begin())));
117}
118
119TEST(ITensorListRefTest, Boxed_Equal) {
120 auto vec = get_tensor_vector();
121 List<at::Tensor> boxed(vec);
122 check_elements_same(boxed, vec, /* use_count= */ 2);
123}
124
125TEST(ITensorListRefTest, Unboxed_Equal) {
126 auto vec = get_tensor_vector();
127 check_elements_same(at::ArrayRef<at::Tensor>(vec), vec, /* use_count= */ 1);
128}
129
130TEST(ITensorListRefTest, UnboxedIndirect_Equal) {
131 // The 4 ref-count locations:
132 // 1. `vec`
133 // 2. `initializer_list` for `ITensorListRef`
134 // 3. `initializer_list` for `std::vector`
135 // 4. temporary `std::vector`
136 auto vec = get_tensor_vector();
137 // Implicit constructors
138 check_elements_same(vec[0], std::vector<at::Tensor>{vec[0]}, /* use_count= */ 3);
139 check_elements_same({vec.data(), vec.size()}, vec, /* use_count= */ 1);
140 check_elements_same({vec.data(), vec.data() + vec.size()}, vec, /* use_count= */ 1);
141 // Vector constructor
142 check_elements_same(vec, vec, /* use_count= */ 1);
143 // InitializerList constructor
144 check_elements_same({vec[0], vec[1], vec[2]}, std::vector<at::Tensor>{vec[0], vec[1], vec[2]}, /* use_count= */ 4);
145}
146
147TEST(ITensorListRefTest, BoxedMaterialize_Equal) {
148 auto vec = get_tensor_vector();
149 List<at::Tensor> boxed(vec);
150 at::ITensorListRef list(boxed);
151 auto materialized = list.materialize();
152 check_elements_same(list, vec, 2);
153 check_elements_same(list, materialized, 2);
154 check_elements_same(materialized, vec, 2);
155}
156
157TEST(ITensorListRefTest, UnboxedMaterialize_Equal) {
158 auto vec = get_tensor_vector();
159 at::ArrayRef<at::Tensor> unboxed(vec);
160 at::ITensorListRef list(unboxed);
161 auto materialized = list.materialize();
162 check_elements_same(list, vec, 1);
163 check_elements_same(list, materialized, 1);
164 check_elements_same(materialized, vec, 1);
165}
166
167TEST(ITensorListRefIteratorTest, CtorEmpty_ThrowsError) {
168 at::ITensorListRefIterator* it = new at::ITensorListRefIterator();
169 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
170 EXPECT_THROW(**it, c10::Error);
171
172#if defined(_MSC_VER) && _ITERATOR_DEBUG_LEVEL == 2
173 EXPECT_THROW({ delete it; }, c10::Error);
174#else
175 delete it;
176#endif
177}
178
179TEST(ITensorListRefIteratorTest, Boxed_GetFirstElement) {
180 auto vec = get_tensor_vector();
181 const List<at::Tensor> boxed(vec);
182 at::ITensorListRef list(boxed);
183 EXPECT_TRUE(boxed[0].is_same(*list.begin()));
184}
185
186TEST(ITensorListRefIteratorTest, Unboxed_GetFirstElement) {
187 auto vec = get_tensor_vector();
188 at::ITensorListRef list(vec);
189 EXPECT_TRUE(vec[0].is_same(*list.begin()));
190}
191
192TEST(ITensorListRefIteratorTest, Boxed_Equality) {
193 auto vec = get_tensor_vector();
194 List<at::Tensor> boxed(vec);
195 at::ITensorListRef list(boxed);
196 EXPECT_EQ(list.begin(), list.begin());
197 EXPECT_NE(list.begin(), list.end());
198 EXPECT_NE(list.end(), list.begin());
199 EXPECT_EQ(list.end(), list.end());
200}
201
202TEST(ITensorListRefIteratorTest, Unboxed_Equality) {
203 auto vec = get_tensor_vector();
204 at::ITensorListRef list(vec);
205 EXPECT_EQ(list.begin(), list.begin());
206 EXPECT_NE(list.begin(), list.end());
207 EXPECT_NE(list.end(), list.begin());
208 EXPECT_EQ(list.end(), list.end());
209}
210
211TEST(ITensorListRefIteratorTest, Boxed_Iterate) {
212 auto vec = get_tensor_vector();
213 const List<at::Tensor> boxed(vec);
214 at::ITensorListRef list(boxed);
215 size_t i = 0;
216 for (const auto& t : list) {
217 EXPECT_TRUE(boxed[i++].is_same(t));
218 }
219 EXPECT_EQ(i, list.size());
220}
221
222TEST(ITensorListRefIteratorTest, Unboxed_Iterate) {
223 auto vec = get_tensor_vector();
224 at::ITensorListRef list(vec);
225 size_t i = 0;
226 for (const auto& t : list) {
227 EXPECT_TRUE(vec[i++].is_same(t));
228 }
229 EXPECT_EQ(i, list.size());
230}
231
232TEST(IOptTensorListRefTest, Boxed_Iterate) {
233 auto vec = get_boxed_opt_tensor_vector();
234 const List<optional<at::Tensor>> boxed(vec);
235 at::IOptTensorListRef list(boxed);
236 size_t i = 0;
237 for (const auto t : list) {
238 EXPECT_EQ(boxed[i].has_value(), t.has_value());
239 if (t.has_value()) {
240 EXPECT_TRUE((*boxed[i]).is_same(*t));
241 }
242 i++;
243 }
244 EXPECT_EQ(i, list.size());
245}
246
247TEST(IOptTensorListRefTest, Unboxed_Iterate) {
248 auto vec = get_unboxed_opt_tensor_vector();
249 at::ArrayRef<at::OptionalTensorRef> unboxed(vec);
250 at::IOptTensorListRef list(unboxed);
251 size_t i = 0;
252 for (const auto t : list) {
253 EXPECT_EQ(unboxed[i].has_value(), t.has_value());
254 if (t.has_value()) {
255 EXPECT_TRUE((*unboxed[i]).is_same(*t));
256 }
257 i++;
258 }
259 EXPECT_EQ(i, list.size());
260}
261