1#include <gtest/gtest.h>
2
3#include <torch/torch.h>
4
5#include <test/cpp/api/support.h>
6
7using namespace torch::indexing;
8using namespace torch::test;
9
10TEST(TensorIndexingTest, Slice) {
11 Slice slice(1, 2, 3);
12 ASSERT_EQ(slice.start(), 1);
13 ASSERT_EQ(slice.stop(), 2);
14 ASSERT_EQ(slice.step(), 3);
15
16 ASSERT_EQ(c10::str(slice), "1:2:3");
17}
18
19TEST(TensorIndexingTest, TensorIndex) {
20 {
21 std::vector<TensorIndex> indices = {
22 None,
23 "...",
24 Ellipsis,
25 0,
26 true,
27 Slice(1, None, 2),
28 torch::tensor({1, 2})};
29 ASSERT_TRUE(indices[0].is_none());
30 ASSERT_TRUE(indices[1].is_ellipsis());
31 ASSERT_TRUE(indices[2].is_ellipsis());
32 ASSERT_TRUE(indices[3].is_integer());
33 ASSERT_TRUE(indices[3].integer() == 0);
34 ASSERT_TRUE(indices[4].is_boolean());
35 ASSERT_TRUE(indices[4].boolean() == true);
36 ASSERT_TRUE(indices[5].is_slice());
37 ASSERT_TRUE(indices[5].slice().start() == 1);
38 ASSERT_TRUE(indices[5].slice().stop() == INDEX_MAX);
39 ASSERT_TRUE(indices[5].slice().step() == 2);
40 ASSERT_TRUE(indices[6].is_tensor());
41 ASSERT_TRUE(torch::equal(indices[6].tensor(), torch::tensor({1, 2})));
42 }
43
44 ASSERT_THROWS_WITH(
45 TensorIndex(".."),
46 "Expected \"...\" to represent an ellipsis index, but got \"..\"");
47
48 {
49 std::vector<TensorIndex> indices = {
50 None, "...", Ellipsis, 0, true, Slice(1, None, 2)};
51 ASSERT_EQ(
52 c10::str(indices),
53 c10::str("(None, ..., ..., 0, true, 1:", INDEX_MAX, ":2)"));
54 ASSERT_EQ(c10::str(indices[0]), "None");
55 ASSERT_EQ(c10::str(indices[1]), "...");
56 ASSERT_EQ(c10::str(indices[2]), "...");
57 ASSERT_EQ(c10::str(indices[3]), "0");
58 ASSERT_EQ(c10::str(indices[4]), "true");
59 ASSERT_EQ(c10::str(indices[5]), c10::str("1:", INDEX_MAX, ":2"));
60 }
61
62 ASSERT_EQ(
63 c10::str(std::vector<TensorIndex>({Slice()})),
64 c10::str("(0:", INDEX_MAX, ":1)"));
65 ASSERT_EQ(
66 c10::str(std::vector<TensorIndex>({Slice(None, None)})),
67 c10::str("(0:", INDEX_MAX, ":1)"));
68 ASSERT_EQ(
69 c10::str(std::vector<TensorIndex>({Slice(None, None, None)})),
70 c10::str("(0:", INDEX_MAX, ":1)"));
71
72 ASSERT_EQ(
73 c10::str(std::vector<TensorIndex>({Slice(1, None)})),
74 c10::str("(1:", INDEX_MAX, ":1)"));
75 ASSERT_EQ(
76 c10::str(std::vector<TensorIndex>({Slice(1, None, None)})),
77 c10::str("(1:", INDEX_MAX, ":1)"));
78 ASSERT_EQ(
79 c10::str(std::vector<TensorIndex>({Slice(None, 3)})),
80 c10::str("(0:3:1)"));
81 ASSERT_EQ(
82 c10::str(std::vector<TensorIndex>({Slice(None, 3, None)})),
83 c10::str("(0:3:1)"));
84 ASSERT_EQ(
85 c10::str(std::vector<TensorIndex>({Slice(None, None, 2)})),
86 c10::str("(0:", INDEX_MAX, ":2)"));
87 ASSERT_EQ(
88 c10::str(std::vector<TensorIndex>({Slice(None, None, -1)})),
89 c10::str("(", INDEX_MAX, ":", INDEX_MIN, ":-1)"));
90
91 ASSERT_EQ(
92 c10::str(std::vector<TensorIndex>({Slice(1, 3)})), c10::str("(1:3:1)"));
93 ASSERT_EQ(
94 c10::str(std::vector<TensorIndex>({Slice(1, None, 2)})),
95 c10::str("(1:", INDEX_MAX, ":2)"));
96 ASSERT_EQ(
97 c10::str(std::vector<TensorIndex>({Slice(1, None, -1)})),
98 c10::str("(1:", INDEX_MIN, ":-1)"));
99 ASSERT_EQ(
100 c10::str(std::vector<TensorIndex>({Slice(None, 3, 2)})),
101 c10::str("(0:3:2)"));
102 ASSERT_EQ(
103 c10::str(std::vector<TensorIndex>({Slice(None, 3, -1)})),
104 c10::str("(", INDEX_MAX, ":3:-1)"));
105
106 ASSERT_EQ(
107 c10::str(std::vector<TensorIndex>({Slice(1, 3, 2)})),
108 c10::str("(1:3:2)"));
109}
110
111TEST(TensorIndexingTest, TestNoIndices) {
112 torch::Tensor tensor = torch::randn({20, 20});
113 torch::Tensor value = torch::randn({20, 20});
114 std::vector<TensorIndex> indices;
115
116 ASSERT_THROWS_WITH(
117 tensor.index({}),
118 "Passing an empty index list to Tensor::index() is not valid syntax");
119 ASSERT_THROWS_WITH(
120 tensor.index_put_({}, 1),
121 "Passing an empty index list to Tensor::index_put_() is not valid syntax");
122 ASSERT_THROWS_WITH(
123 tensor.index_put_({}, value),
124 "Passing an empty index list to Tensor::index_put_() is not valid syntax");
125
126 ASSERT_THROWS_WITH(
127 tensor.index(indices),
128 "Passing an empty index list to Tensor::index() is not valid syntax");
129 ASSERT_THROWS_WITH(
130 tensor.index_put_(indices, 1),
131 "Passing an empty index list to Tensor::index_put_() is not valid syntax");
132 ASSERT_THROWS_WITH(
133 tensor.index_put_(indices, value),
134 "Passing an empty index list to Tensor::index_put_() is not valid syntax");
135}
136
137TEST(TensorIndexingTest, TestAdvancedIndexingWithListOfTensor) {
138 {
139 torch::Tensor tensor = torch::randn({20, 20});
140 torch::Tensor index = torch::arange(10, torch::kLong).cpu();
141 torch::Tensor result = at::index(tensor, {index});
142 torch::Tensor result_with_init_list = tensor.index({index});
143 ASSERT_TRUE(result.equal(result_with_init_list));
144 }
145 {
146 torch::Tensor tensor = torch::randn({20, 20});
147 torch::Tensor index = torch::arange(10, torch::kLong).cpu();
148 torch::Tensor result = at::index_put_(tensor, {index}, torch::ones({20}));
149 torch::Tensor result_with_init_list =
150 tensor.index_put_({index}, torch::ones({20}));
151 ASSERT_TRUE(result.equal(result_with_init_list));
152 }
153 {
154 torch::Tensor tensor = torch::randn({20, 20});
155 torch::Tensor index = torch::arange(10, torch::kLong).cpu();
156 torch::Tensor result =
157 at::index_put_(tensor, {index}, torch::ones({1, 20}));
158 torch::Tensor result_with_init_list =
159 tensor.index_put_({index}, torch::ones({1, 20}));
160 ASSERT_TRUE(result.equal(result_with_init_list));
161 }
162}
163
164TEST(TensorIndexingTest, TestSingleInt) {
165 auto v = torch::randn({5, 7, 3});
166 ASSERT_EQ(v.index({4}).sizes(), torch::IntArrayRef({7, 3}));
167}
168
169TEST(TensorIndexingTest, TestMultipleInt) {
170 auto v = torch::randn({5, 7, 3});
171 ASSERT_EQ(v.index({4}).sizes(), torch::IntArrayRef({7, 3}));
172 ASSERT_EQ(v.index({4, Slice(), 1}).sizes(), torch::IntArrayRef({7}));
173
174 // To show that `.index_put_` works
175 v.index_put_({4, 3, 1}, 0);
176 ASSERT_EQ(v.index({4, 3, 1}).item<double>(), 0);
177}
178
179TEST(TensorIndexingTest, TestNone) {
180 auto v = torch::randn({5, 7, 3});
181 ASSERT_EQ(v.index({None}).sizes(), torch::IntArrayRef({1, 5, 7, 3}));
182 ASSERT_EQ(v.index({Slice(), None}).sizes(), torch::IntArrayRef({5, 1, 7, 3}));
183 ASSERT_EQ(
184 v.index({Slice(), None, None}).sizes(),
185 torch::IntArrayRef({5, 1, 1, 7, 3}));
186 ASSERT_EQ(v.index({"...", None}).sizes(), torch::IntArrayRef({5, 7, 3, 1}));
187}
188
189TEST(TensorIndexingTest, TestStep) {
190 auto v = torch::arange(10);
191 assert_tensor_equal(v.index({Slice(None, None, 1)}), v);
192 assert_tensor_equal(
193 v.index({Slice(None, None, 2)}), torch::tensor({0, 2, 4, 6, 8}));
194 assert_tensor_equal(
195 v.index({Slice(None, None, 3)}), torch::tensor({0, 3, 6, 9}));
196 assert_tensor_equal(v.index({Slice(None, None, 11)}), torch::tensor({0}));
197 assert_tensor_equal(v.index({Slice(1, 6, 2)}), torch::tensor({1, 3, 5}));
198}
199
200TEST(TensorIndexingTest, TestStepAssignment) {
201 auto v = torch::zeros({4, 4});
202 v.index_put_({0, Slice(1, None, 2)}, torch::tensor({3., 4.}));
203 assert_tensor_equal(v.index({0}), torch::tensor({0., 3., 0., 4.}));
204 assert_tensor_equal(v.index({Slice(1, None)}).sum(), torch::tensor(0));
205}
206
207TEST(TensorIndexingTest, TestBoolIndices) {
208 {
209 auto v = torch::randn({5, 7, 3});
210 auto boolIndices =
211 torch::tensor({true, false, true, true, false}, torch::kBool);
212 ASSERT_EQ(v.index({boolIndices}).sizes(), torch::IntArrayRef({3, 7, 3}));
213 assert_tensor_equal(
214 v.index({boolIndices}),
215 torch::stack({v.index({0}), v.index({2}), v.index({3})}));
216 }
217 {
218 auto v = torch::tensor({true, false, true}, torch::kBool);
219 auto boolIndices = torch::tensor({true, false, false}, torch::kBool);
220 auto uint8Indices = torch::tensor({1, 0, 0}, torch::kUInt8);
221
222 {
223 WarningCapture warnings;
224
225 ASSERT_EQ(
226 v.index({boolIndices}).sizes(), v.index({uint8Indices}).sizes());
227 assert_tensor_equal(v.index({boolIndices}), v.index({uint8Indices}));
228 assert_tensor_equal(
229 v.index({boolIndices}), torch::tensor({true}, torch::kBool));
230
231 ASSERT_EQ(
232 count_substr_occurrences(
233 warnings.str(),
234 "indexing with dtype torch.uint8 is now deprecated"),
235 2);
236 }
237 }
238}
239
240TEST(TensorIndexingTest, TestBoolIndicesAccumulate) {
241 auto mask = torch::zeros({10}, torch::kBool);
242 auto y = torch::ones({10, 10});
243 y.index_put_({mask}, {y.index({mask})}, /*accumulate=*/true);
244 assert_tensor_equal(y, torch::ones({10, 10}));
245}
246
247TEST(TensorIndexingTest, TestMultipleBoolIndices) {
248 auto v = torch::randn({5, 7, 3});
249 // note: these broadcast together and are transposed to the first dim
250 auto mask1 = torch::tensor({1, 0, 1, 1, 0}, torch::kBool);
251 auto mask2 = torch::tensor({1, 1, 1}, torch::kBool);
252 ASSERT_EQ(
253 v.index({mask1, Slice(), mask2}).sizes(), torch::IntArrayRef({3, 7}));
254}
255
256TEST(TensorIndexingTest, TestByteMask) {
257 {
258 auto v = torch::randn({5, 7, 3});
259 auto mask = torch::tensor({1, 0, 1, 1, 0}, torch::kByte);
260 {
261 WarningCapture warnings;
262
263 ASSERT_EQ(v.index({mask}).sizes(), torch::IntArrayRef({3, 7, 3}));
264 assert_tensor_equal(v.index({mask}), torch::stack({v[0], v[2], v[3]}));
265
266 ASSERT_EQ(
267 count_substr_occurrences(
268 warnings.str(),
269 "indexing with dtype torch.uint8 is now deprecated"),
270 2);
271 }
272 }
273 {
274 auto v = torch::tensor({1.});
275 assert_tensor_equal(v.index({v == 0}), torch::randn({0}));
276 }
277}
278
279TEST(TensorIndexingTest, TestByteMaskAccumulate) {
280 auto mask = torch::zeros({10}, torch::kUInt8);
281 auto y = torch::ones({10, 10});
282 {
283 WarningCapture warnings;
284
285 y.index_put_({mask}, y.index({mask}), /*accumulate=*/true);
286 assert_tensor_equal(y, torch::ones({10, 10}));
287
288 ASSERT_EQ(
289 count_substr_occurrences(
290 warnings.str(),
291 "indexing with dtype torch.uint8 is now deprecated"),
292 2);
293 }
294}
295
296TEST(TensorIndexingTest, TestMultipleByteMask) {
297 auto v = torch::randn({5, 7, 3});
298 // note: these broadcast together and are transposed to the first dim
299 auto mask1 = torch::tensor({1, 0, 1, 1, 0}, torch::kByte);
300 auto mask2 = torch::tensor({1, 1, 1}, torch::kByte);
301 {
302 WarningCapture warnings;
303
304 ASSERT_EQ(
305 v.index({mask1, Slice(), mask2}).sizes(), torch::IntArrayRef({3, 7}));
306
307 ASSERT_EQ(
308 count_substr_occurrences(
309 warnings.str(),
310 "indexing with dtype torch.uint8 is now deprecated"),
311 2);
312 }
313}
314
315TEST(TensorIndexingTest, TestByteMask2d) {
316 auto v = torch::randn({5, 7, 3});
317 auto c = torch::randn({5, 7});
318 int64_t num_ones = (c > 0).sum().item().to<int64_t>();
319 auto r = v.index({c > 0});
320 ASSERT_EQ(r.sizes(), torch::IntArrayRef({num_ones, 3}));
321}
322
323TEST(TensorIndexingTest, TestIntIndices) {
324 auto v = torch::randn({5, 7, 3});
325 ASSERT_EQ(
326 v.index({torch::tensor({0, 4, 2})}).sizes(),
327 torch::IntArrayRef({3, 7, 3}));
328 ASSERT_EQ(
329 v.index({Slice(), torch::tensor({0, 4, 2})}).sizes(),
330 torch::IntArrayRef({5, 3, 3}));
331 ASSERT_EQ(
332 v.index({Slice(), torch::tensor({{0, 1}, {4, 3}})}).sizes(),
333 torch::IntArrayRef({5, 2, 2, 3}));
334}
335
336TEST(TensorIndexingTest, TestIntIndices2d) {
337 // From the NumPy indexing example
338 auto x = torch::arange(0, 12, torch::kLong).view({4, 3});
339 auto rows = torch::tensor({{0, 0}, {3, 3}});
340 auto columns = torch::tensor({{0, 2}, {0, 2}});
341 assert_tensor_equal(
342 x.index({rows, columns}), torch::tensor({{0, 2}, {9, 11}}));
343}
344
345TEST(TensorIndexingTest, TestIntIndicesBroadcast) {
346 // From the NumPy indexing example
347 auto x = torch::arange(0, 12, torch::kLong).view({4, 3});
348 auto rows = torch::tensor({0, 3});
349 auto columns = torch::tensor({0, 2});
350 auto result = x.index({rows.index({Slice(), None}), columns});
351 assert_tensor_equal(result, torch::tensor({{0, 2}, {9, 11}}));
352}
353
354TEST(TensorIndexingTest, TestEmptyIndex) {
355 auto x = torch::arange(0, 12).view({4, 3});
356 auto idx = torch::tensor({}, torch::kLong);
357 ASSERT_EQ(x.index({idx}).numel(), 0);
358
359 // empty assignment should have no effect but not throw an exception
360 auto y = x.clone();
361 y.index_put_({idx}, -1);
362 assert_tensor_equal(x, y);
363
364 auto mask = torch::zeros({4, 3}, torch::kBool);
365 y.index_put_({mask}, -1);
366 assert_tensor_equal(x, y);
367}
368
369TEST(TensorIndexingTest, TestEmptyNdimIndex) {
370 torch::Device device(torch::kCPU);
371 {
372 auto x = torch::randn({5}, device);
373 assert_tensor_equal(
374 torch::empty({0, 2}, device),
375 x.index({torch::empty(
376 {0, 2}, torch::TensorOptions(torch::kInt64).device(device))}));
377 }
378 {
379 auto x = torch::randn({2, 3, 4, 5}, device);
380 assert_tensor_equal(
381 torch::empty({2, 0, 6, 4, 5}, device),
382 x.index(
383 {Slice(),
384 torch::empty(
385 {0, 6}, torch::TensorOptions(torch::kInt64).device(device))}));
386 }
387 {
388 auto x = torch::empty({10, 0});
389 ASSERT_EQ(
390 x.index({torch::tensor({1, 2})}).sizes(), torch::IntArrayRef({2, 0}));
391 ASSERT_EQ(
392 x.index(
393 {torch::tensor({}, torch::kLong), torch::tensor({}, torch::kLong)})
394 .sizes(),
395 torch::IntArrayRef({0}));
396 ASSERT_THROWS_WITH(
397 x.index({Slice(), torch::tensor({0, 1})}), "for dimension with size 0");
398 }
399}
400
401TEST(TensorIndexingTest, TestEmptyNdimIndex_CUDA) {
402 torch::Device device(torch::kCUDA);
403 {
404 auto x = torch::randn({5}, device);
405 assert_tensor_equal(
406 torch::empty({0, 2}, device),
407 x.index({torch::empty(
408 {0, 2}, torch::TensorOptions(torch::kInt64).device(device))}));
409 }
410 {
411 auto x = torch::randn({2, 3, 4, 5}, device);
412 assert_tensor_equal(
413 torch::empty({2, 0, 6, 4, 5}, device),
414 x.index(
415 {Slice(),
416 torch::empty(
417 {0, 6}, torch::TensorOptions(torch::kInt64).device(device))}));
418 }
419}
420
421TEST(TensorIndexingTest, TestEmptyNdimIndexBool) {
422 torch::Device device(torch::kCPU);
423 auto x = torch::randn({5}, device);
424 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
425 ASSERT_THROW(
426 x.index({torch::empty(
427 {0, 2}, torch::TensorOptions(torch::kUInt8).device(device))}),
428 c10::Error);
429}
430
431TEST(TensorIndexingTest, TestEmptyNdimIndexBool_CUDA) {
432 torch::Device device(torch::kCUDA);
433 auto x = torch::randn({5}, device);
434 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
435 ASSERT_THROW(
436 x.index({torch::empty(
437 {0, 2}, torch::TensorOptions(torch::kUInt8).device(device))}),
438 c10::Error);
439}
440
441TEST(TensorIndexingTest, TestEmptySlice) {
442 torch::Device device(torch::kCPU);
443 auto x = torch::randn({2, 3, 4, 5}, device);
444 auto y = x.index({Slice(), Slice(), Slice(), 1});
445 auto z = y.index({Slice(), Slice(1, 1), Slice()});
446 ASSERT_EQ(z.sizes(), torch::IntArrayRef({2, 0, 4}));
447 // this isn't technically necessary, but matches NumPy stride calculations.
448 ASSERT_EQ(z.strides(), torch::IntArrayRef({60, 20, 5}));
449 ASSERT_TRUE(z.is_contiguous());
450}
451
452TEST(TensorIndexingTest, TestEmptySlice_CUDA) {
453 torch::Device device(torch::kCUDA);
454 auto x = torch::randn({2, 3, 4, 5}, device);
455 auto y = x.index({Slice(), Slice(), Slice(), 1});
456 auto z = y.index({Slice(), Slice(1, 1), Slice()});
457 ASSERT_EQ(z.sizes(), torch::IntArrayRef({2, 0, 4}));
458 // this isn't technically necessary, but matches NumPy stride calculations.
459 ASSERT_EQ(z.strides(), torch::IntArrayRef({60, 20, 5}));
460 ASSERT_TRUE(z.is_contiguous());
461}
462
463TEST(TensorIndexingTest, TestIndexGetitemCopyBoolsSlices) {
464 auto true_tensor = torch::tensor(1, torch::kUInt8);
465 auto false_tensor = torch::tensor(0, torch::kUInt8);
466
467 std::vector<torch::Tensor> tensors = {torch::randn({2, 3}), torch::tensor(3)};
468
469 for (auto& a : tensors) {
470 ASSERT_NE(a.data_ptr(), a.index({true}).data_ptr());
471 {
472 std::vector<int64_t> sizes = {0};
473 sizes.insert(sizes.end(), a.sizes().begin(), a.sizes().end());
474 assert_tensor_equal(torch::empty(sizes), a.index({false}));
475 }
476 ASSERT_NE(a.data_ptr(), a.index({true_tensor}).data_ptr());
477 {
478 std::vector<int64_t> sizes = {0};
479 sizes.insert(sizes.end(), a.sizes().begin(), a.sizes().end());
480 assert_tensor_equal(torch::empty(sizes), a.index({false_tensor}));
481 }
482 ASSERT_EQ(a.data_ptr(), a.index({None}).data_ptr());
483 ASSERT_EQ(a.data_ptr(), a.index({"..."}).data_ptr());
484 }
485}
486
487TEST(TensorIndexingTest, TestIndexSetitemBoolsSlices) {
488 auto true_tensor = torch::tensor(1, torch::kUInt8);
489 auto false_tensor = torch::tensor(0, torch::kUInt8);
490
491 std::vector<torch::Tensor> tensors = {torch::randn({2, 3}), torch::tensor(3)};
492
493 for (auto& a : tensors) {
494 // prefix with a 1,1, to ensure we are compatible with numpy which cuts off
495 // prefix 1s (some of these ops already prefix a 1 to the size)
496 auto neg_ones = torch::ones_like(a) * -1;
497 auto neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0);
498 a.index_put_({true}, neg_ones_expanded);
499 assert_tensor_equal(a, neg_ones);
500 a.index_put_({false}, 5);
501 assert_tensor_equal(a, neg_ones);
502 a.index_put_({true_tensor}, neg_ones_expanded * 2);
503 assert_tensor_equal(a, neg_ones * 2);
504 a.index_put_({false_tensor}, 5);
505 assert_tensor_equal(a, neg_ones * 2);
506 a.index_put_({None}, neg_ones_expanded * 3);
507 assert_tensor_equal(a, neg_ones * 3);
508 a.index_put_({"..."}, neg_ones_expanded * 4);
509 assert_tensor_equal(a, neg_ones * 4);
510 if (a.dim() == 0) {
511 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
512 ASSERT_THROW(a.index_put_({Slice()}, neg_ones_expanded * 5), c10::Error);
513 }
514 }
515}
516
517TEST(TensorIndexingTest, TestIndexScalarWithBoolMask) {
518 torch::Device device(torch::kCPU);
519
520 auto a = torch::tensor(1, device);
521 auto uintMask =
522 torch::tensor(true, torch::TensorOptions(torch::kUInt8).device(device));
523 auto boolMask =
524 torch::tensor(true, torch::TensorOptions(torch::kBool).device(device));
525 assert_tensor_equal(a.index({uintMask}), a.index({boolMask}));
526 ASSERT_EQ(a.index({uintMask}).dtype(), a.index({boolMask}).dtype());
527
528 a = torch::tensor(true, torch::TensorOptions(torch::kBool).device(device));
529 assert_tensor_equal(a.index({uintMask}), a.index({boolMask}));
530 ASSERT_EQ(a.index({uintMask}).dtype(), a.index({boolMask}).dtype());
531}
532
533TEST(TensorIndexingTest, TestIndexScalarWithBoolMask_CUDA) {
534 torch::Device device(torch::kCUDA);
535
536 auto a = torch::tensor(1, device);
537 auto uintMask =
538 torch::tensor(true, torch::TensorOptions(torch::kUInt8).device(device));
539 auto boolMask =
540 torch::tensor(true, torch::TensorOptions(torch::kBool).device(device));
541 assert_tensor_equal(a.index({uintMask}), a.index({boolMask}));
542 ASSERT_EQ(a.index({uintMask}).dtype(), a.index({boolMask}).dtype());
543
544 a = torch::tensor(true, torch::TensorOptions(torch::kBool).device(device));
545 assert_tensor_equal(a.index({uintMask}), a.index({boolMask}));
546 ASSERT_EQ(a.index({uintMask}).dtype(), a.index({boolMask}).dtype());
547}
548
549TEST(TensorIndexingTest, TestSetitemExpansionError) {
550 auto true_tensor = torch::tensor(true);
551 auto a = torch::randn({2, 3});
552 // check prefix with non-1s doesn't work
553 std::vector<int64_t> tensor_sizes{5, 1};
554 tensor_sizes.insert(tensor_sizes.end(), a.sizes().begin(), a.sizes().end());
555 auto a_expanded = a.expand(tensor_sizes);
556 // NumPy: ValueError
557 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
558 ASSERT_THROW(a.index_put_({true}, a_expanded), c10::Error);
559 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
560 ASSERT_THROW(a.index_put_({true_tensor}, a_expanded), c10::Error);
561}
562
563TEST(TensorIndexingTest, TestGetitemScalars) {
564 auto zero = torch::tensor(0, torch::kInt64);
565 auto one = torch::tensor(1, torch::kInt64);
566
567 // non-scalar indexed with scalars
568 auto a = torch::randn({2, 3});
569 assert_tensor_equal(a.index({0}), a.index({zero}));
570 assert_tensor_equal(a.index({0}).index({1}), a.index({zero}).index({one}));
571 assert_tensor_equal(a.index({0, 1}), a.index({zero, one}));
572 assert_tensor_equal(a.index({0, one}), a.index({zero, 1}));
573
574 // indexing by a scalar should slice (not copy)
575 ASSERT_EQ(a.index({0, 1}).data_ptr(), a.index({zero, one}).data_ptr());
576 ASSERT_EQ(a.index({1}).data_ptr(), a.index({one.to(torch::kInt)}).data_ptr());
577 ASSERT_EQ(
578 a.index({1}).data_ptr(), a.index({one.to(torch::kShort)}).data_ptr());
579
580 // scalar indexed with scalar
581 auto r = torch::randn({});
582 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
583 ASSERT_THROW(r.index({Slice()}), c10::Error);
584 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
585 ASSERT_THROW(r.index({zero}), c10::Error);
586 assert_tensor_equal(r, r.index({"..."}));
587}
588
589TEST(TensorIndexingTest, TestSetitemScalars) {
590 auto zero = torch::tensor(0, torch::kInt64);
591
592 // non-scalar indexed with scalars
593 auto a = torch::randn({2, 3});
594 auto a_set_with_number = a.clone();
595 auto a_set_with_scalar = a.clone();
596 auto b = torch::randn({3});
597
598 a_set_with_number.index_put_({0}, b);
599 a_set_with_scalar.index_put_({zero}, b);
600 assert_tensor_equal(a_set_with_number, a_set_with_scalar);
601 a.index_put_({1, zero}, 7.7);
602 ASSERT_TRUE(a.index({1, 0}).allclose(torch::tensor(7.7)));
603
604 // scalar indexed with scalars
605 auto r = torch::randn({});
606 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
607 ASSERT_THROW(r.index_put_({Slice()}, 8.8), c10::Error);
608 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
609 ASSERT_THROW(r.index_put_({zero}, 8.8), c10::Error);
610 r.index_put_({"..."}, 9.9);
611 ASSERT_TRUE(r.allclose(torch::tensor(9.9)));
612}
613
614TEST(TensorIndexingTest, TestBasicAdvancedCombined) {
615 // From the NumPy indexing example
616 auto x = torch::arange(0, 12).to(torch::kLong).view({4, 3});
617 assert_tensor_equal(
618 x.index({Slice(1, 2), Slice(1, 3)}),
619 x.index({Slice(1, 2), torch::tensor({1, 2})}));
620 assert_tensor_equal(
621 x.index({Slice(1, 2), Slice(1, 3)}), torch::tensor({{4, 5}}));
622
623 // Check that it is a copy
624 {
625 auto unmodified = x.clone();
626 x.index({Slice(1, 2), torch::tensor({1, 2})}).zero_();
627 assert_tensor_equal(x, unmodified);
628 }
629
630 // But assignment should modify the original
631 {
632 auto unmodified = x.clone();
633 x.index_put_({Slice(1, 2), torch::tensor({1, 2})}, 0);
634 assert_tensor_not_equal(x, unmodified);
635 }
636}
637
638TEST(TensorIndexingTest, TestIntAssignment) {
639 {
640 auto x = torch::arange(0, 4).to(torch::kLong).view({2, 2});
641 x.index_put_({1}, 5);
642 assert_tensor_equal(x, torch::tensor({{0, 1}, {5, 5}}));
643 }
644
645 {
646 auto x = torch::arange(0, 4).to(torch::kLong).view({2, 2});
647 x.index_put_({1}, torch::arange(5, 7).to(torch::kLong));
648 assert_tensor_equal(x, torch::tensor({{0, 1}, {5, 6}}));
649 }
650}
651
652TEST(TensorIndexingTest, TestByteTensorAssignment) {
653 auto x = torch::arange(0., 16).to(torch::kFloat).view({4, 4});
654 auto b = torch::tensor({true, false, true, false}, torch::kByte);
655 auto value = torch::tensor({3., 4., 5., 6.});
656
657 {
658 WarningCapture warnings;
659
660 x.index_put_({b}, value);
661
662 ASSERT_EQ(
663 count_substr_occurrences(
664 warnings.str(),
665 "indexing with dtype torch.uint8 is now deprecated"),
666 1);
667 }
668
669 assert_tensor_equal(x.index({0}), value);
670 assert_tensor_equal(x.index({1}), torch::arange(4, 8).to(torch::kLong));
671 assert_tensor_equal(x.index({2}), value);
672 assert_tensor_equal(x.index({3}), torch::arange(12, 16).to(torch::kLong));
673}
674
675TEST(TensorIndexingTest, TestVariableSlicing) {
676 auto x = torch::arange(0, 16).view({4, 4});
677 auto indices = torch::tensor({0, 1}, torch::kInt);
678 int i = indices[0].item<int>();
679 int j = indices[1].item<int>();
680 assert_tensor_equal(x.index({Slice(i, j)}), x.index({Slice(0, 1)}));
681}
682
683TEST(TensorIndexingTest, TestEllipsisTensor) {
684 auto x = torch::arange(0, 9).to(torch::kLong).view({3, 3});
685 auto idx = torch::tensor({0, 2});
686 assert_tensor_equal(
687 x.index({"...", idx}), torch::tensor({{0, 2}, {3, 5}, {6, 8}}));
688 assert_tensor_equal(
689 x.index({idx, "..."}), torch::tensor({{0, 1, 2}, {6, 7, 8}}));
690}
691
692TEST(TensorIndexingTest, TestOutOfBoundIndex) {
693 auto x = torch::arange(0, 100).view({2, 5, 10});
694 ASSERT_THROWS_WITH(
695 x.index({0, 5}), "index 5 is out of bounds for dimension 1 with size 5");
696 ASSERT_THROWS_WITH(
697 x.index({4, 5}), "index 4 is out of bounds for dimension 0 with size 2");
698 ASSERT_THROWS_WITH(
699 x.index({0, 1, 15}),
700 "index 15 is out of bounds for dimension 2 with size 10");
701 ASSERT_THROWS_WITH(
702 x.index({Slice(), Slice(), 12}),
703 "index 12 is out of bounds for dimension 2 with size 10");
704}
705
706TEST(TensorIndexingTest, TestZeroDimIndex) {
707 auto x = torch::tensor(10);
708
709 auto runner = [&]() -> torch::Tensor {
710 std::cout << x.index({0}) << std::endl;
711 return x.index({0});
712 };
713
714 ASSERT_THROWS_WITH(runner(), "invalid index");
715}
716
717// The tests below are from NumPy test_indexing.py with some modifications to
718// make them compatible with libtorch. It's licensed under the BDS license
719// below:
720//
721// Copyright (c) 2005-2017, NumPy Developers.
722// All rights reserved.
723//
724// Redistribution and use in source and binary forms, with or without
725// modification, are permitted provided that the following conditions are
726// met:
727//
728// * Redistributions of source code must retain the above copyright
729// notice, this list of conditions and the following disclaimer.
730//
731// * Redistributions in binary form must reproduce the above
732// copyright notice, this list of conditions and the following
733// disclaimer in the documentation and/or other materials provided
734// with the distribution.
735//
736// * Neither the name of the NumPy Developers nor the names of any
737// contributors may be used to endorse or promote products derived
738// from this software without specific prior written permission.
739//
740// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
741// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
742// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
743// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
744// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
745// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
746// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
747// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
748// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
749// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
750// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
751
752TEST(NumpyTests, TestNoneIndex) {
753 // `None` index adds newaxis
754 auto a = torch::tensor({1, 2, 3});
755 ASSERT_EQ(a.index({None}).dim(), a.dim() + 1);
756}
757
758TEST(NumpyTests, TestEmptyFancyIndex) {
759 // Empty list index creates an empty array
760 auto a = torch::tensor({1, 2, 3});
761 assert_tensor_equal(
762 a.index({torch::tensor({}, torch::kLong)}), torch::tensor({}));
763
764 auto b = torch::tensor({}).to(torch::kLong);
765 assert_tensor_equal(
766 a.index({torch::tensor({}, torch::kLong)}),
767 torch::tensor({}, torch::kLong));
768
769 b = torch::tensor({}).to(torch::kFloat);
770 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
771 ASSERT_THROW(a.index({b}), c10::Error);
772}
773
774TEST(NumpyTests, TestEllipsisIndex) {
775 auto a = torch::tensor({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
776 ASSERT_FALSE(a.index({"..."}).is_same(a));
777 assert_tensor_equal(a.index({"..."}), a);
778 // `a[...]` was `a` in numpy <1.9.
779 ASSERT_EQ(a.index({"..."}).data_ptr(), a.data_ptr());
780
781 // Slicing with ellipsis can skip an
782 // arbitrary number of dimensions
783 assert_tensor_equal(a.index({0, "..."}), a.index({0}));
784 assert_tensor_equal(a.index({0, "..."}), a.index({0, Slice()}));
785 assert_tensor_equal(a.index({"...", 0}), a.index({Slice(), 0}));
786
787 // In NumPy, slicing with ellipsis results in a 0-dim array. In PyTorch
788 // we don't have separate 0-dim arrays and scalars.
789 assert_tensor_equal(a.index({0, "...", 1}), torch::tensor(2));
790
791 // Assignment with `Ellipsis` on 0-d arrays
792 auto b = torch::tensor(1);
793 b.index_put_({Ellipsis}, 2);
794 ASSERT_EQ(b.item<int64_t>(), 2);
795}
796
797TEST(NumpyTests, TestSingleIntIndex) {
798 // Single integer index selects one row
799 auto a = torch::tensor({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
800
801 assert_tensor_equal(a.index({0}), torch::tensor({1, 2, 3}));
802 assert_tensor_equal(a.index({-1}), torch::tensor({7, 8, 9}));
803
804 // Index out of bounds produces IndexError
805 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
806 ASSERT_THROW(a.index({1 << 30}), c10::Error);
807 // NOTE: According to the standard
808 // (http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2017/p0543r0.html), for
809 // signed integers, if during the evaluation of an expression, the result is
810 // not mathematically defined or not in the range of representable values for
811 // its type, the behavior is undefined. Therefore, there is no way to check
812 // for index overflow case because it might not throw exception.
813 // ASSERT_THROW(a(1 << 64), c10::Error);
814}
815
816TEST(NumpyTests, TestSingleBoolIndex) {
817 // Single boolean index
818 auto a = torch::tensor({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
819
820 assert_tensor_equal(a.index({true}), a.index({None}));
821 assert_tensor_equal(a.index({false}), a.index({None}).index({Slice(0, 0)}));
822}
823
824TEST(NumpyTests, TestBooleanShapeMismatch) {
825 auto arr = torch::ones({5, 4, 3});
826
827 auto index = torch::tensor({true});
828 ASSERT_THROWS_WITH(arr.index({index}), "mask");
829
830 index = torch::tensor({false, false, false, false, false, false});
831 ASSERT_THROWS_WITH(arr.index({index}), "mask");
832
833 {
834 WarningCapture warnings;
835
836 index = torch::empty({4, 4}, torch::kByte).zero_();
837 ASSERT_THROWS_WITH(arr.index({index}), "mask");
838 ASSERT_THROWS_WITH(arr.index({Slice(), index}), "mask");
839
840 ASSERT_EQ(
841 count_substr_occurrences(
842 warnings.str(),
843 "indexing with dtype torch.uint8 is now deprecated"),
844 2);
845 }
846}
847
848TEST(NumpyTests, TestBooleanIndexingOnedim) {
849 // Indexing a 2-dimensional array with
850 // boolean array of length one
851 auto a = torch::tensor({{0., 0., 0.}});
852 auto b = torch::tensor({true});
853 assert_tensor_equal(a.index({b}), a);
854 // boolean assignment
855 a.index_put_({b}, 1.);
856 assert_tensor_equal(a, torch::tensor({{1., 1., 1.}}));
857}
858
859TEST(NumpyTests, TestBooleanAssignmentValueMismatch) {
860 // A boolean assignment should fail when the shape of the values
861 // cannot be broadcast to the subscription. (see also gh-3458)
862 auto a = torch::arange(0, 4);
863
864 auto f = [](torch::Tensor a, std::vector<int64_t> v) -> void {
865 a.index_put_({a > -1}, torch::tensor(v));
866 };
867
868 ASSERT_THROWS_WITH(f(a, {}), "shape mismatch");
869 ASSERT_THROWS_WITH(f(a, {1, 2, 3}), "shape mismatch");
870 ASSERT_THROWS_WITH(f(a.index({Slice(None, 1)}), {1, 2, 3}), "shape mismatch");
871}
872
873TEST(NumpyTests, TestBooleanIndexingTwodim) {
874 // Indexing a 2-dimensional array with
875 // 2-dimensional boolean array
876 auto a = torch::tensor({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
877 auto b = torch::tensor(
878 {{true, false, true}, {false, true, false}, {true, false, true}});
879 assert_tensor_equal(a.index({b}), torch::tensor({1, 3, 5, 7, 9}));
880 assert_tensor_equal(a.index({b.index({1})}), torch::tensor({{4, 5, 6}}));
881 assert_tensor_equal(a.index({b.index({0})}), a.index({b.index({2})}));
882
883 // boolean assignment
884 a.index_put_({b}, 0);
885 assert_tensor_equal(a, torch::tensor({{0, 2, 0}, {4, 0, 6}, {0, 8, 0}}));
886}
887
888TEST(NumpyTests, TestBooleanIndexingWeirdness) {
889 // Weird boolean indexing things
890 auto a = torch::ones({2, 3, 4});
891 ASSERT_EQ(
892 a.index({false, true, "..."}).sizes(), torch::IntArrayRef({0, 2, 3, 4}));
893 assert_tensor_equal(
894 torch::ones({1, 2}),
895 a.index(
896 {true,
897 torch::tensor({0, 1}),
898 true,
899 true,
900 torch::tensor({1}),
901 torch::tensor({{2}})}));
902 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
903 ASSERT_THROW(a.index({false, torch::tensor({0, 1}), "..."}), c10::Error);
904}
905
906TEST(NumpyTests, TestBooleanIndexingWeirdnessTensors) {
907 // Weird boolean indexing things
908 auto false_tensor = torch::tensor(false);
909 auto true_tensor = torch::tensor(true);
910 auto a = torch::ones({2, 3, 4});
911 ASSERT_EQ(
912 a.index({false, true, "..."}).sizes(), torch::IntArrayRef({0, 2, 3, 4}));
913 assert_tensor_equal(
914 torch::ones({1, 2}),
915 a.index(
916 {true_tensor,
917 torch::tensor({0, 1}),
918 true_tensor,
919 true_tensor,
920 torch::tensor({1}),
921 torch::tensor({{2}})}));
922 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
923 ASSERT_THROW(
924 a.index({false_tensor, torch::tensor({0, 1}), "..."}), c10::Error);
925}
926
927TEST(NumpyTests, TestBooleanIndexingAlldims) {
928 auto true_tensor = torch::tensor(true);
929 auto a = torch::ones({2, 3});
930 ASSERT_EQ(a.index({true, true}).sizes(), torch::IntArrayRef({1, 2, 3}));
931 ASSERT_EQ(
932 a.index({true_tensor, true_tensor}).sizes(),
933 torch::IntArrayRef({1, 2, 3}));
934}
935
936TEST(NumpyTests, TestBooleanListIndexing) {
937 // Indexing a 2-dimensional array with
938 // boolean lists
939 auto a = torch::tensor({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
940 auto b = torch::tensor({true, false, false});
941 auto c = torch::tensor({true, true, false});
942 assert_tensor_equal(a.index({b}), torch::tensor({{1, 2, 3}}));
943 assert_tensor_equal(a.index({b, b}), torch::tensor({1}));
944 assert_tensor_equal(a.index({c}), torch::tensor({{1, 2, 3}, {4, 5, 6}}));
945 assert_tensor_equal(a.index({c, c}), torch::tensor({1, 5}));
946}
947
948TEST(NumpyTests, TestEverythingReturnsViews) {
949 // Before `...` would return a itself.
950 auto a = torch::tensor({5});
951
952 ASSERT_FALSE(a.is_same(a.index({"..."})));
953 ASSERT_FALSE(a.is_same(a.index({Slice()})));
954}
955
956TEST(NumpyTests, TestBroaderrorsIndexing) {
957 auto a = torch::zeros({5, 5});
958 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
959 ASSERT_THROW(
960 a.index({torch::tensor({0, 1}), torch::tensor({0, 1, 2})}), c10::Error);
961 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
962 ASSERT_THROW(
963 a.index_put_({torch::tensor({0, 1}), torch::tensor({0, 1, 2})}, 0),
964 c10::Error);
965}
966
967TEST(NumpyTests, TestTrivialFancyOutOfBounds) {
968 auto a = torch::zeros({5});
969 auto ind = torch::ones({20}, torch::kInt64);
970 ind.index_put_({-1}, 10);
971 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
972 ASSERT_THROW(a.index({ind}), c10::Error);
973 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
974 ASSERT_THROW(a.index_put_({ind}, 0), c10::Error);
975 ind = torch::ones({20}, torch::kInt64);
976 ind.index_put_({0}, 11);
977 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
978 ASSERT_THROW(a.index({ind}), c10::Error);
979 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
980 ASSERT_THROW(a.index_put_({ind}, 0), c10::Error);
981}
982
983TEST(NumpyTests, TestIndexIsLarger) {
984 // Simple case of fancy index broadcasting of the index.
985 auto a = torch::zeros({5, 5});
986 a.index_put_(
987 {torch::tensor({{0}, {1}, {2}}), torch::tensor({0, 1, 2})},
988 torch::tensor({2., 3., 4.}));
989
990 ASSERT_TRUE(
991 (a.index({Slice(None, 3), Slice(None, 3)}) == torch::tensor({2., 3., 4.}))
992 .all()
993 .item<bool>());
994}
995
996TEST(NumpyTests, TestBroadcastSubspace) {
997 auto a = torch::zeros({100, 100});
998 auto v = torch::arange(0., 100).index({Slice(), None});
999 auto b = torch::arange(99, -1, -1).to(torch::kLong);
1000 a.index_put_({b}, v);
1001 auto expected = b.to(torch::kDouble).unsqueeze(1).expand({100, 100});
1002 assert_tensor_equal(a, expected);
1003}
1004