1#include <ATen/TensorIndexing.h>
2
3#include <c10/util/Exception.h>
4#include <c10/util/irange.h>
5
6namespace at {
7namespace indexing {
8
9const EllipsisIndexType Ellipsis = EllipsisIndexType();
10
11std::ostream& operator<<(std::ostream& stream, const Slice& slice) {
12 stream << slice.start() << ":" << slice.stop() << ":" << slice.step();
13 return stream;
14}
15
16std::ostream& operator<<(std::ostream& stream, const TensorIndex& tensor_index) {
17 if (tensor_index.is_none()) {
18 stream << "None";
19 } else if (tensor_index.is_ellipsis()) {
20 stream << "...";
21 } else if (tensor_index.is_integer()) {
22 stream << tensor_index.integer();
23 } else if (tensor_index.is_boolean()) {
24 stream << std::boolalpha << tensor_index.boolean();
25 } else if (tensor_index.is_slice()) {
26 stream << tensor_index.slice();
27 } else if (tensor_index.is_tensor()) {
28 stream << tensor_index.tensor();
29 }
30 return stream;
31}
32
33std::ostream& operator<<(std::ostream& stream, const std::vector<TensorIndex>& tensor_indices) {
34 stream << "(";
35 for (const auto i : c10::irange(tensor_indices.size())) {
36 stream << tensor_indices[i];
37 if (i < tensor_indices.size() - 1) stream << ", ";
38 }
39 stream << ")";
40 return stream;
41}
42
43// This mirrors `THPVariable_setitem` in torch/csrc/autograd/python_variable_indexing.cpp
44// for "the assigned value is a Scalar" case
45static inline void set_item(const Tensor& self, ArrayRef<TensorIndex> indices, const Scalar& v) {
46 Tensor value;
47
48 {
49 at::AutoDispatchBelowADInplaceOrView guard;
50 at::Device self_device = self.device();
51
52 // TODO: This qint special case looks very suspicious...
53 if (isQIntType(self.scalar_type())) {
54 value = at::indexing::scalarToTensor(v, device(kCPU).dtype(kFloat), at::Device(kCPU));
55 } else if (self_device.is_cuda()) {
56 value = at::indexing::scalarToTensor(v, self.options(), at::Device(kCPU));
57 } else {
58 value = at::indexing::scalarToTensor(v, self.options(), self_device);
59 }
60 }
61
62 return set_item(self, indices, value);
63}
64
65} // namespace indexing
66
67Tensor Tensor::index(ArrayRef<at::indexing::TensorIndex> indices) const {
68 TORCH_CHECK(!indices.empty(), "Passing an empty index list to Tensor::index() is not valid syntax");
69 OptionalDeviceGuard device_guard(device_of(*this));
70 return at::indexing::get_item(*this, indices);
71}
72Tensor Tensor::index(std::initializer_list<at::indexing::TensorIndex> indices) const {
73 return index(ArrayRef<at::indexing::TensorIndex>(indices));
74}
75
76Tensor & Tensor::index_put_(ArrayRef<at::indexing::TensorIndex> indices, Tensor const & rhs) {
77 TORCH_CHECK(!indices.empty(), "Passing an empty index list to Tensor::index_put_() is not valid syntax");
78 OptionalDeviceGuard device_guard(device_of(*this));
79 at::indexing::set_item(*this, indices, rhs);
80 return *this;
81}
82Tensor & Tensor::index_put_(ArrayRef<at::indexing::TensorIndex> indices, const Scalar& v) {
83 TORCH_CHECK(!indices.empty(), "Passing an empty index list to Tensor::index_put_() is not valid syntax");
84 OptionalDeviceGuard device_guard(device_of(*this));
85 at::indexing::set_item(*this, indices, v);
86 return *this;
87}
88Tensor & Tensor::index_put_(std::initializer_list<at::indexing::TensorIndex> indices, Tensor const & rhs) {
89 return index_put_(ArrayRef<at::indexing::TensorIndex>(indices), rhs);
90}
91Tensor & Tensor::index_put_(std::initializer_list<at::indexing::TensorIndex> indices, const Scalar& v) {
92 return index_put_(ArrayRef<at::indexing::TensorIndex>(indices), v);
93}
94
95} // namespace at
96