1 | #include <torch/csrc/autograd/python_variable_indexing.h> |
2 | |
3 | #include <torch/csrc/DynamicTypes.h> |
4 | #include <torch/csrc/Exceptions.h> |
5 | #include <torch/csrc/Export.h> |
6 | #include <torch/csrc/autograd/function.h> |
7 | #include <torch/csrc/autograd/utils/wrap_outputs.h> |
8 | #include <torch/csrc/autograd/variable.h> |
9 | #include <torch/csrc/jit/frontend/tracer.h> |
10 | #include <torch/csrc/jit/ir/ir.h> |
11 | #include <torch/csrc/utils/python_arg_parser.h> |
12 | #include <torch/csrc/utils/python_compat.h> |
13 | #include <torch/csrc/utils/python_numbers.h> |
14 | #include <torch/csrc/utils/tensor_new.h> |
15 | #include <torch/csrc/utils/tensor_types.h> |
16 | |
17 | #include <ATen/DeviceGuard.h> |
18 | #include <ATen/ExpandUtils.h> |
19 | #include <ATen/Functions.h> |
20 | #include <ATen/TensorIndexing.h> |
21 | #include <ATen/TracerMode.h> |
22 | #include <ATen/core/LegacyTypeDispatch.h> |
23 | #include <c10/core/TensorOptions.h> |
24 | #include <c10/util/irange.h> |
25 | |
26 | #include <c10/core/Layout.h> |
27 | #include <tuple> |
28 | #include <vector> |
29 | |
30 | using namespace at; |
31 | using namespace torch::autograd::utils; |
32 | |
33 | namespace torch { |
34 | namespace autograd { |
35 | |
36 | Py_ssize_t THPVariable_length(PyObject* self) { |
37 | HANDLE_TH_ERRORS |
38 | if (check_has_torch_function(self)) { |
39 | py::object ret = py::reinterpret_steal<py::object>( |
40 | handle_torch_function(self, "__len__" )); |
41 | Py_ssize_t length = PyLong_AsSsize_t(ret.ptr()); |
42 | if (PyErr_Occurred()) { |
43 | throw python_error(); |
44 | } |
45 | return length; |
46 | } |
47 | const auto& self_ = THPVariable_Unpack(self); |
48 | if (self_.dim() == 0) { |
49 | return 0; |
50 | } |
51 | // TODO: Maybe this should return a SymInt directly? |
52 | // Add the guard to get a nice error message if/when we will hit this. |
53 | return (Py_ssize_t)self_.sym_size(0).guard_int(__FILE__, __LINE__); |
54 | END_HANDLE_TH_ERRORS_RET(-1) |
55 | } |
56 | |
57 | // We allow indexing by integers, slices, ellipsis, None, Variables, |
58 | // and tuples of those types. We also handle bools as if they were a |
59 | // Variable[ByteTensor]. |
60 | |
61 | static inline int64_t count_specified_dimensions(PyObject* index) { |
62 | // Count the number of indexed dimensions (everything but ellipsis and None) |
63 | // -1 is a sentinel for __torch_function__ |
64 | int64_t count = 0; |
65 | auto size = |
66 | PyTuple_GET_SIZE(index); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) |
67 | for (Py_ssize_t i = 0; i < size; i++) { |
68 | PyObject* obj = PyTuple_GET_ITEM( |
69 | index, i); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) |
70 | if (!THPVariable_CheckExact(obj) && check_has_torch_function(obj)) |
71 | return -1; |
72 | if (THPVariable_Check(obj)) { |
73 | const auto& var = THPVariable_Unpack(obj); |
74 | const auto& var_scalar_type = var.scalar_type(); |
75 | if (var_scalar_type == kByte || var_scalar_type == kBool) { |
76 | count += var.dim(); |
77 | } else { |
78 | count++; |
79 | } |
80 | } else if ( |
81 | obj != Py_None && obj != Py_Ellipsis && obj != Py_True && |
82 | obj != Py_False) { // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) |
83 | count++; |
84 | } |
85 | } |
86 | return count; |
87 | } |
88 | |
89 | [[noreturn]] static inline void invalid_index(PyObject* obj) { |
90 | throw IndexError( |
91 | "only integers, slices (`:`), ellipsis (`...`), None and long or byte " |
92 | "Variables are valid indices (got %s)" , |
93 | Py_TYPE(obj)->tp_name); |
94 | } |
95 | |
96 | static inline Variable sequenceToVariable( |
97 | c10::TensorOptions options, |
98 | PyObject* seq) { |
99 | return torch::utils::indexing_tensor_from_data( |
100 | options, kLong, c10::nullopt, seq); |
101 | } |
102 | |
103 | inline Variable valueToTensor( |
104 | c10::TensorOptions options, |
105 | PyObject* value, |
106 | const at::Device& device) { |
107 | if (THPVariable_Check(value)) { |
108 | return THPVariable_Unpack(value); |
109 | } |
110 | at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove |
111 | at::tracer::impl::NoTracerDispatchMode tracer_guard; |
112 | Scalar scalar; |
113 | if (THPUtils_checkLong(value) || PyBool_Check(value)) { |
114 | scalar = Scalar(THPUtils_unpackLong(value)); |
115 | } else if (PyFloat_Check(value)) { |
116 | scalar = Scalar(THPUtils_unpackDouble(value)); |
117 | } else if (PyComplex_Check(value)) { |
118 | scalar = Scalar(THPUtils_unpackComplexDouble(value)); |
119 | } else { |
120 | throw TypeError( |
121 | "can't assign a %s to a %s" , |
122 | Py_TYPE(value)->tp_name, |
123 | torch::utils::options_to_string(options).c_str()); |
124 | } |
125 | // lift_fresh is supposed to be used in situations where you are guaranteed to |
126 | // get a plain Tensor which is not true for cpu device but not for non cpu |
127 | // device |
128 | if (device == at::kCPU) { |
129 | return at::lift_fresh( |
130 | at::indexing::scalarToTensor(scalar, options, device)); |
131 | } else { |
132 | return at::indexing::scalarToTensor(scalar, options, device); |
133 | } |
134 | } |
135 | |
136 | static inline void recordSliceTrace(PyObject* obj) { |
137 | PySliceObject* sliceobj = (PySliceObject*)obj; |
138 | if (THPVariable_Check(sliceobj->start)) { |
139 | torch::jit::tracer::ArgumentStash::stashValue( |
140 | std::string("start" ), |
141 | 1, |
142 | THPVariable_Unpack(sliceobj->start), |
143 | torch::jit::IntType::get()); |
144 | } |
145 | if (THPVariable_Check(sliceobj->stop)) { |
146 | torch::jit::tracer::ArgumentStash::stashValue( |
147 | std::string("end" ), |
148 | 1, |
149 | THPVariable_Unpack(sliceobj->stop), |
150 | torch::jit::IntType::get()); |
151 | } |
152 | if (THPVariable_Check(sliceobj->step)) { |
153 | torch::jit::tracer::ArgumentStash::stashValue( |
154 | std::string("step" ), |
155 | 1, |
156 | THPVariable_Unpack(sliceobj->step), |
157 | torch::jit::IntType::get()); |
158 | } |
159 | } |
160 | |
161 | static inline void recordSelectTrace(const Tensor& index_tensor) { |
162 | torch::jit::tracer::ArgumentStash::stashValue( |
163 | std::string("index" ), 1, index_tensor, torch::jit::IntType::get()); |
164 | } |
165 | |
166 | static inline Variable applySlicing( |
167 | const Variable& self, |
168 | PyObject* index, |
169 | variable_list& outIndices, |
170 | bool is_tracing, |
171 | const at::Device& self_device, |
172 | const c10::optional<int64_t>& self_ndim, |
173 | int64_t specified_dims) { |
174 | int64_t size = |
175 | PyTuple_GET_SIZE(index); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) |
176 | int64_t dim = 0; |
177 | |
178 | // See NOTE [nested tensor size for indexing] |
179 | if (self_ndim.has_value()) { |
180 | TORCH_CHECK_INDEX( |
181 | specified_dims <= self_ndim.value(), |
182 | "too many indices for tensor of dimension " , |
183 | self_ndim.value()); |
184 | } |
185 | |
186 | Variable result = self; |
187 | for (const auto i : c10::irange(size)) { |
188 | PyObject* obj = PyTuple_GET_ITEM( |
189 | index, i); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) |
190 | // NOTE [nested tensor size for indexing] |
191 | // nested tensor does not have a size (yet) so for now we represent its size |
192 | // as null may need to be changed after we reach a better solution for |
193 | // nested tensor size |
194 | c10::optional<SymIntArrayRef> result_sizes = result.is_nested() |
195 | ? c10::optional<SymIntArrayRef>(c10::nullopt) |
196 | : c10::optional<SymIntArrayRef>(result.sym_sizes()); |
197 | result = at::indexing::handleDimInMultiDimIndexing( |
198 | /*prev_dim_result=*/result, |
199 | /*original_tensor=*/self, |
200 | /*index=*/([&]() { |
201 | if (THPUtils_checkLong(obj)) { |
202 | if (is_tracing && THPVariable_Check(obj)) { |
203 | recordSelectTrace(THPVariable_Unpack(obj)); |
204 | } |
205 | return at::indexing::TensorIndex(THPUtils_unpackLong(obj)); |
206 | } else if (PySlice_Check(obj)) { |
207 | auto val = __PySlice_Unpack(obj); |
208 | if (is_tracing) { |
209 | recordSliceTrace(obj); |
210 | } |
211 | return at::indexing::TensorIndex( |
212 | at::indexing::Slice(val.start, val.stop, val.step)); |
213 | } else if (obj == Py_Ellipsis) { |
214 | return at::indexing::TensorIndex(at::indexing::Ellipsis); |
215 | } else if (obj == Py_None) { |
216 | return at::indexing::TensorIndex(at::indexing::None); |
217 | } else if (PyBool_Check(obj)) { |
218 | return at::indexing::TensorIndex(obj == Py_True); |
219 | } else if (THPVariable_Check(obj)) { |
220 | Tensor tensor = THPVariable_Unpack(obj); |
221 | if (is_tracing) { |
222 | auto scalar_type = tensor.scalar_type(); |
223 | if (tensor.dim() == 0 && |
224 | at::isIntegralType(scalar_type, /*includeBool=*/false) && |
225 | scalar_type != at::kByte) { |
226 | recordSelectTrace(tensor); |
227 | } |
228 | } |
229 | return at::indexing::TensorIndex(std::move(tensor)); |
230 | } else if (PySequence_Check(obj)) { |
231 | return at::indexing::TensorIndex( |
232 | sequenceToVariable(self.options(), obj)); |
233 | } else { |
234 | auto idx = THPObjectPtr(PyNumber_Index(obj)); |
235 | if (!idx) { |
236 | PyErr_Clear(); |
237 | invalid_index(obj); |
238 | } |
239 | if (is_tracing && THPVariable_Check(idx)) { |
240 | recordSelectTrace(THPVariable_Unpack(idx)); |
241 | } |
242 | return at::indexing::TensorIndex(THPUtils_unpackLong(idx)); |
243 | } |
244 | })(), |
245 | /*dim_ptr=*/&dim, |
246 | /*specified_dims_ptr=*/&specified_dims, |
247 | /*real_dim=*/i, |
248 | /*outIndices=*/outIndices, |
249 | // See NOTE [ Setting `disable_slice_optimization` when calling C++ |
250 | // tensor indexing functions from Python ] |
251 | /*disable_slice_optimization=*/is_tracing, |
252 | /*original_tensor_device=*/self_device, |
253 | /*prev_dim_result_sizes=*/result_sizes); |
254 | } |
255 | return result; |
256 | } |
257 | |
258 | static inline bool treatSequenceAsTuple(PyObject* index) { |
259 | if (PyTuple_Check(index)) { |
260 | return true; |
261 | } |
262 | if (THPVariable_Check(index)) { |
263 | return false; |
264 | } |
265 | if (!PySequence_Check(index)) { |
266 | return false; |
267 | } |
268 | // This uses a heuristics from NumPy for determining whether to treat |
269 | // non-tuple sequences as if they were a tuple. From the NumPy code comments: |
270 | // |
271 | // "At this point, we're left with a non-tuple, non-array, sequence: |
272 | // typically, a list. We use some somewhat-arbitrary heuristics from here |
273 | // onwards to decided whether to treat that list as a single index, or a |
274 | // list of indices. Backwards compatibility only takes effect for short |
275 | // sequences - otherwise we treat it like any other scalar." |
276 | auto n = PySequence_Size(index); |
277 | if (n < 0) { |
278 | // Negative size indicates a Python error in the PySequence_Size call. |
279 | PyErr_Clear(); |
280 | return false; |
281 | } |
282 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
283 | if (n >= 32) { |
284 | return false; |
285 | } |
286 | for (Py_ssize_t i = 0; i < n; i++) { |
287 | auto obj = THPObjectPtr{PySequence_GetItem(index, i)}; |
288 | if (!obj.get()) { |
289 | PyErr_Clear(); |
290 | return false; |
291 | } |
292 | if (THPVariable_Check(obj.get()) || PySequence_Check(obj.get()) || |
293 | PySlice_Check(obj.get())) { |
294 | return true; |
295 | } |
296 | if (obj.get() == Py_Ellipsis || obj.get() == Py_None) { |
297 | return true; |
298 | } |
299 | } |
300 | return false; |
301 | } |
302 | |
303 | static inline THPObjectPtr wrapTuple(PyObject* index) { |
304 | THPObjectPtr res; |
305 | if (treatSequenceAsTuple(index)) { |
306 | res = PySequence_Tuple(index); |
307 | } else { |
308 | res = PyTuple_Pack( |
309 | 1, index); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) |
310 | } |
311 | if (!res) |
312 | throw python_error(); |
313 | return res; |
314 | } |
315 | |
316 | // NOTE: Here is the dispatch structure for `THPVariable_getitem`: |
317 | // |
318 | // 1. Python 1-D getter calls C++ `at::indexing::get_item` after |
319 | // converting Python index to C++ TensorIndex. |
320 | // |
321 | // 2. Python N-D getter calls C++ `at::indexing::handleDimInMultiDimIndexing` |
322 | // for each dim, after converting Python index to C++ TensorIndex. If advanced |
323 | // indexing is needed, it calls C++ `at::indexing::dispatch_index`. |
324 | PyObject* THPVariable_getitem(PyObject* self, PyObject* index) { |
325 | HANDLE_TH_ERRORS |
326 | if (!THPVariable_CheckExact(self) && check_has_torch_function(self)) { |
327 | return handle_torch_function_indexing(self, index); |
328 | } |
329 | const auto& self_ = THPVariable_Unpack(self); |
330 | OptionalDeviceGuard device_guard(device_of(self_)); |
331 | |
332 | // handle simple types: none, ellipsis |
333 | if (index == Py_None) { |
334 | return THPVariable_Wrap(at::indexing::get_item( |
335 | self_, {at::indexing::TensorIndex(at::indexing::None)})); |
336 | } else if (index == Py_Ellipsis) { |
337 | return THPVariable_Wrap(at::indexing::get_item( |
338 | self_, {at::indexing::TensorIndex(at::indexing::Ellipsis)})); |
339 | } |
340 | |
341 | bool is_tracing = torch::jit::tracer::isTracing(); |
342 | |
343 | // handle simple types: integers, slices, bool |
344 | if (THPUtils_checkLong(index)) { |
345 | if (is_tracing && THPVariable_Check(index)) { |
346 | recordSelectTrace(THPVariable_Unpack(index)); |
347 | } |
348 | return THPVariable_Wrap(at::indexing::get_item( |
349 | self_, {at::indexing::TensorIndex(THPUtils_unpackLong(index))})); |
350 | } else if (PySlice_Check(index)) { |
351 | auto val = __PySlice_Unpack(index); |
352 | if (is_tracing) { |
353 | recordSliceTrace(index); |
354 | } |
355 | return THPVariable_Wrap(at::indexing::get_item( |
356 | self_, |
357 | {at::indexing::TensorIndex( |
358 | at::indexing::Slice(val.start, val.stop, val.step))})); |
359 | } else if (index == Py_False || index == Py_True) { |
360 | return THPVariable_Wrap(([&]() { |
361 | pybind11::gil_scoped_release no_gil; |
362 | return at::indexing::get_item( |
363 | self_, {at::indexing::TensorIndex(index == Py_True)}); |
364 | })()); |
365 | } |
366 | |
367 | // wrap index in a tuple if it's not already one |
368 | THPObjectPtr holder = wrapTuple(index); |
369 | |
370 | variable_list variableIndices; |
371 | int64_t specified_dims = count_specified_dimensions(holder.get()); |
372 | if (specified_dims == -1) { |
373 | return handle_torch_function_indexing(self, holder.get()); |
374 | } |
375 | Variable sliced = applySlicing( |
376 | self_, |
377 | holder.get(), |
378 | variableIndices, |
379 | /*is_tracing=*/is_tracing, |
380 | self_.device(), |
381 | self_.ndimension(), |
382 | specified_dims); |
383 | if (variableIndices.empty()) { |
384 | if (sliced.is_same(self_)) { |
385 | // ensure we return a shallow copy for things like x[...] |
386 | sliced = at::alias(sliced); |
387 | } |
388 | return THPVariable_Wrap(std::move(sliced)); |
389 | } |
390 | |
391 | // indexing by tensors ("advanced" indexing) |
392 | return THPVariable_Wrap(([&]() { |
393 | pybind11::gil_scoped_release no_gil; |
394 | return at::indexing::dispatch_index(sliced, std::move(variableIndices)); |
395 | })()); |
396 | |
397 | Py_RETURN_NONE; |
398 | END_HANDLE_TH_ERRORS |
399 | } |
400 | |
401 | void dispatch_set_item( |
402 | const Tensor& self, |
403 | ArrayRef<at::indexing::TensorIndex> indices, |
404 | const Tensor& value, |
405 | bool disable_slice_optimization = false) { |
406 | pybind11::gil_scoped_release no_gil; |
407 | at::indexing::set_item(self, indices, value, disable_slice_optimization); |
408 | } |
409 | |
410 | // NOTE: Here is the dispatch structure for `THPVariable_setitem`: |
411 | // |
412 | // 1. Python 1-D setter calls C++ `at::indexing::set_item` after |
413 | // converting Python index to C++ TensorIndex. |
414 | // |
415 | // 2. Python N-D setter calls C++ `at::indexing::handleDimInMultiDimIndexing` |
416 | // for each dim, after converting Python index to C++ TensorIndex. If advanced |
417 | // indexing is needed, it calls C++ `at::indexing::dispatch_index_put_`. |
418 | int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) { |
419 | HANDLE_TH_ERRORS |
420 | if (py_value == nullptr) { |
421 | throw TypeError("Tensor does not support deleting items" ); |
422 | } |
423 | if ((!THPVariable_CheckExact(self) && check_has_torch_function(self)) || |
424 | (!THPVariable_CheckExact(py_value) && |
425 | check_has_torch_function(py_value))) { |
426 | py::object ret = py::reinterpret_steal<py::object>( |
427 | handle_torch_function_indexing(self, index, py_value)); |
428 | return 0; |
429 | } |
430 | |
431 | const auto& self_ = THPVariable_Unpack(self); |
432 | if (self_.layout() == kSparse || self_.layout() == kSparseCsr || |
433 | self_.layout() == kSparseCsc || self_.layout() == kSparseBsr || |
434 | self_.layout() == kSparseBsc) { |
435 | throw TypeError("Cannot assign to a sparse tensor" ); |
436 | } |
437 | OptionalDeviceGuard device_guard(device_of(self_)); |
438 | at::Device self_device = self_.device(); |
439 | Variable value; |
440 | // TODO: This qint special case looks very suspicious... |
441 | if (isQIntType(self_.scalar_type())) { |
442 | value = |
443 | valueToTensor(device(kCPU).dtype(kFloat), py_value, at::Device(kCPU)); |
444 | } else if (self_device.is_cuda()) { |
445 | value = valueToTensor(self_.options(), py_value, at::Device(kCPU)); |
446 | } else { |
447 | value = valueToTensor(self_.options(), py_value, self_device); |
448 | } |
449 | |
450 | // handle simple types: ellipsis, none, bool |
451 | if (index == Py_False) { // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) |
452 | // do nothing for false (technically we should check the size, but we don't |
453 | // have real 0-sized shapes. |
454 | return 0; |
455 | } else if (index == Py_Ellipsis) { |
456 | dispatch_set_item( |
457 | self_, {at::indexing::TensorIndex(at::indexing::Ellipsis)}, value); |
458 | return 0; |
459 | } else if (index == Py_None) { |
460 | dispatch_set_item( |
461 | self_, {at::indexing::TensorIndex(at::indexing::None)}, value); |
462 | return 0; |
463 | } else if (index == Py_True) { |
464 | dispatch_set_item(self_, {at::indexing::TensorIndex(true)}, value); |
465 | return 0; |
466 | } |
467 | |
468 | bool is_tracing = torch::jit::tracer::isTracing(); |
469 | |
470 | // handle simple types: integers, slices |
471 | if (THPUtils_checkLong(index)) { |
472 | if (is_tracing && THPVariable_Check(index)) { |
473 | recordSelectTrace(THPVariable_Unpack(index)); |
474 | } |
475 | dispatch_set_item( |
476 | self_, {at::indexing::TensorIndex(THPUtils_unpackLong(index))}, value); |
477 | return 0; |
478 | } else if (PySlice_Check(index)) { |
479 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
480 | auto val = __PySlice_Unpack(index); |
481 | if (is_tracing) { |
482 | recordSliceTrace(index); |
483 | } |
484 | // See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor |
485 | // indexing functions from Python ] |
486 | dispatch_set_item( |
487 | self_, |
488 | {at::indexing::TensorIndex( |
489 | at::indexing::Slice(val.start, val.stop, val.step))}, |
490 | value, |
491 | /*disable_slice_optimization=*/is_tracing); |
492 | return 0; |
493 | } |
494 | |
495 | // wrap index in a tuple if it's not already one |
496 | THPObjectPtr holder = wrapTuple(index); |
497 | |
498 | variable_list variableIndices; |
499 | int64_t specified_dims = count_specified_dimensions(holder.get()); |
500 | if (specified_dims == -1) { |
501 | py::object val = py::reinterpret_steal<py::object>( |
502 | handle_torch_function_indexing(self, index, py_value)); |
503 | return 0; |
504 | } |
505 | Variable sliced = applySlicing( |
506 | self_, |
507 | holder.get(), |
508 | variableIndices, |
509 | /*is_tracing=*/is_tracing, |
510 | self_device, |
511 | self_.ndimension(), |
512 | specified_dims); |
513 | if (variableIndices.empty()) { |
514 | pybind11::gil_scoped_release no_gil; |
515 | at::indexing::copy_to(sliced, value); |
516 | return 0; |
517 | } |
518 | |
519 | { |
520 | pybind11::gil_scoped_release no_gil; |
521 | SymIntArrayRef valueSizes = value.sym_sizes(); |
522 | SymIntArrayRef slicedValueSizes = |
523 | at::indexing::slicePrefix1sSize(valueSizes); |
524 | torch::autograd::Variable valuesSliced; |
525 | if (!valueSizes.equals(slicedValueSizes)) { |
526 | valuesSliced = value.view_symint(slicedValueSizes); |
527 | } else { |
528 | valuesSliced = value; |
529 | } |
530 | at::indexing::dispatch_index_put_( |
531 | sliced, std::move(variableIndices), valuesSliced); |
532 | return 0; |
533 | } |
534 | END_HANDLE_TH_ERRORS_RET(-1) |
535 | } |
536 | |
537 | } // namespace autograd |
538 | } // namespace torch |
539 | |