1 | #include <gtest/gtest.h> |
2 | |
3 | #include <torch/torch.h> |
4 | |
5 | #include <test/cpp/api/support.h> |
6 | |
7 | using namespace torch::indexing; |
8 | using namespace torch::test; |
9 | |
10 | TEST(TensorIndexingTest, Slice) { |
11 | Slice slice(1, 2, 3); |
12 | ASSERT_EQ(slice.start(), 1); |
13 | ASSERT_EQ(slice.stop(), 2); |
14 | ASSERT_EQ(slice.step(), 3); |
15 | |
16 | ASSERT_EQ(c10::str(slice), "1:2:3" ); |
17 | } |
18 | |
19 | TEST(TensorIndexingTest, TensorIndex) { |
20 | { |
21 | std::vector<TensorIndex> indices = { |
22 | None, |
23 | "..." , |
24 | Ellipsis, |
25 | 0, |
26 | true, |
27 | Slice(1, None, 2), |
28 | torch::tensor({1, 2})}; |
29 | ASSERT_TRUE(indices[0].is_none()); |
30 | ASSERT_TRUE(indices[1].is_ellipsis()); |
31 | ASSERT_TRUE(indices[2].is_ellipsis()); |
32 | ASSERT_TRUE(indices[3].is_integer()); |
33 | ASSERT_TRUE(indices[3].integer() == 0); |
34 | ASSERT_TRUE(indices[4].is_boolean()); |
35 | ASSERT_TRUE(indices[4].boolean() == true); |
36 | ASSERT_TRUE(indices[5].is_slice()); |
37 | ASSERT_TRUE(indices[5].slice().start() == 1); |
38 | ASSERT_TRUE(indices[5].slice().stop() == INDEX_MAX); |
39 | ASSERT_TRUE(indices[5].slice().step() == 2); |
40 | ASSERT_TRUE(indices[6].is_tensor()); |
41 | ASSERT_TRUE(torch::equal(indices[6].tensor(), torch::tensor({1, 2}))); |
42 | } |
43 | |
44 | ASSERT_THROWS_WITH( |
45 | TensorIndex(".." ), |
46 | "Expected \"...\" to represent an ellipsis index, but got \"..\"" ); |
47 | |
48 | { |
49 | std::vector<TensorIndex> indices = { |
50 | None, "..." , Ellipsis, 0, true, Slice(1, None, 2)}; |
51 | ASSERT_EQ( |
52 | c10::str(indices), |
53 | c10::str("(None, ..., ..., 0, true, 1:" , INDEX_MAX, ":2)" )); |
54 | ASSERT_EQ(c10::str(indices[0]), "None" ); |
55 | ASSERT_EQ(c10::str(indices[1]), "..." ); |
56 | ASSERT_EQ(c10::str(indices[2]), "..." ); |
57 | ASSERT_EQ(c10::str(indices[3]), "0" ); |
58 | ASSERT_EQ(c10::str(indices[4]), "true" ); |
59 | ASSERT_EQ(c10::str(indices[5]), c10::str("1:" , INDEX_MAX, ":2" )); |
60 | } |
61 | |
62 | ASSERT_EQ( |
63 | c10::str(std::vector<TensorIndex>({Slice()})), |
64 | c10::str("(0:" , INDEX_MAX, ":1)" )); |
65 | ASSERT_EQ( |
66 | c10::str(std::vector<TensorIndex>({Slice(None, None)})), |
67 | c10::str("(0:" , INDEX_MAX, ":1)" )); |
68 | ASSERT_EQ( |
69 | c10::str(std::vector<TensorIndex>({Slice(None, None, None)})), |
70 | c10::str("(0:" , INDEX_MAX, ":1)" )); |
71 | |
72 | ASSERT_EQ( |
73 | c10::str(std::vector<TensorIndex>({Slice(1, None)})), |
74 | c10::str("(1:" , INDEX_MAX, ":1)" )); |
75 | ASSERT_EQ( |
76 | c10::str(std::vector<TensorIndex>({Slice(1, None, None)})), |
77 | c10::str("(1:" , INDEX_MAX, ":1)" )); |
78 | ASSERT_EQ( |
79 | c10::str(std::vector<TensorIndex>({Slice(None, 3)})), |
80 | c10::str("(0:3:1)" )); |
81 | ASSERT_EQ( |
82 | c10::str(std::vector<TensorIndex>({Slice(None, 3, None)})), |
83 | c10::str("(0:3:1)" )); |
84 | ASSERT_EQ( |
85 | c10::str(std::vector<TensorIndex>({Slice(None, None, 2)})), |
86 | c10::str("(0:" , INDEX_MAX, ":2)" )); |
87 | ASSERT_EQ( |
88 | c10::str(std::vector<TensorIndex>({Slice(None, None, -1)})), |
89 | c10::str("(" , INDEX_MAX, ":" , INDEX_MIN, ":-1)" )); |
90 | |
91 | ASSERT_EQ( |
92 | c10::str(std::vector<TensorIndex>({Slice(1, 3)})), c10::str("(1:3:1)" )); |
93 | ASSERT_EQ( |
94 | c10::str(std::vector<TensorIndex>({Slice(1, None, 2)})), |
95 | c10::str("(1:" , INDEX_MAX, ":2)" )); |
96 | ASSERT_EQ( |
97 | c10::str(std::vector<TensorIndex>({Slice(1, None, -1)})), |
98 | c10::str("(1:" , INDEX_MIN, ":-1)" )); |
99 | ASSERT_EQ( |
100 | c10::str(std::vector<TensorIndex>({Slice(None, 3, 2)})), |
101 | c10::str("(0:3:2)" )); |
102 | ASSERT_EQ( |
103 | c10::str(std::vector<TensorIndex>({Slice(None, 3, -1)})), |
104 | c10::str("(" , INDEX_MAX, ":3:-1)" )); |
105 | |
106 | ASSERT_EQ( |
107 | c10::str(std::vector<TensorIndex>({Slice(1, 3, 2)})), |
108 | c10::str("(1:3:2)" )); |
109 | } |
110 | |
111 | TEST(TensorIndexingTest, TestNoIndices) { |
112 | torch::Tensor tensor = torch::randn({20, 20}); |
113 | torch::Tensor value = torch::randn({20, 20}); |
114 | std::vector<TensorIndex> indices; |
115 | |
116 | ASSERT_THROWS_WITH( |
117 | tensor.index({}), |
118 | "Passing an empty index list to Tensor::index() is not valid syntax" ); |
119 | ASSERT_THROWS_WITH( |
120 | tensor.index_put_({}, 1), |
121 | "Passing an empty index list to Tensor::index_put_() is not valid syntax" ); |
122 | ASSERT_THROWS_WITH( |
123 | tensor.index_put_({}, value), |
124 | "Passing an empty index list to Tensor::index_put_() is not valid syntax" ); |
125 | |
126 | ASSERT_THROWS_WITH( |
127 | tensor.index(indices), |
128 | "Passing an empty index list to Tensor::index() is not valid syntax" ); |
129 | ASSERT_THROWS_WITH( |
130 | tensor.index_put_(indices, 1), |
131 | "Passing an empty index list to Tensor::index_put_() is not valid syntax" ); |
132 | ASSERT_THROWS_WITH( |
133 | tensor.index_put_(indices, value), |
134 | "Passing an empty index list to Tensor::index_put_() is not valid syntax" ); |
135 | } |
136 | |
137 | TEST(TensorIndexingTest, TestAdvancedIndexingWithListOfTensor) { |
138 | { |
139 | torch::Tensor tensor = torch::randn({20, 20}); |
140 | torch::Tensor index = torch::arange(10, torch::kLong).cpu(); |
141 | torch::Tensor result = at::index(tensor, {index}); |
142 | torch::Tensor result_with_init_list = tensor.index({index}); |
143 | ASSERT_TRUE(result.equal(result_with_init_list)); |
144 | } |
145 | { |
146 | torch::Tensor tensor = torch::randn({20, 20}); |
147 | torch::Tensor index = torch::arange(10, torch::kLong).cpu(); |
148 | torch::Tensor result = at::index_put_(tensor, {index}, torch::ones({20})); |
149 | torch::Tensor result_with_init_list = |
150 | tensor.index_put_({index}, torch::ones({20})); |
151 | ASSERT_TRUE(result.equal(result_with_init_list)); |
152 | } |
153 | { |
154 | torch::Tensor tensor = torch::randn({20, 20}); |
155 | torch::Tensor index = torch::arange(10, torch::kLong).cpu(); |
156 | torch::Tensor result = |
157 | at::index_put_(tensor, {index}, torch::ones({1, 20})); |
158 | torch::Tensor result_with_init_list = |
159 | tensor.index_put_({index}, torch::ones({1, 20})); |
160 | ASSERT_TRUE(result.equal(result_with_init_list)); |
161 | } |
162 | } |
163 | |
164 | TEST(TensorIndexingTest, TestSingleInt) { |
165 | auto v = torch::randn({5, 7, 3}); |
166 | ASSERT_EQ(v.index({4}).sizes(), torch::IntArrayRef({7, 3})); |
167 | } |
168 | |
169 | TEST(TensorIndexingTest, TestMultipleInt) { |
170 | auto v = torch::randn({5, 7, 3}); |
171 | ASSERT_EQ(v.index({4}).sizes(), torch::IntArrayRef({7, 3})); |
172 | ASSERT_EQ(v.index({4, Slice(), 1}).sizes(), torch::IntArrayRef({7})); |
173 | |
174 | // To show that `.index_put_` works |
175 | v.index_put_({4, 3, 1}, 0); |
176 | ASSERT_EQ(v.index({4, 3, 1}).item<double>(), 0); |
177 | } |
178 | |
179 | TEST(TensorIndexingTest, TestNone) { |
180 | auto v = torch::randn({5, 7, 3}); |
181 | ASSERT_EQ(v.index({None}).sizes(), torch::IntArrayRef({1, 5, 7, 3})); |
182 | ASSERT_EQ(v.index({Slice(), None}).sizes(), torch::IntArrayRef({5, 1, 7, 3})); |
183 | ASSERT_EQ( |
184 | v.index({Slice(), None, None}).sizes(), |
185 | torch::IntArrayRef({5, 1, 1, 7, 3})); |
186 | ASSERT_EQ(v.index({"..." , None}).sizes(), torch::IntArrayRef({5, 7, 3, 1})); |
187 | } |
188 | |
189 | TEST(TensorIndexingTest, TestStep) { |
190 | auto v = torch::arange(10); |
191 | assert_tensor_equal(v.index({Slice(None, None, 1)}), v); |
192 | assert_tensor_equal( |
193 | v.index({Slice(None, None, 2)}), torch::tensor({0, 2, 4, 6, 8})); |
194 | assert_tensor_equal( |
195 | v.index({Slice(None, None, 3)}), torch::tensor({0, 3, 6, 9})); |
196 | assert_tensor_equal(v.index({Slice(None, None, 11)}), torch::tensor({0})); |
197 | assert_tensor_equal(v.index({Slice(1, 6, 2)}), torch::tensor({1, 3, 5})); |
198 | } |
199 | |
200 | TEST(TensorIndexingTest, TestStepAssignment) { |
201 | auto v = torch::zeros({4, 4}); |
202 | v.index_put_({0, Slice(1, None, 2)}, torch::tensor({3., 4.})); |
203 | assert_tensor_equal(v.index({0}), torch::tensor({0., 3., 0., 4.})); |
204 | assert_tensor_equal(v.index({Slice(1, None)}).sum(), torch::tensor(0)); |
205 | } |
206 | |
207 | TEST(TensorIndexingTest, TestBoolIndices) { |
208 | { |
209 | auto v = torch::randn({5, 7, 3}); |
210 | auto boolIndices = |
211 | torch::tensor({true, false, true, true, false}, torch::kBool); |
212 | ASSERT_EQ(v.index({boolIndices}).sizes(), torch::IntArrayRef({3, 7, 3})); |
213 | assert_tensor_equal( |
214 | v.index({boolIndices}), |
215 | torch::stack({v.index({0}), v.index({2}), v.index({3})})); |
216 | } |
217 | { |
218 | auto v = torch::tensor({true, false, true}, torch::kBool); |
219 | auto boolIndices = torch::tensor({true, false, false}, torch::kBool); |
220 | auto uint8Indices = torch::tensor({1, 0, 0}, torch::kUInt8); |
221 | |
222 | { |
223 | WarningCapture warnings; |
224 | |
225 | ASSERT_EQ( |
226 | v.index({boolIndices}).sizes(), v.index({uint8Indices}).sizes()); |
227 | assert_tensor_equal(v.index({boolIndices}), v.index({uint8Indices})); |
228 | assert_tensor_equal( |
229 | v.index({boolIndices}), torch::tensor({true}, torch::kBool)); |
230 | |
231 | ASSERT_EQ( |
232 | count_substr_occurrences( |
233 | warnings.str(), |
234 | "indexing with dtype torch.uint8 is now deprecated" ), |
235 | 2); |
236 | } |
237 | } |
238 | } |
239 | |
240 | TEST(TensorIndexingTest, TestBoolIndicesAccumulate) { |
241 | auto mask = torch::zeros({10}, torch::kBool); |
242 | auto y = torch::ones({10, 10}); |
243 | y.index_put_({mask}, {y.index({mask})}, /*accumulate=*/true); |
244 | assert_tensor_equal(y, torch::ones({10, 10})); |
245 | } |
246 | |
247 | TEST(TensorIndexingTest, TestMultipleBoolIndices) { |
248 | auto v = torch::randn({5, 7, 3}); |
249 | // note: these broadcast together and are transposed to the first dim |
250 | auto mask1 = torch::tensor({1, 0, 1, 1, 0}, torch::kBool); |
251 | auto mask2 = torch::tensor({1, 1, 1}, torch::kBool); |
252 | ASSERT_EQ( |
253 | v.index({mask1, Slice(), mask2}).sizes(), torch::IntArrayRef({3, 7})); |
254 | } |
255 | |
256 | TEST(TensorIndexingTest, TestByteMask) { |
257 | { |
258 | auto v = torch::randn({5, 7, 3}); |
259 | auto mask = torch::tensor({1, 0, 1, 1, 0}, torch::kByte); |
260 | { |
261 | WarningCapture warnings; |
262 | |
263 | ASSERT_EQ(v.index({mask}).sizes(), torch::IntArrayRef({3, 7, 3})); |
264 | assert_tensor_equal(v.index({mask}), torch::stack({v[0], v[2], v[3]})); |
265 | |
266 | ASSERT_EQ( |
267 | count_substr_occurrences( |
268 | warnings.str(), |
269 | "indexing with dtype torch.uint8 is now deprecated" ), |
270 | 2); |
271 | } |
272 | } |
273 | { |
274 | auto v = torch::tensor({1.}); |
275 | assert_tensor_equal(v.index({v == 0}), torch::randn({0})); |
276 | } |
277 | } |
278 | |
279 | TEST(TensorIndexingTest, TestByteMaskAccumulate) { |
280 | auto mask = torch::zeros({10}, torch::kUInt8); |
281 | auto y = torch::ones({10, 10}); |
282 | { |
283 | WarningCapture warnings; |
284 | |
285 | y.index_put_({mask}, y.index({mask}), /*accumulate=*/true); |
286 | assert_tensor_equal(y, torch::ones({10, 10})); |
287 | |
288 | ASSERT_EQ( |
289 | count_substr_occurrences( |
290 | warnings.str(), |
291 | "indexing with dtype torch.uint8 is now deprecated" ), |
292 | 2); |
293 | } |
294 | } |
295 | |
296 | TEST(TensorIndexingTest, TestMultipleByteMask) { |
297 | auto v = torch::randn({5, 7, 3}); |
298 | // note: these broadcast together and are transposed to the first dim |
299 | auto mask1 = torch::tensor({1, 0, 1, 1, 0}, torch::kByte); |
300 | auto mask2 = torch::tensor({1, 1, 1}, torch::kByte); |
301 | { |
302 | WarningCapture warnings; |
303 | |
304 | ASSERT_EQ( |
305 | v.index({mask1, Slice(), mask2}).sizes(), torch::IntArrayRef({3, 7})); |
306 | |
307 | ASSERT_EQ( |
308 | count_substr_occurrences( |
309 | warnings.str(), |
310 | "indexing with dtype torch.uint8 is now deprecated" ), |
311 | 2); |
312 | } |
313 | } |
314 | |
315 | TEST(TensorIndexingTest, TestByteMask2d) { |
316 | auto v = torch::randn({5, 7, 3}); |
317 | auto c = torch::randn({5, 7}); |
318 | int64_t num_ones = (c > 0).sum().item().to<int64_t>(); |
319 | auto r = v.index({c > 0}); |
320 | ASSERT_EQ(r.sizes(), torch::IntArrayRef({num_ones, 3})); |
321 | } |
322 | |
323 | TEST(TensorIndexingTest, TestIntIndices) { |
324 | auto v = torch::randn({5, 7, 3}); |
325 | ASSERT_EQ( |
326 | v.index({torch::tensor({0, 4, 2})}).sizes(), |
327 | torch::IntArrayRef({3, 7, 3})); |
328 | ASSERT_EQ( |
329 | v.index({Slice(), torch::tensor({0, 4, 2})}).sizes(), |
330 | torch::IntArrayRef({5, 3, 3})); |
331 | ASSERT_EQ( |
332 | v.index({Slice(), torch::tensor({{0, 1}, {4, 3}})}).sizes(), |
333 | torch::IntArrayRef({5, 2, 2, 3})); |
334 | } |
335 | |
336 | TEST(TensorIndexingTest, TestIntIndices2d) { |
337 | // From the NumPy indexing example |
338 | auto x = torch::arange(0, 12, torch::kLong).view({4, 3}); |
339 | auto rows = torch::tensor({{0, 0}, {3, 3}}); |
340 | auto columns = torch::tensor({{0, 2}, {0, 2}}); |
341 | assert_tensor_equal( |
342 | x.index({rows, columns}), torch::tensor({{0, 2}, {9, 11}})); |
343 | } |
344 | |
345 | TEST(TensorIndexingTest, TestIntIndicesBroadcast) { |
346 | // From the NumPy indexing example |
347 | auto x = torch::arange(0, 12, torch::kLong).view({4, 3}); |
348 | auto rows = torch::tensor({0, 3}); |
349 | auto columns = torch::tensor({0, 2}); |
350 | auto result = x.index({rows.index({Slice(), None}), columns}); |
351 | assert_tensor_equal(result, torch::tensor({{0, 2}, {9, 11}})); |
352 | } |
353 | |
354 | TEST(TensorIndexingTest, TestEmptyIndex) { |
355 | auto x = torch::arange(0, 12).view({4, 3}); |
356 | auto idx = torch::tensor({}, torch::kLong); |
357 | ASSERT_EQ(x.index({idx}).numel(), 0); |
358 | |
359 | // empty assignment should have no effect but not throw an exception |
360 | auto y = x.clone(); |
361 | y.index_put_({idx}, -1); |
362 | assert_tensor_equal(x, y); |
363 | |
364 | auto mask = torch::zeros({4, 3}, torch::kBool); |
365 | y.index_put_({mask}, -1); |
366 | assert_tensor_equal(x, y); |
367 | } |
368 | |
369 | TEST(TensorIndexingTest, TestEmptyNdimIndex) { |
370 | torch::Device device(torch::kCPU); |
371 | { |
372 | auto x = torch::randn({5}, device); |
373 | assert_tensor_equal( |
374 | torch::empty({0, 2}, device), |
375 | x.index({torch::empty( |
376 | {0, 2}, torch::TensorOptions(torch::kInt64).device(device))})); |
377 | } |
378 | { |
379 | auto x = torch::randn({2, 3, 4, 5}, device); |
380 | assert_tensor_equal( |
381 | torch::empty({2, 0, 6, 4, 5}, device), |
382 | x.index( |
383 | {Slice(), |
384 | torch::empty( |
385 | {0, 6}, torch::TensorOptions(torch::kInt64).device(device))})); |
386 | } |
387 | { |
388 | auto x = torch::empty({10, 0}); |
389 | ASSERT_EQ( |
390 | x.index({torch::tensor({1, 2})}).sizes(), torch::IntArrayRef({2, 0})); |
391 | ASSERT_EQ( |
392 | x.index( |
393 | {torch::tensor({}, torch::kLong), torch::tensor({}, torch::kLong)}) |
394 | .sizes(), |
395 | torch::IntArrayRef({0})); |
396 | ASSERT_THROWS_WITH( |
397 | x.index({Slice(), torch::tensor({0, 1})}), "for dimension with size 0" ); |
398 | } |
399 | } |
400 | |
401 | TEST(TensorIndexingTest, TestEmptyNdimIndex_CUDA) { |
402 | torch::Device device(torch::kCUDA); |
403 | { |
404 | auto x = torch::randn({5}, device); |
405 | assert_tensor_equal( |
406 | torch::empty({0, 2}, device), |
407 | x.index({torch::empty( |
408 | {0, 2}, torch::TensorOptions(torch::kInt64).device(device))})); |
409 | } |
410 | { |
411 | auto x = torch::randn({2, 3, 4, 5}, device); |
412 | assert_tensor_equal( |
413 | torch::empty({2, 0, 6, 4, 5}, device), |
414 | x.index( |
415 | {Slice(), |
416 | torch::empty( |
417 | {0, 6}, torch::TensorOptions(torch::kInt64).device(device))})); |
418 | } |
419 | } |
420 | |
421 | TEST(TensorIndexingTest, TestEmptyNdimIndexBool) { |
422 | torch::Device device(torch::kCPU); |
423 | auto x = torch::randn({5}, device); |
424 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
425 | ASSERT_THROW( |
426 | x.index({torch::empty( |
427 | {0, 2}, torch::TensorOptions(torch::kUInt8).device(device))}), |
428 | c10::Error); |
429 | } |
430 | |
431 | TEST(TensorIndexingTest, TestEmptyNdimIndexBool_CUDA) { |
432 | torch::Device device(torch::kCUDA); |
433 | auto x = torch::randn({5}, device); |
434 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
435 | ASSERT_THROW( |
436 | x.index({torch::empty( |
437 | {0, 2}, torch::TensorOptions(torch::kUInt8).device(device))}), |
438 | c10::Error); |
439 | } |
440 | |
441 | TEST(TensorIndexingTest, TestEmptySlice) { |
442 | torch::Device device(torch::kCPU); |
443 | auto x = torch::randn({2, 3, 4, 5}, device); |
444 | auto y = x.index({Slice(), Slice(), Slice(), 1}); |
445 | auto z = y.index({Slice(), Slice(1, 1), Slice()}); |
446 | ASSERT_EQ(z.sizes(), torch::IntArrayRef({2, 0, 4})); |
447 | // this isn't technically necessary, but matches NumPy stride calculations. |
448 | ASSERT_EQ(z.strides(), torch::IntArrayRef({60, 20, 5})); |
449 | ASSERT_TRUE(z.is_contiguous()); |
450 | } |
451 | |
452 | TEST(TensorIndexingTest, TestEmptySlice_CUDA) { |
453 | torch::Device device(torch::kCUDA); |
454 | auto x = torch::randn({2, 3, 4, 5}, device); |
455 | auto y = x.index({Slice(), Slice(), Slice(), 1}); |
456 | auto z = y.index({Slice(), Slice(1, 1), Slice()}); |
457 | ASSERT_EQ(z.sizes(), torch::IntArrayRef({2, 0, 4})); |
458 | // this isn't technically necessary, but matches NumPy stride calculations. |
459 | ASSERT_EQ(z.strides(), torch::IntArrayRef({60, 20, 5})); |
460 | ASSERT_TRUE(z.is_contiguous()); |
461 | } |
462 | |
463 | TEST(TensorIndexingTest, TestIndexGetitemCopyBoolsSlices) { |
464 | auto true_tensor = torch::tensor(1, torch::kUInt8); |
465 | auto false_tensor = torch::tensor(0, torch::kUInt8); |
466 | |
467 | std::vector<torch::Tensor> tensors = {torch::randn({2, 3}), torch::tensor(3)}; |
468 | |
469 | for (auto& a : tensors) { |
470 | ASSERT_NE(a.data_ptr(), a.index({true}).data_ptr()); |
471 | { |
472 | std::vector<int64_t> sizes = {0}; |
473 | sizes.insert(sizes.end(), a.sizes().begin(), a.sizes().end()); |
474 | assert_tensor_equal(torch::empty(sizes), a.index({false})); |
475 | } |
476 | ASSERT_NE(a.data_ptr(), a.index({true_tensor}).data_ptr()); |
477 | { |
478 | std::vector<int64_t> sizes = {0}; |
479 | sizes.insert(sizes.end(), a.sizes().begin(), a.sizes().end()); |
480 | assert_tensor_equal(torch::empty(sizes), a.index({false_tensor})); |
481 | } |
482 | ASSERT_EQ(a.data_ptr(), a.index({None}).data_ptr()); |
483 | ASSERT_EQ(a.data_ptr(), a.index({"..." }).data_ptr()); |
484 | } |
485 | } |
486 | |
487 | TEST(TensorIndexingTest, TestIndexSetitemBoolsSlices) { |
488 | auto true_tensor = torch::tensor(1, torch::kUInt8); |
489 | auto false_tensor = torch::tensor(0, torch::kUInt8); |
490 | |
491 | std::vector<torch::Tensor> tensors = {torch::randn({2, 3}), torch::tensor(3)}; |
492 | |
493 | for (auto& a : tensors) { |
494 | // prefix with a 1,1, to ensure we are compatible with numpy which cuts off |
495 | // prefix 1s (some of these ops already prefix a 1 to the size) |
496 | auto neg_ones = torch::ones_like(a) * -1; |
497 | auto neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0); |
498 | a.index_put_({true}, neg_ones_expanded); |
499 | assert_tensor_equal(a, neg_ones); |
500 | a.index_put_({false}, 5); |
501 | assert_tensor_equal(a, neg_ones); |
502 | a.index_put_({true_tensor}, neg_ones_expanded * 2); |
503 | assert_tensor_equal(a, neg_ones * 2); |
504 | a.index_put_({false_tensor}, 5); |
505 | assert_tensor_equal(a, neg_ones * 2); |
506 | a.index_put_({None}, neg_ones_expanded * 3); |
507 | assert_tensor_equal(a, neg_ones * 3); |
508 | a.index_put_({"..." }, neg_ones_expanded * 4); |
509 | assert_tensor_equal(a, neg_ones * 4); |
510 | if (a.dim() == 0) { |
511 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
512 | ASSERT_THROW(a.index_put_({Slice()}, neg_ones_expanded * 5), c10::Error); |
513 | } |
514 | } |
515 | } |
516 | |
517 | TEST(TensorIndexingTest, TestIndexScalarWithBoolMask) { |
518 | torch::Device device(torch::kCPU); |
519 | |
520 | auto a = torch::tensor(1, device); |
521 | auto uintMask = |
522 | torch::tensor(true, torch::TensorOptions(torch::kUInt8).device(device)); |
523 | auto boolMask = |
524 | torch::tensor(true, torch::TensorOptions(torch::kBool).device(device)); |
525 | assert_tensor_equal(a.index({uintMask}), a.index({boolMask})); |
526 | ASSERT_EQ(a.index({uintMask}).dtype(), a.index({boolMask}).dtype()); |
527 | |
528 | a = torch::tensor(true, torch::TensorOptions(torch::kBool).device(device)); |
529 | assert_tensor_equal(a.index({uintMask}), a.index({boolMask})); |
530 | ASSERT_EQ(a.index({uintMask}).dtype(), a.index({boolMask}).dtype()); |
531 | } |
532 | |
533 | TEST(TensorIndexingTest, TestIndexScalarWithBoolMask_CUDA) { |
534 | torch::Device device(torch::kCUDA); |
535 | |
536 | auto a = torch::tensor(1, device); |
537 | auto uintMask = |
538 | torch::tensor(true, torch::TensorOptions(torch::kUInt8).device(device)); |
539 | auto boolMask = |
540 | torch::tensor(true, torch::TensorOptions(torch::kBool).device(device)); |
541 | assert_tensor_equal(a.index({uintMask}), a.index({boolMask})); |
542 | ASSERT_EQ(a.index({uintMask}).dtype(), a.index({boolMask}).dtype()); |
543 | |
544 | a = torch::tensor(true, torch::TensorOptions(torch::kBool).device(device)); |
545 | assert_tensor_equal(a.index({uintMask}), a.index({boolMask})); |
546 | ASSERT_EQ(a.index({uintMask}).dtype(), a.index({boolMask}).dtype()); |
547 | } |
548 | |
549 | TEST(TensorIndexingTest, TestSetitemExpansionError) { |
550 | auto true_tensor = torch::tensor(true); |
551 | auto a = torch::randn({2, 3}); |
552 | // check prefix with non-1s doesn't work |
553 | std::vector<int64_t> tensor_sizes{5, 1}; |
554 | tensor_sizes.insert(tensor_sizes.end(), a.sizes().begin(), a.sizes().end()); |
555 | auto a_expanded = a.expand(tensor_sizes); |
556 | // NumPy: ValueError |
557 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
558 | ASSERT_THROW(a.index_put_({true}, a_expanded), c10::Error); |
559 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
560 | ASSERT_THROW(a.index_put_({true_tensor}, a_expanded), c10::Error); |
561 | } |
562 | |
563 | TEST(TensorIndexingTest, TestGetitemScalars) { |
564 | auto zero = torch::tensor(0, torch::kInt64); |
565 | auto one = torch::tensor(1, torch::kInt64); |
566 | |
567 | // non-scalar indexed with scalars |
568 | auto a = torch::randn({2, 3}); |
569 | assert_tensor_equal(a.index({0}), a.index({zero})); |
570 | assert_tensor_equal(a.index({0}).index({1}), a.index({zero}).index({one})); |
571 | assert_tensor_equal(a.index({0, 1}), a.index({zero, one})); |
572 | assert_tensor_equal(a.index({0, one}), a.index({zero, 1})); |
573 | |
574 | // indexing by a scalar should slice (not copy) |
575 | ASSERT_EQ(a.index({0, 1}).data_ptr(), a.index({zero, one}).data_ptr()); |
576 | ASSERT_EQ(a.index({1}).data_ptr(), a.index({one.to(torch::kInt)}).data_ptr()); |
577 | ASSERT_EQ( |
578 | a.index({1}).data_ptr(), a.index({one.to(torch::kShort)}).data_ptr()); |
579 | |
580 | // scalar indexed with scalar |
581 | auto r = torch::randn({}); |
582 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
583 | ASSERT_THROW(r.index({Slice()}), c10::Error); |
584 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
585 | ASSERT_THROW(r.index({zero}), c10::Error); |
586 | assert_tensor_equal(r, r.index({"..." })); |
587 | } |
588 | |
589 | TEST(TensorIndexingTest, TestSetitemScalars) { |
590 | auto zero = torch::tensor(0, torch::kInt64); |
591 | |
592 | // non-scalar indexed with scalars |
593 | auto a = torch::randn({2, 3}); |
594 | auto a_set_with_number = a.clone(); |
595 | auto a_set_with_scalar = a.clone(); |
596 | auto b = torch::randn({3}); |
597 | |
598 | a_set_with_number.index_put_({0}, b); |
599 | a_set_with_scalar.index_put_({zero}, b); |
600 | assert_tensor_equal(a_set_with_number, a_set_with_scalar); |
601 | a.index_put_({1, zero}, 7.7); |
602 | ASSERT_TRUE(a.index({1, 0}).allclose(torch::tensor(7.7))); |
603 | |
604 | // scalar indexed with scalars |
605 | auto r = torch::randn({}); |
606 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
607 | ASSERT_THROW(r.index_put_({Slice()}, 8.8), c10::Error); |
608 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
609 | ASSERT_THROW(r.index_put_({zero}, 8.8), c10::Error); |
610 | r.index_put_({"..." }, 9.9); |
611 | ASSERT_TRUE(r.allclose(torch::tensor(9.9))); |
612 | } |
613 | |
614 | TEST(TensorIndexingTest, TestBasicAdvancedCombined) { |
615 | // From the NumPy indexing example |
616 | auto x = torch::arange(0, 12).to(torch::kLong).view({4, 3}); |
617 | assert_tensor_equal( |
618 | x.index({Slice(1, 2), Slice(1, 3)}), |
619 | x.index({Slice(1, 2), torch::tensor({1, 2})})); |
620 | assert_tensor_equal( |
621 | x.index({Slice(1, 2), Slice(1, 3)}), torch::tensor({{4, 5}})); |
622 | |
623 | // Check that it is a copy |
624 | { |
625 | auto unmodified = x.clone(); |
626 | x.index({Slice(1, 2), torch::tensor({1, 2})}).zero_(); |
627 | assert_tensor_equal(x, unmodified); |
628 | } |
629 | |
630 | // But assignment should modify the original |
631 | { |
632 | auto unmodified = x.clone(); |
633 | x.index_put_({Slice(1, 2), torch::tensor({1, 2})}, 0); |
634 | assert_tensor_not_equal(x, unmodified); |
635 | } |
636 | } |
637 | |
638 | TEST(TensorIndexingTest, TestIntAssignment) { |
639 | { |
640 | auto x = torch::arange(0, 4).to(torch::kLong).view({2, 2}); |
641 | x.index_put_({1}, 5); |
642 | assert_tensor_equal(x, torch::tensor({{0, 1}, {5, 5}})); |
643 | } |
644 | |
645 | { |
646 | auto x = torch::arange(0, 4).to(torch::kLong).view({2, 2}); |
647 | x.index_put_({1}, torch::arange(5, 7).to(torch::kLong)); |
648 | assert_tensor_equal(x, torch::tensor({{0, 1}, {5, 6}})); |
649 | } |
650 | } |
651 | |
652 | TEST(TensorIndexingTest, TestByteTensorAssignment) { |
653 | auto x = torch::arange(0., 16).to(torch::kFloat).view({4, 4}); |
654 | auto b = torch::tensor({true, false, true, false}, torch::kByte); |
655 | auto value = torch::tensor({3., 4., 5., 6.}); |
656 | |
657 | { |
658 | WarningCapture warnings; |
659 | |
660 | x.index_put_({b}, value); |
661 | |
662 | ASSERT_EQ( |
663 | count_substr_occurrences( |
664 | warnings.str(), |
665 | "indexing with dtype torch.uint8 is now deprecated" ), |
666 | 1); |
667 | } |
668 | |
669 | assert_tensor_equal(x.index({0}), value); |
670 | assert_tensor_equal(x.index({1}), torch::arange(4, 8).to(torch::kLong)); |
671 | assert_tensor_equal(x.index({2}), value); |
672 | assert_tensor_equal(x.index({3}), torch::arange(12, 16).to(torch::kLong)); |
673 | } |
674 | |
675 | TEST(TensorIndexingTest, TestVariableSlicing) { |
676 | auto x = torch::arange(0, 16).view({4, 4}); |
677 | auto indices = torch::tensor({0, 1}, torch::kInt); |
678 | int i = indices[0].item<int>(); |
679 | int j = indices[1].item<int>(); |
680 | assert_tensor_equal(x.index({Slice(i, j)}), x.index({Slice(0, 1)})); |
681 | } |
682 | |
683 | TEST(TensorIndexingTest, TestEllipsisTensor) { |
684 | auto x = torch::arange(0, 9).to(torch::kLong).view({3, 3}); |
685 | auto idx = torch::tensor({0, 2}); |
686 | assert_tensor_equal( |
687 | x.index({"..." , idx}), torch::tensor({{0, 2}, {3, 5}, {6, 8}})); |
688 | assert_tensor_equal( |
689 | x.index({idx, "..." }), torch::tensor({{0, 1, 2}, {6, 7, 8}})); |
690 | } |
691 | |
692 | TEST(TensorIndexingTest, TestOutOfBoundIndex) { |
693 | auto x = torch::arange(0, 100).view({2, 5, 10}); |
694 | ASSERT_THROWS_WITH( |
695 | x.index({0, 5}), "index 5 is out of bounds for dimension 1 with size 5" ); |
696 | ASSERT_THROWS_WITH( |
697 | x.index({4, 5}), "index 4 is out of bounds for dimension 0 with size 2" ); |
698 | ASSERT_THROWS_WITH( |
699 | x.index({0, 1, 15}), |
700 | "index 15 is out of bounds for dimension 2 with size 10" ); |
701 | ASSERT_THROWS_WITH( |
702 | x.index({Slice(), Slice(), 12}), |
703 | "index 12 is out of bounds for dimension 2 with size 10" ); |
704 | } |
705 | |
706 | TEST(TensorIndexingTest, TestZeroDimIndex) { |
707 | auto x = torch::tensor(10); |
708 | |
709 | auto runner = [&]() -> torch::Tensor { |
710 | std::cout << x.index({0}) << std::endl; |
711 | return x.index({0}); |
712 | }; |
713 | |
714 | ASSERT_THROWS_WITH(runner(), "invalid index" ); |
715 | } |
716 | |
717 | // The tests below are from NumPy test_indexing.py with some modifications to |
718 | // make them compatible with libtorch. It's licensed under the BDS license |
719 | // below: |
720 | // |
721 | // Copyright (c) 2005-2017, NumPy Developers. |
722 | // All rights reserved. |
723 | // |
724 | // Redistribution and use in source and binary forms, with or without |
725 | // modification, are permitted provided that the following conditions are |
726 | // met: |
727 | // |
728 | // * Redistributions of source code must retain the above copyright |
729 | // notice, this list of conditions and the following disclaimer. |
730 | // |
731 | // * Redistributions in binary form must reproduce the above |
732 | // copyright notice, this list of conditions and the following |
733 | // disclaimer in the documentation and/or other materials provided |
734 | // with the distribution. |
735 | // |
736 | // * Neither the name of the NumPy Developers nor the names of any |
737 | // contributors may be used to endorse or promote products derived |
738 | // from this software without specific prior written permission. |
739 | // |
740 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
741 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
742 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR |
743 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT |
744 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, |
745 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT |
746 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, |
747 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY |
748 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
749 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
750 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
751 | |
752 | TEST(NumpyTests, TestNoneIndex) { |
753 | // `None` index adds newaxis |
754 | auto a = torch::tensor({1, 2, 3}); |
755 | ASSERT_EQ(a.index({None}).dim(), a.dim() + 1); |
756 | } |
757 | |
758 | TEST(NumpyTests, TestEmptyFancyIndex) { |
759 | // Empty list index creates an empty array |
760 | auto a = torch::tensor({1, 2, 3}); |
761 | assert_tensor_equal( |
762 | a.index({torch::tensor({}, torch::kLong)}), torch::tensor({})); |
763 | |
764 | auto b = torch::tensor({}).to(torch::kLong); |
765 | assert_tensor_equal( |
766 | a.index({torch::tensor({}, torch::kLong)}), |
767 | torch::tensor({}, torch::kLong)); |
768 | |
769 | b = torch::tensor({}).to(torch::kFloat); |
770 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
771 | ASSERT_THROW(a.index({b}), c10::Error); |
772 | } |
773 | |
774 | TEST(NumpyTests, TestEllipsisIndex) { |
775 | auto a = torch::tensor({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); |
776 | ASSERT_FALSE(a.index({"..." }).is_same(a)); |
777 | assert_tensor_equal(a.index({"..." }), a); |
778 | // `a[...]` was `a` in numpy <1.9. |
779 | ASSERT_EQ(a.index({"..." }).data_ptr(), a.data_ptr()); |
780 | |
781 | // Slicing with ellipsis can skip an |
782 | // arbitrary number of dimensions |
783 | assert_tensor_equal(a.index({0, "..." }), a.index({0})); |
784 | assert_tensor_equal(a.index({0, "..." }), a.index({0, Slice()})); |
785 | assert_tensor_equal(a.index({"..." , 0}), a.index({Slice(), 0})); |
786 | |
787 | // In NumPy, slicing with ellipsis results in a 0-dim array. In PyTorch |
788 | // we don't have separate 0-dim arrays and scalars. |
789 | assert_tensor_equal(a.index({0, "..." , 1}), torch::tensor(2)); |
790 | |
791 | // Assignment with `Ellipsis` on 0-d arrays |
792 | auto b = torch::tensor(1); |
793 | b.index_put_({Ellipsis}, 2); |
794 | ASSERT_EQ(b.item<int64_t>(), 2); |
795 | } |
796 | |
797 | TEST(NumpyTests, TestSingleIntIndex) { |
798 | // Single integer index selects one row |
799 | auto a = torch::tensor({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); |
800 | |
801 | assert_tensor_equal(a.index({0}), torch::tensor({1, 2, 3})); |
802 | assert_tensor_equal(a.index({-1}), torch::tensor({7, 8, 9})); |
803 | |
804 | // Index out of bounds produces IndexError |
805 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
806 | ASSERT_THROW(a.index({1 << 30}), c10::Error); |
807 | // NOTE: According to the standard |
808 | // (http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2017/p0543r0.html), for |
809 | // signed integers, if during the evaluation of an expression, the result is |
810 | // not mathematically defined or not in the range of representable values for |
811 | // its type, the behavior is undefined. Therefore, there is no way to check |
812 | // for index overflow case because it might not throw exception. |
813 | // ASSERT_THROW(a(1 << 64), c10::Error); |
814 | } |
815 | |
816 | TEST(NumpyTests, TestSingleBoolIndex) { |
817 | // Single boolean index |
818 | auto a = torch::tensor({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); |
819 | |
820 | assert_tensor_equal(a.index({true}), a.index({None})); |
821 | assert_tensor_equal(a.index({false}), a.index({None}).index({Slice(0, 0)})); |
822 | } |
823 | |
824 | TEST(NumpyTests, TestBooleanShapeMismatch) { |
825 | auto arr = torch::ones({5, 4, 3}); |
826 | |
827 | auto index = torch::tensor({true}); |
828 | ASSERT_THROWS_WITH(arr.index({index}), "mask" ); |
829 | |
830 | index = torch::tensor({false, false, false, false, false, false}); |
831 | ASSERT_THROWS_WITH(arr.index({index}), "mask" ); |
832 | |
833 | { |
834 | WarningCapture warnings; |
835 | |
836 | index = torch::empty({4, 4}, torch::kByte).zero_(); |
837 | ASSERT_THROWS_WITH(arr.index({index}), "mask" ); |
838 | ASSERT_THROWS_WITH(arr.index({Slice(), index}), "mask" ); |
839 | |
840 | ASSERT_EQ( |
841 | count_substr_occurrences( |
842 | warnings.str(), |
843 | "indexing with dtype torch.uint8 is now deprecated" ), |
844 | 2); |
845 | } |
846 | } |
847 | |
848 | TEST(NumpyTests, TestBooleanIndexingOnedim) { |
849 | // Indexing a 2-dimensional array with |
850 | // boolean array of length one |
851 | auto a = torch::tensor({{0., 0., 0.}}); |
852 | auto b = torch::tensor({true}); |
853 | assert_tensor_equal(a.index({b}), a); |
854 | // boolean assignment |
855 | a.index_put_({b}, 1.); |
856 | assert_tensor_equal(a, torch::tensor({{1., 1., 1.}})); |
857 | } |
858 | |
859 | TEST(NumpyTests, TestBooleanAssignmentValueMismatch) { |
860 | // A boolean assignment should fail when the shape of the values |
861 | // cannot be broadcast to the subscription. (see also gh-3458) |
862 | auto a = torch::arange(0, 4); |
863 | |
864 | auto f = [](torch::Tensor a, std::vector<int64_t> v) -> void { |
865 | a.index_put_({a > -1}, torch::tensor(v)); |
866 | }; |
867 | |
868 | ASSERT_THROWS_WITH(f(a, {}), "shape mismatch" ); |
869 | ASSERT_THROWS_WITH(f(a, {1, 2, 3}), "shape mismatch" ); |
870 | ASSERT_THROWS_WITH(f(a.index({Slice(None, 1)}), {1, 2, 3}), "shape mismatch" ); |
871 | } |
872 | |
873 | TEST(NumpyTests, TestBooleanIndexingTwodim) { |
874 | // Indexing a 2-dimensional array with |
875 | // 2-dimensional boolean array |
876 | auto a = torch::tensor({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); |
877 | auto b = torch::tensor( |
878 | {{true, false, true}, {false, true, false}, {true, false, true}}); |
879 | assert_tensor_equal(a.index({b}), torch::tensor({1, 3, 5, 7, 9})); |
880 | assert_tensor_equal(a.index({b.index({1})}), torch::tensor({{4, 5, 6}})); |
881 | assert_tensor_equal(a.index({b.index({0})}), a.index({b.index({2})})); |
882 | |
883 | // boolean assignment |
884 | a.index_put_({b}, 0); |
885 | assert_tensor_equal(a, torch::tensor({{0, 2, 0}, {4, 0, 6}, {0, 8, 0}})); |
886 | } |
887 | |
888 | TEST(NumpyTests, TestBooleanIndexingWeirdness) { |
889 | // Weird boolean indexing things |
890 | auto a = torch::ones({2, 3, 4}); |
891 | ASSERT_EQ( |
892 | a.index({false, true, "..." }).sizes(), torch::IntArrayRef({0, 2, 3, 4})); |
893 | assert_tensor_equal( |
894 | torch::ones({1, 2}), |
895 | a.index( |
896 | {true, |
897 | torch::tensor({0, 1}), |
898 | true, |
899 | true, |
900 | torch::tensor({1}), |
901 | torch::tensor({{2}})})); |
902 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
903 | ASSERT_THROW(a.index({false, torch::tensor({0, 1}), "..." }), c10::Error); |
904 | } |
905 | |
906 | TEST(NumpyTests, TestBooleanIndexingWeirdnessTensors) { |
907 | // Weird boolean indexing things |
908 | auto false_tensor = torch::tensor(false); |
909 | auto true_tensor = torch::tensor(true); |
910 | auto a = torch::ones({2, 3, 4}); |
911 | ASSERT_EQ( |
912 | a.index({false, true, "..." }).sizes(), torch::IntArrayRef({0, 2, 3, 4})); |
913 | assert_tensor_equal( |
914 | torch::ones({1, 2}), |
915 | a.index( |
916 | {true_tensor, |
917 | torch::tensor({0, 1}), |
918 | true_tensor, |
919 | true_tensor, |
920 | torch::tensor({1}), |
921 | torch::tensor({{2}})})); |
922 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
923 | ASSERT_THROW( |
924 | a.index({false_tensor, torch::tensor({0, 1}), "..." }), c10::Error); |
925 | } |
926 | |
927 | TEST(NumpyTests, TestBooleanIndexingAlldims) { |
928 | auto true_tensor = torch::tensor(true); |
929 | auto a = torch::ones({2, 3}); |
930 | ASSERT_EQ(a.index({true, true}).sizes(), torch::IntArrayRef({1, 2, 3})); |
931 | ASSERT_EQ( |
932 | a.index({true_tensor, true_tensor}).sizes(), |
933 | torch::IntArrayRef({1, 2, 3})); |
934 | } |
935 | |
936 | TEST(NumpyTests, TestBooleanListIndexing) { |
937 | // Indexing a 2-dimensional array with |
938 | // boolean lists |
939 | auto a = torch::tensor({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); |
940 | auto b = torch::tensor({true, false, false}); |
941 | auto c = torch::tensor({true, true, false}); |
942 | assert_tensor_equal(a.index({b}), torch::tensor({{1, 2, 3}})); |
943 | assert_tensor_equal(a.index({b, b}), torch::tensor({1})); |
944 | assert_tensor_equal(a.index({c}), torch::tensor({{1, 2, 3}, {4, 5, 6}})); |
945 | assert_tensor_equal(a.index({c, c}), torch::tensor({1, 5})); |
946 | } |
947 | |
948 | TEST(NumpyTests, TestEverythingReturnsViews) { |
949 | // Before `...` would return a itself. |
950 | auto a = torch::tensor({5}); |
951 | |
952 | ASSERT_FALSE(a.is_same(a.index({"..." }))); |
953 | ASSERT_FALSE(a.is_same(a.index({Slice()}))); |
954 | } |
955 | |
956 | TEST(NumpyTests, TestBroaderrorsIndexing) { |
957 | auto a = torch::zeros({5, 5}); |
958 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
959 | ASSERT_THROW( |
960 | a.index({torch::tensor({0, 1}), torch::tensor({0, 1, 2})}), c10::Error); |
961 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
962 | ASSERT_THROW( |
963 | a.index_put_({torch::tensor({0, 1}), torch::tensor({0, 1, 2})}, 0), |
964 | c10::Error); |
965 | } |
966 | |
967 | TEST(NumpyTests, TestTrivialFancyOutOfBounds) { |
968 | auto a = torch::zeros({5}); |
969 | auto ind = torch::ones({20}, torch::kInt64); |
970 | ind.index_put_({-1}, 10); |
971 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
972 | ASSERT_THROW(a.index({ind}), c10::Error); |
973 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
974 | ASSERT_THROW(a.index_put_({ind}, 0), c10::Error); |
975 | ind = torch::ones({20}, torch::kInt64); |
976 | ind.index_put_({0}, 11); |
977 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
978 | ASSERT_THROW(a.index({ind}), c10::Error); |
979 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
980 | ASSERT_THROW(a.index_put_({ind}, 0), c10::Error); |
981 | } |
982 | |
983 | TEST(NumpyTests, TestIndexIsLarger) { |
984 | // Simple case of fancy index broadcasting of the index. |
985 | auto a = torch::zeros({5, 5}); |
986 | a.index_put_( |
987 | {torch::tensor({{0}, {1}, {2}}), torch::tensor({0, 1, 2})}, |
988 | torch::tensor({2., 3., 4.})); |
989 | |
990 | ASSERT_TRUE( |
991 | (a.index({Slice(None, 3), Slice(None, 3)}) == torch::tensor({2., 3., 4.})) |
992 | .all() |
993 | .item<bool>()); |
994 | } |
995 | |
996 | TEST(NumpyTests, TestBroadcastSubspace) { |
997 | auto a = torch::zeros({100, 100}); |
998 | auto v = torch::arange(0., 100).index({Slice(), None}); |
999 | auto b = torch::arange(99, -1, -1).to(torch::kLong); |
1000 | a.index_put_({b}, v); |
1001 | auto expected = b.to(torch::kDouble).unsqueeze(1).expand({100, 100}); |
1002 | assert_tensor_equal(a, expected); |
1003 | } |
1004 | |