1#include <torch/csrc/autograd/python_variable_indexing.h>
2
3#include <torch/csrc/DynamicTypes.h>
4#include <torch/csrc/Exceptions.h>
5#include <torch/csrc/Export.h>
6#include <torch/csrc/autograd/function.h>
7#include <torch/csrc/autograd/utils/wrap_outputs.h>
8#include <torch/csrc/autograd/variable.h>
9#include <torch/csrc/jit/frontend/tracer.h>
10#include <torch/csrc/jit/ir/ir.h>
11#include <torch/csrc/utils/python_arg_parser.h>
12#include <torch/csrc/utils/python_compat.h>
13#include <torch/csrc/utils/python_numbers.h>
14#include <torch/csrc/utils/tensor_new.h>
15#include <torch/csrc/utils/tensor_types.h>
16
17#include <ATen/DeviceGuard.h>
18#include <ATen/ExpandUtils.h>
19#include <ATen/Functions.h>
20#include <ATen/TensorIndexing.h>
21#include <ATen/TracerMode.h>
22#include <ATen/core/LegacyTypeDispatch.h>
23#include <c10/core/TensorOptions.h>
24#include <c10/util/irange.h>
25
26#include <c10/core/Layout.h>
27#include <tuple>
28#include <vector>
29
30using namespace at;
31using namespace torch::autograd::utils;
32
33namespace torch {
34namespace autograd {
35
36Py_ssize_t THPVariable_length(PyObject* self) {
37 HANDLE_TH_ERRORS
38 if (check_has_torch_function(self)) {
39 py::object ret = py::reinterpret_steal<py::object>(
40 handle_torch_function(self, "__len__"));
41 Py_ssize_t length = PyLong_AsSsize_t(ret.ptr());
42 if (PyErr_Occurred()) {
43 throw python_error();
44 }
45 return length;
46 }
47 const auto& self_ = THPVariable_Unpack(self);
48 if (self_.dim() == 0) {
49 return 0;
50 }
51 // TODO: Maybe this should return a SymInt directly?
52 // Add the guard to get a nice error message if/when we will hit this.
53 return (Py_ssize_t)self_.sym_size(0).guard_int(__FILE__, __LINE__);
54 END_HANDLE_TH_ERRORS_RET(-1)
55}
56
57// We allow indexing by integers, slices, ellipsis, None, Variables,
58// and tuples of those types. We also handle bools as if they were a
59// Variable[ByteTensor].
60
61static inline int64_t count_specified_dimensions(PyObject* index) {
62 // Count the number of indexed dimensions (everything but ellipsis and None)
63 // -1 is a sentinel for __torch_function__
64 int64_t count = 0;
65 auto size =
66 PyTuple_GET_SIZE(index); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
67 for (Py_ssize_t i = 0; i < size; i++) {
68 PyObject* obj = PyTuple_GET_ITEM(
69 index, i); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
70 if (!THPVariable_CheckExact(obj) && check_has_torch_function(obj))
71 return -1;
72 if (THPVariable_Check(obj)) {
73 const auto& var = THPVariable_Unpack(obj);
74 const auto& var_scalar_type = var.scalar_type();
75 if (var_scalar_type == kByte || var_scalar_type == kBool) {
76 count += var.dim();
77 } else {
78 count++;
79 }
80 } else if (
81 obj != Py_None && obj != Py_Ellipsis && obj != Py_True &&
82 obj != Py_False) { // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
83 count++;
84 }
85 }
86 return count;
87}
88
89[[noreturn]] static inline void invalid_index(PyObject* obj) {
90 throw IndexError(
91 "only integers, slices (`:`), ellipsis (`...`), None and long or byte "
92 "Variables are valid indices (got %s)",
93 Py_TYPE(obj)->tp_name);
94}
95
96static inline Variable sequenceToVariable(
97 c10::TensorOptions options,
98 PyObject* seq) {
99 return torch::utils::indexing_tensor_from_data(
100 options, kLong, c10::nullopt, seq);
101}
102
103inline Variable valueToTensor(
104 c10::TensorOptions options,
105 PyObject* value,
106 const at::Device& device) {
107 if (THPVariable_Check(value)) {
108 return THPVariable_Unpack(value);
109 }
110 at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove
111 at::tracer::impl::NoTracerDispatchMode tracer_guard;
112 Scalar scalar;
113 if (THPUtils_checkLong(value) || PyBool_Check(value)) {
114 scalar = Scalar(THPUtils_unpackLong(value));
115 } else if (PyFloat_Check(value)) {
116 scalar = Scalar(THPUtils_unpackDouble(value));
117 } else if (PyComplex_Check(value)) {
118 scalar = Scalar(THPUtils_unpackComplexDouble(value));
119 } else {
120 throw TypeError(
121 "can't assign a %s to a %s",
122 Py_TYPE(value)->tp_name,
123 torch::utils::options_to_string(options).c_str());
124 }
125 // lift_fresh is supposed to be used in situations where you are guaranteed to
126 // get a plain Tensor which is not true for cpu device but not for non cpu
127 // device
128 if (device == at::kCPU) {
129 return at::lift_fresh(
130 at::indexing::scalarToTensor(scalar, options, device));
131 } else {
132 return at::indexing::scalarToTensor(scalar, options, device);
133 }
134}
135
136static inline void recordSliceTrace(PyObject* obj) {
137 PySliceObject* sliceobj = (PySliceObject*)obj;
138 if (THPVariable_Check(sliceobj->start)) {
139 torch::jit::tracer::ArgumentStash::stashValue(
140 std::string("start"),
141 1,
142 THPVariable_Unpack(sliceobj->start),
143 torch::jit::IntType::get());
144 }
145 if (THPVariable_Check(sliceobj->stop)) {
146 torch::jit::tracer::ArgumentStash::stashValue(
147 std::string("end"),
148 1,
149 THPVariable_Unpack(sliceobj->stop),
150 torch::jit::IntType::get());
151 }
152 if (THPVariable_Check(sliceobj->step)) {
153 torch::jit::tracer::ArgumentStash::stashValue(
154 std::string("step"),
155 1,
156 THPVariable_Unpack(sliceobj->step),
157 torch::jit::IntType::get());
158 }
159}
160
161static inline void recordSelectTrace(const Tensor& index_tensor) {
162 torch::jit::tracer::ArgumentStash::stashValue(
163 std::string("index"), 1, index_tensor, torch::jit::IntType::get());
164}
165
166static inline Variable applySlicing(
167 const Variable& self,
168 PyObject* index,
169 variable_list& outIndices,
170 bool is_tracing,
171 const at::Device& self_device,
172 const c10::optional<int64_t>& self_ndim,
173 int64_t specified_dims) {
174 int64_t size =
175 PyTuple_GET_SIZE(index); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
176 int64_t dim = 0;
177
178 // See NOTE [nested tensor size for indexing]
179 if (self_ndim.has_value()) {
180 TORCH_CHECK_INDEX(
181 specified_dims <= self_ndim.value(),
182 "too many indices for tensor of dimension ",
183 self_ndim.value());
184 }
185
186 Variable result = self;
187 for (const auto i : c10::irange(size)) {
188 PyObject* obj = PyTuple_GET_ITEM(
189 index, i); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
190 // NOTE [nested tensor size for indexing]
191 // nested tensor does not have a size (yet) so for now we represent its size
192 // as null may need to be changed after we reach a better solution for
193 // nested tensor size
194 c10::optional<SymIntArrayRef> result_sizes = result.is_nested()
195 ? c10::optional<SymIntArrayRef>(c10::nullopt)
196 : c10::optional<SymIntArrayRef>(result.sym_sizes());
197 result = at::indexing::handleDimInMultiDimIndexing(
198 /*prev_dim_result=*/result,
199 /*original_tensor=*/self,
200 /*index=*/([&]() {
201 if (THPUtils_checkLong(obj)) {
202 if (is_tracing && THPVariable_Check(obj)) {
203 recordSelectTrace(THPVariable_Unpack(obj));
204 }
205 return at::indexing::TensorIndex(THPUtils_unpackLong(obj));
206 } else if (PySlice_Check(obj)) {
207 auto val = __PySlice_Unpack(obj);
208 if (is_tracing) {
209 recordSliceTrace(obj);
210 }
211 return at::indexing::TensorIndex(
212 at::indexing::Slice(val.start, val.stop, val.step));
213 } else if (obj == Py_Ellipsis) {
214 return at::indexing::TensorIndex(at::indexing::Ellipsis);
215 } else if (obj == Py_None) {
216 return at::indexing::TensorIndex(at::indexing::None);
217 } else if (PyBool_Check(obj)) {
218 return at::indexing::TensorIndex(obj == Py_True);
219 } else if (THPVariable_Check(obj)) {
220 Tensor tensor = THPVariable_Unpack(obj);
221 if (is_tracing) {
222 auto scalar_type = tensor.scalar_type();
223 if (tensor.dim() == 0 &&
224 at::isIntegralType(scalar_type, /*includeBool=*/false) &&
225 scalar_type != at::kByte) {
226 recordSelectTrace(tensor);
227 }
228 }
229 return at::indexing::TensorIndex(std::move(tensor));
230 } else if (PySequence_Check(obj)) {
231 return at::indexing::TensorIndex(
232 sequenceToVariable(self.options(), obj));
233 } else {
234 auto idx = THPObjectPtr(PyNumber_Index(obj));
235 if (!idx) {
236 PyErr_Clear();
237 invalid_index(obj);
238 }
239 if (is_tracing && THPVariable_Check(idx)) {
240 recordSelectTrace(THPVariable_Unpack(idx));
241 }
242 return at::indexing::TensorIndex(THPUtils_unpackLong(idx));
243 }
244 })(),
245 /*dim_ptr=*/&dim,
246 /*specified_dims_ptr=*/&specified_dims,
247 /*real_dim=*/i,
248 /*outIndices=*/outIndices,
249 // See NOTE [ Setting `disable_slice_optimization` when calling C++
250 // tensor indexing functions from Python ]
251 /*disable_slice_optimization=*/is_tracing,
252 /*original_tensor_device=*/self_device,
253 /*prev_dim_result_sizes=*/result_sizes);
254 }
255 return result;
256}
257
258static inline bool treatSequenceAsTuple(PyObject* index) {
259 if (PyTuple_Check(index)) {
260 return true;
261 }
262 if (THPVariable_Check(index)) {
263 return false;
264 }
265 if (!PySequence_Check(index)) {
266 return false;
267 }
268 // This uses a heuristics from NumPy for determining whether to treat
269 // non-tuple sequences as if they were a tuple. From the NumPy code comments:
270 //
271 // "At this point, we're left with a non-tuple, non-array, sequence:
272 // typically, a list. We use some somewhat-arbitrary heuristics from here
273 // onwards to decided whether to treat that list as a single index, or a
274 // list of indices. Backwards compatibility only takes effect for short
275 // sequences - otherwise we treat it like any other scalar."
276 auto n = PySequence_Size(index);
277 if (n < 0) {
278 // Negative size indicates a Python error in the PySequence_Size call.
279 PyErr_Clear();
280 return false;
281 }
282 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
283 if (n >= 32) {
284 return false;
285 }
286 for (Py_ssize_t i = 0; i < n; i++) {
287 auto obj = THPObjectPtr{PySequence_GetItem(index, i)};
288 if (!obj.get()) {
289 PyErr_Clear();
290 return false;
291 }
292 if (THPVariable_Check(obj.get()) || PySequence_Check(obj.get()) ||
293 PySlice_Check(obj.get())) {
294 return true;
295 }
296 if (obj.get() == Py_Ellipsis || obj.get() == Py_None) {
297 return true;
298 }
299 }
300 return false;
301}
302
303static inline THPObjectPtr wrapTuple(PyObject* index) {
304 THPObjectPtr res;
305 if (treatSequenceAsTuple(index)) {
306 res = PySequence_Tuple(index);
307 } else {
308 res = PyTuple_Pack(
309 1, index); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
310 }
311 if (!res)
312 throw python_error();
313 return res;
314}
315
316// NOTE: Here is the dispatch structure for `THPVariable_getitem`:
317//
318// 1. Python 1-D getter calls C++ `at::indexing::get_item` after
319// converting Python index to C++ TensorIndex.
320//
321// 2. Python N-D getter calls C++ `at::indexing::handleDimInMultiDimIndexing`
322// for each dim, after converting Python index to C++ TensorIndex. If advanced
323// indexing is needed, it calls C++ `at::indexing::dispatch_index`.
324PyObject* THPVariable_getitem(PyObject* self, PyObject* index) {
325 HANDLE_TH_ERRORS
326 if (!THPVariable_CheckExact(self) && check_has_torch_function(self)) {
327 return handle_torch_function_indexing(self, index);
328 }
329 const auto& self_ = THPVariable_Unpack(self);
330 OptionalDeviceGuard device_guard(device_of(self_));
331
332 // handle simple types: none, ellipsis
333 if (index == Py_None) {
334 return THPVariable_Wrap(at::indexing::get_item(
335 self_, {at::indexing::TensorIndex(at::indexing::None)}));
336 } else if (index == Py_Ellipsis) {
337 return THPVariable_Wrap(at::indexing::get_item(
338 self_, {at::indexing::TensorIndex(at::indexing::Ellipsis)}));
339 }
340
341 bool is_tracing = torch::jit::tracer::isTracing();
342
343 // handle simple types: integers, slices, bool
344 if (THPUtils_checkLong(index)) {
345 if (is_tracing && THPVariable_Check(index)) {
346 recordSelectTrace(THPVariable_Unpack(index));
347 }
348 return THPVariable_Wrap(at::indexing::get_item(
349 self_, {at::indexing::TensorIndex(THPUtils_unpackLong(index))}));
350 } else if (PySlice_Check(index)) {
351 auto val = __PySlice_Unpack(index);
352 if (is_tracing) {
353 recordSliceTrace(index);
354 }
355 return THPVariable_Wrap(at::indexing::get_item(
356 self_,
357 {at::indexing::TensorIndex(
358 at::indexing::Slice(val.start, val.stop, val.step))}));
359 } else if (index == Py_False || index == Py_True) {
360 return THPVariable_Wrap(([&]() {
361 pybind11::gil_scoped_release no_gil;
362 return at::indexing::get_item(
363 self_, {at::indexing::TensorIndex(index == Py_True)});
364 })());
365 }
366
367 // wrap index in a tuple if it's not already one
368 THPObjectPtr holder = wrapTuple(index);
369
370 variable_list variableIndices;
371 int64_t specified_dims = count_specified_dimensions(holder.get());
372 if (specified_dims == -1) {
373 return handle_torch_function_indexing(self, holder.get());
374 }
375 Variable sliced = applySlicing(
376 self_,
377 holder.get(),
378 variableIndices,
379 /*is_tracing=*/is_tracing,
380 self_.device(),
381 self_.ndimension(),
382 specified_dims);
383 if (variableIndices.empty()) {
384 if (sliced.is_same(self_)) {
385 // ensure we return a shallow copy for things like x[...]
386 sliced = at::alias(sliced);
387 }
388 return THPVariable_Wrap(std::move(sliced));
389 }
390
391 // indexing by tensors ("advanced" indexing)
392 return THPVariable_Wrap(([&]() {
393 pybind11::gil_scoped_release no_gil;
394 return at::indexing::dispatch_index(sliced, std::move(variableIndices));
395 })());
396
397 Py_RETURN_NONE;
398 END_HANDLE_TH_ERRORS
399}
400
401void dispatch_set_item(
402 const Tensor& self,
403 ArrayRef<at::indexing::TensorIndex> indices,
404 const Tensor& value,
405 bool disable_slice_optimization = false) {
406 pybind11::gil_scoped_release no_gil;
407 at::indexing::set_item(self, indices, value, disable_slice_optimization);
408}
409
410// NOTE: Here is the dispatch structure for `THPVariable_setitem`:
411//
412// 1. Python 1-D setter calls C++ `at::indexing::set_item` after
413// converting Python index to C++ TensorIndex.
414//
415// 2. Python N-D setter calls C++ `at::indexing::handleDimInMultiDimIndexing`
416// for each dim, after converting Python index to C++ TensorIndex. If advanced
417// indexing is needed, it calls C++ `at::indexing::dispatch_index_put_`.
418int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) {
419 HANDLE_TH_ERRORS
420 if (py_value == nullptr) {
421 throw TypeError("Tensor does not support deleting items");
422 }
423 if ((!THPVariable_CheckExact(self) && check_has_torch_function(self)) ||
424 (!THPVariable_CheckExact(py_value) &&
425 check_has_torch_function(py_value))) {
426 py::object ret = py::reinterpret_steal<py::object>(
427 handle_torch_function_indexing(self, index, py_value));
428 return 0;
429 }
430
431 const auto& self_ = THPVariable_Unpack(self);
432 if (self_.layout() == kSparse || self_.layout() == kSparseCsr ||
433 self_.layout() == kSparseCsc || self_.layout() == kSparseBsr ||
434 self_.layout() == kSparseBsc) {
435 throw TypeError("Cannot assign to a sparse tensor");
436 }
437 OptionalDeviceGuard device_guard(device_of(self_));
438 at::Device self_device = self_.device();
439 Variable value;
440 // TODO: This qint special case looks very suspicious...
441 if (isQIntType(self_.scalar_type())) {
442 value =
443 valueToTensor(device(kCPU).dtype(kFloat), py_value, at::Device(kCPU));
444 } else if (self_device.is_cuda()) {
445 value = valueToTensor(self_.options(), py_value, at::Device(kCPU));
446 } else {
447 value = valueToTensor(self_.options(), py_value, self_device);
448 }
449
450 // handle simple types: ellipsis, none, bool
451 if (index == Py_False) { // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
452 // do nothing for false (technically we should check the size, but we don't
453 // have real 0-sized shapes.
454 return 0;
455 } else if (index == Py_Ellipsis) {
456 dispatch_set_item(
457 self_, {at::indexing::TensorIndex(at::indexing::Ellipsis)}, value);
458 return 0;
459 } else if (index == Py_None) {
460 dispatch_set_item(
461 self_, {at::indexing::TensorIndex(at::indexing::None)}, value);
462 return 0;
463 } else if (index == Py_True) {
464 dispatch_set_item(self_, {at::indexing::TensorIndex(true)}, value);
465 return 0;
466 }
467
468 bool is_tracing = torch::jit::tracer::isTracing();
469
470 // handle simple types: integers, slices
471 if (THPUtils_checkLong(index)) {
472 if (is_tracing && THPVariable_Check(index)) {
473 recordSelectTrace(THPVariable_Unpack(index));
474 }
475 dispatch_set_item(
476 self_, {at::indexing::TensorIndex(THPUtils_unpackLong(index))}, value);
477 return 0;
478 } else if (PySlice_Check(index)) {
479 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
480 auto val = __PySlice_Unpack(index);
481 if (is_tracing) {
482 recordSliceTrace(index);
483 }
484 // See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor
485 // indexing functions from Python ]
486 dispatch_set_item(
487 self_,
488 {at::indexing::TensorIndex(
489 at::indexing::Slice(val.start, val.stop, val.step))},
490 value,
491 /*disable_slice_optimization=*/is_tracing);
492 return 0;
493 }
494
495 // wrap index in a tuple if it's not already one
496 THPObjectPtr holder = wrapTuple(index);
497
498 variable_list variableIndices;
499 int64_t specified_dims = count_specified_dimensions(holder.get());
500 if (specified_dims == -1) {
501 py::object val = py::reinterpret_steal<py::object>(
502 handle_torch_function_indexing(self, index, py_value));
503 return 0;
504 }
505 Variable sliced = applySlicing(
506 self_,
507 holder.get(),
508 variableIndices,
509 /*is_tracing=*/is_tracing,
510 self_device,
511 self_.ndimension(),
512 specified_dims);
513 if (variableIndices.empty()) {
514 pybind11::gil_scoped_release no_gil;
515 at::indexing::copy_to(sliced, value);
516 return 0;
517 }
518
519 {
520 pybind11::gil_scoped_release no_gil;
521 SymIntArrayRef valueSizes = value.sym_sizes();
522 SymIntArrayRef slicedValueSizes =
523 at::indexing::slicePrefix1sSize(valueSizes);
524 torch::autograd::Variable valuesSliced;
525 if (!valueSizes.equals(slicedValueSizes)) {
526 valuesSliced = value.view_symint(slicedValueSizes);
527 } else {
528 valuesSliced = value;
529 }
530 at::indexing::dispatch_index_put_(
531 sliced, std::move(variableIndices), valuesSliced);
532 return 0;
533 }
534 END_HANDLE_TH_ERRORS_RET(-1)
535}
536
537} // namespace autograd
538} // namespace torch
539