1 | #include <ATen/TensorIndexing.h> |
2 | |
3 | #include <c10/util/Exception.h> |
4 | #include <c10/util/irange.h> |
5 | |
6 | namespace at { |
7 | namespace indexing { |
8 | |
9 | const EllipsisIndexType Ellipsis = EllipsisIndexType(); |
10 | |
11 | std::ostream& operator<<(std::ostream& stream, const Slice& slice) { |
12 | stream << slice.start() << ":" << slice.stop() << ":" << slice.step(); |
13 | return stream; |
14 | } |
15 | |
16 | std::ostream& operator<<(std::ostream& stream, const TensorIndex& tensor_index) { |
17 | if (tensor_index.is_none()) { |
18 | stream << "None" ; |
19 | } else if (tensor_index.is_ellipsis()) { |
20 | stream << "..." ; |
21 | } else if (tensor_index.is_integer()) { |
22 | stream << tensor_index.integer(); |
23 | } else if (tensor_index.is_boolean()) { |
24 | stream << std::boolalpha << tensor_index.boolean(); |
25 | } else if (tensor_index.is_slice()) { |
26 | stream << tensor_index.slice(); |
27 | } else if (tensor_index.is_tensor()) { |
28 | stream << tensor_index.tensor(); |
29 | } |
30 | return stream; |
31 | } |
32 | |
33 | std::ostream& operator<<(std::ostream& stream, const std::vector<TensorIndex>& tensor_indices) { |
34 | stream << "(" ; |
35 | for (const auto i : c10::irange(tensor_indices.size())) { |
36 | stream << tensor_indices[i]; |
37 | if (i < tensor_indices.size() - 1) stream << ", " ; |
38 | } |
39 | stream << ")" ; |
40 | return stream; |
41 | } |
42 | |
43 | // This mirrors `THPVariable_setitem` in torch/csrc/autograd/python_variable_indexing.cpp |
44 | // for "the assigned value is a Scalar" case |
45 | static inline void set_item(const Tensor& self, ArrayRef<TensorIndex> indices, const Scalar& v) { |
46 | Tensor value; |
47 | |
48 | { |
49 | at::AutoDispatchBelowADInplaceOrView guard; |
50 | at::Device self_device = self.device(); |
51 | |
52 | // TODO: This qint special case looks very suspicious... |
53 | if (isQIntType(self.scalar_type())) { |
54 | value = at::indexing::scalarToTensor(v, device(kCPU).dtype(kFloat), at::Device(kCPU)); |
55 | } else if (self_device.is_cuda()) { |
56 | value = at::indexing::scalarToTensor(v, self.options(), at::Device(kCPU)); |
57 | } else { |
58 | value = at::indexing::scalarToTensor(v, self.options(), self_device); |
59 | } |
60 | } |
61 | |
62 | return set_item(self, indices, value); |
63 | } |
64 | |
65 | } // namespace indexing |
66 | |
67 | Tensor Tensor::index(ArrayRef<at::indexing::TensorIndex> indices) const { |
68 | TORCH_CHECK(!indices.empty(), "Passing an empty index list to Tensor::index() is not valid syntax" ); |
69 | OptionalDeviceGuard device_guard(device_of(*this)); |
70 | return at::indexing::get_item(*this, indices); |
71 | } |
72 | Tensor Tensor::index(std::initializer_list<at::indexing::TensorIndex> indices) const { |
73 | return index(ArrayRef<at::indexing::TensorIndex>(indices)); |
74 | } |
75 | |
76 | Tensor & Tensor::index_put_(ArrayRef<at::indexing::TensorIndex> indices, Tensor const & rhs) { |
77 | TORCH_CHECK(!indices.empty(), "Passing an empty index list to Tensor::index_put_() is not valid syntax" ); |
78 | OptionalDeviceGuard device_guard(device_of(*this)); |
79 | at::indexing::set_item(*this, indices, rhs); |
80 | return *this; |
81 | } |
82 | Tensor & Tensor::index_put_(ArrayRef<at::indexing::TensorIndex> indices, const Scalar& v) { |
83 | TORCH_CHECK(!indices.empty(), "Passing an empty index list to Tensor::index_put_() is not valid syntax" ); |
84 | OptionalDeviceGuard device_guard(device_of(*this)); |
85 | at::indexing::set_item(*this, indices, v); |
86 | return *this; |
87 | } |
88 | Tensor & Tensor::index_put_(std::initializer_list<at::indexing::TensorIndex> indices, Tensor const & rhs) { |
89 | return index_put_(ArrayRef<at::indexing::TensorIndex>(indices), rhs); |
90 | } |
91 | Tensor & Tensor::index_put_(std::initializer_list<at::indexing::TensorIndex> indices, const Scalar& v) { |
92 | return index_put_(ArrayRef<at::indexing::TensorIndex>(indices), v); |
93 | } |
94 | |
95 | } // namespace at |
96 | |