1 | #pragma once |
2 | |
3 | #include <ATen/ExpandUtils.h> |
4 | #include <ATen/ScalarOps.h> |
5 | #include <ATen/core/Tensor.h> |
6 | #include <ATen/core/TensorBody.h> |
7 | #include <c10/core/SymInt.h> |
8 | #include <c10/util/Optional.h> |
9 | #include <c10/util/irange.h> |
10 | |
11 | #ifndef AT_PER_OPERATOR_HEADERS |
12 | #include <ATen/Functions.h> |
13 | #include <ATen/NativeFunctions.h> |
14 | #else |
15 | #include <ATen/ops/alias.h> |
16 | #include <ATen/ops/empty.h> |
17 | #include <ATen/ops/scalar_tensor.h> |
18 | #include <ATen/ops/zeros.h> |
19 | #endif |
20 | |
21 | #include <ATen/core/List.h> |
22 | |
23 | #include <utility> |
24 | |
25 | namespace at { |
26 | namespace indexing { |
27 | |
28 | const int64_t INDEX_MIN = c10::SymInt::min_representable_int(); |
29 | const int64_t INDEX_MAX = -(INDEX_MIN + 1); |
30 | |
31 | enum class TensorIndexType { None, Ellipsis, Integer, Boolean, Slice, Tensor }; |
32 | |
33 | constexpr c10::nullopt_t None = c10::nullopt; |
34 | |
35 | struct TORCH_API EllipsisIndexType final { |
36 | EllipsisIndexType() = default; |
37 | }; |
38 | TORCH_API extern const EllipsisIndexType Ellipsis; |
39 | |
40 | struct TORCH_API Slice final { |
41 | public: |
42 | Slice( |
43 | c10::optional<c10::SymInt> start_index = c10::nullopt, |
44 | c10::optional<c10::SymInt> stop_index = c10::nullopt, |
45 | c10::optional<c10::SymInt> step_index = c10::nullopt) { |
46 | if (!step_index.has_value()) { |
47 | step_ = c10::SymInt(1); |
48 | } else { |
49 | step_ = std::move(step_index).value(); |
50 | } |
51 | |
52 | TORCH_CHECK_VALUE(step_ != 0, "slice step cannot be zero" ); |
53 | |
54 | if (!start_index.has_value()) { |
55 | start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0); |
56 | } else { |
57 | start_ = std::move(start_index).value(); |
58 | } |
59 | |
60 | if (!stop_index.has_value()) { |
61 | stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX); |
62 | } else { |
63 | stop_ = std::move(stop_index).value(); |
64 | } |
65 | } |
66 | |
67 | inline c10::SymInt start() const { |
68 | return start_; |
69 | } |
70 | |
71 | inline c10::SymInt stop() const { |
72 | return stop_; |
73 | } |
74 | |
75 | inline c10::SymInt step() const { |
76 | return step_; |
77 | } |
78 | |
79 | private: |
80 | c10::SymInt start_; |
81 | c10::SymInt stop_; |
82 | c10::SymInt step_; |
83 | }; |
84 | |
85 | TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice); |
86 | |
87 | // `at::indexing::TensorIndex` is used for converting C++ tensor indices such as |
88 | // `{None, "...", Ellipsis, 0, true, Slice(1, None, 2), torch::tensor({1, 2})}` |
89 | // into its equivalent `std::vector<TensorIndex>`, so that further tensor |
90 | // indexing operations can be performed using the supplied indices. |
91 | // |
92 | // There is one-to-one correspondence between Python and C++ tensor index types: |
93 | // Python | C++ |
94 | // ----------------------------------------------------- |
95 | // `None` | `at::indexing::None` |
96 | // `Ellipsis` | `at::indexing::Ellipsis` |
97 | // `...` | `"..."` |
98 | // `123` | `123` |
99 | // `True` / `False` | `true` / `false` |
100 | // `:` | `Slice()` / `Slice(None, None)` |
101 | // `::` | `Slice()` / `Slice(None, None, None)` |
102 | // `1:` | `Slice(1, None)` |
103 | // `1::` | `Slice(1, None, None)` |
104 | // `:3` | `Slice(None, 3)` |
105 | // `:3:` | `Slice(None, 3, None)` |
106 | // `::2` | `Slice(None, None, 2)` |
107 | // `1:3` | `Slice(1, 3)` |
108 | // `1::2` | `Slice(1, None, 2)` |
109 | // `:3:2` | `Slice(None, 3, 2)` |
110 | // `1:3:2` | `Slice(1, 3, 2)` |
111 | // `torch.tensor([1, 2])`) | `torch::tensor({1, 2})` |
112 | struct TORCH_API TensorIndex final { |
113 | // Case 1: `at::indexing::None` |
114 | TensorIndex(c10::nullopt_t) : type_(TensorIndexType::None) {} |
115 | |
116 | // Case 2: "..." / `at::indexing::Ellipsis` |
117 | TensorIndex(at::indexing::EllipsisIndexType) |
118 | : type_(TensorIndexType::Ellipsis) {} |
119 | TensorIndex(const char* str) : TensorIndex(at::indexing::Ellipsis) { |
120 | TORCH_CHECK_VALUE( |
121 | strcmp(str, "..." ) == 0, |
122 | "Expected \"...\" to represent an ellipsis index, but got \"" , |
123 | str, |
124 | "\"" ); |
125 | } |
126 | |
127 | // Case 3: Integer value |
128 | TensorIndex(int64_t integer) |
129 | : integer_(integer), type_(TensorIndexType::Integer) {} |
130 | TensorIndex(int integer) : TensorIndex((int64_t)integer) {} |
131 | |
132 | // Case 4: Boolean value |
133 | template < |
134 | class T, |
135 | class = typename std::enable_if<std::is_same<bool, T>::value>::type> |
136 | TensorIndex(T boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {} |
137 | |
138 | // Case 5: Slice represented in `at::indexing::Slice` form |
139 | TensorIndex(Slice slice) |
140 | : slice_(std::move(slice)), type_(TensorIndexType::Slice) {} |
141 | |
142 | // Case 6: Tensor value |
143 | TensorIndex(Tensor tensor) |
144 | : tensor_(std::move(tensor)), type_(TensorIndexType::Tensor) {} |
145 | |
146 | inline bool is_none() const { |
147 | return type_ == TensorIndexType::None; |
148 | } |
149 | |
150 | inline bool is_ellipsis() const { |
151 | return type_ == TensorIndexType::Ellipsis; |
152 | } |
153 | |
154 | inline bool is_integer() const { |
155 | return type_ == TensorIndexType::Integer; |
156 | } |
157 | |
158 | inline int64_t integer() const { |
159 | return integer_; |
160 | } |
161 | |
162 | inline bool is_boolean() const { |
163 | return type_ == TensorIndexType::Boolean; |
164 | } |
165 | |
166 | inline bool boolean() const { |
167 | return boolean_; |
168 | } |
169 | |
170 | inline bool is_slice() const { |
171 | return type_ == TensorIndexType::Slice; |
172 | } |
173 | |
174 | inline const Slice& slice() const { |
175 | return slice_; |
176 | } |
177 | |
178 | inline bool is_tensor() const { |
179 | return type_ == TensorIndexType::Tensor; |
180 | } |
181 | |
182 | inline const Tensor& tensor() const { |
183 | return tensor_; |
184 | } |
185 | |
186 | private: |
187 | int64_t integer_ = 0; |
188 | bool boolean_ = false; |
189 | Slice slice_; |
190 | Tensor tensor_; |
191 | TensorIndexType type_; |
192 | }; |
193 | |
194 | TORCH_API std::ostream& operator<<( |
195 | std::ostream& stream, |
196 | const TensorIndex& tensor_index); |
197 | TORCH_API std::ostream& operator<<( |
198 | std::ostream& stream, |
199 | const std::vector<TensorIndex>& tensor_indices); |
200 | |
201 | namespace impl { |
202 | static inline Tensor applySlice( |
203 | const Tensor& self, |
204 | int64_t dim, |
205 | c10::SymInt start, |
206 | c10::SymInt stop, |
207 | c10::SymInt step, |
208 | bool disable_slice_optimization, |
209 | const at::Device& self_device, |
210 | const c10::optional<SymIntArrayRef>& self_sizes) { |
211 | // TODO: implement negative step |
212 | TORCH_CHECK_VALUE(step > 0, "step must be greater than zero" ); |
213 | |
214 | // See NOTE [nested tensor size for indexing] |
215 | if (self_sizes.has_value()) { |
216 | // Skip this optimization if we are tracing, as the trace may be polymorphic |
217 | // over the shape of the `self` tensor, and we still want to record |
218 | // the slice. |
219 | SymInt length = (self_device == at::kCPU || self_device == at::kCUDA) |
220 | ? (*self_sizes)[dim] |
221 | : self.sym_size(dim); |
222 | if (!disable_slice_optimization && start == 0 && length == stop && |
223 | step == 1) { |
224 | return self; |
225 | } |
226 | } |
227 | return self.slice_symint(dim, start, stop, std::move(step)); |
228 | } |
229 | |
230 | static inline Tensor applySelect( |
231 | const Tensor& self, |
232 | int64_t dim, |
233 | int64_t index, |
234 | int64_t real_dim, |
235 | const at::Device& /*self_device*/, |
236 | const c10::optional<SymIntArrayRef>& self_sizes) { |
237 | // See NOTE [nested tensor size for indexing] |
238 | if (self_sizes.has_value()) { |
239 | TORCH_CHECK_INDEX( |
240 | !(index == 0 && dim == 0 && self_sizes->empty()), |
241 | "invalid index of a 0-dim tensor. " , |
242 | "Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number" ); |
243 | |
244 | auto size = (*self_sizes)[dim]; |
245 | TORCH_CHECK_INDEX( |
246 | size >= -index && size > index, |
247 | "index " , |
248 | index, |
249 | " is out of bounds for dimension " , |
250 | real_dim, |
251 | " with size " , |
252 | size); |
253 | } |
254 | |
255 | // if the index is negative, do not normalize it because that would fix the |
256 | // index on the current tensor size in the tracer. aten::select also works on |
257 | // negative indices |
258 | return self.select(dim, index); |
259 | } |
260 | |
261 | static inline Tensor boolToIndexingTensorCPUOrCUDA( |
262 | const Tensor& self, |
263 | bool value) { |
264 | // booleans add a dimension of size 1. true indexes this dimension as if 0:, |
265 | // false as empty. |
266 | if (value) { |
267 | return at::empty({1}, {}, self.options().dtype(kLong)).fill_(0.); |
268 | } else { |
269 | return at::empty({0}, {}, self.options().dtype(kLong)); |
270 | } |
271 | } |
272 | |
273 | static inline Tensor boolToIndexingTensorNonNativeDeviceType( |
274 | const Tensor& self, |
275 | bool value) { |
276 | // booleans add a dimension of size 1. true indexes this dimension as if 0:, |
277 | // false as empty. |
278 | if (value) { |
279 | return at::zeros({1}, {}, self.options().dtype(kLong)); |
280 | } else { |
281 | return at::empty({0}, {}, self.options().dtype(kLong)); |
282 | } |
283 | } |
284 | |
285 | static inline Tensor boolToIndexingTensor( |
286 | const Tensor& self, |
287 | bool value, |
288 | const at::Device& self_device) { |
289 | if (self_device == at::kCPU || self_device == at::kCUDA) { |
290 | return boolToIndexingTensorCPUOrCUDA(self, value); |
291 | } else { |
292 | return boolToIndexingTensorNonNativeDeviceType(self, value); |
293 | } |
294 | } |
295 | |
296 | static inline Tensor scalarToTensorNonNativeDeviceType( |
297 | const Scalar& v, |
298 | const TensorOptions& options) { |
299 | return at::scalar_tensor(v, options); |
300 | } |
301 | |
302 | static inline void recordTensorIndex( |
303 | const Tensor& tensor, |
304 | std::vector<Tensor>& outIndices, |
305 | int64_t* dim_ptr) { |
306 | // TODO: check scalarType |
307 | outIndices.resize(*dim_ptr + 1); |
308 | outIndices[*dim_ptr] = tensor; |
309 | (*dim_ptr)++; |
310 | }; |
311 | |
312 | static inline c10::List<c10::optional<Tensor>> typeConvertIndices( |
313 | const Tensor& /*self*/, |
314 | std::vector<Tensor>&& indices) { |
315 | c10::List<c10::optional<Tensor>> converted_inds; |
316 | converted_inds.reserve(indices.size()); |
317 | for (const auto& i : indices) { |
318 | converted_inds.push_back(std::move(i)); |
319 | } |
320 | return converted_inds; |
321 | } |
322 | |
323 | // NOTE: Why do we mirror instead of replace the `count_specified_dimensions` |
324 | // function in torch/csrc/autograd/python_variable_indexing.cpp? It's because |
325 | // `count_specified_dimensions` is on the hot path of Python tensor multi-dim |
326 | // indexing (i.e. it's called by `applySlicing` which is called by |
327 | // `THPVariable_getitem` / `THPVariable_setitem` when handling indexing of more |
328 | // than one dimension). If we were to merge the Python/C++ |
329 | // `count_specified_dimensions` function, on the Python side we would have to |
330 | // construct a `std::vector` container to be consumed by the C++ |
331 | // `count_specified_dimensions` function, which adds 100s of nanoseconds |
332 | // overhead and is undesirable. |
333 | static inline int64_t count_specified_dimensions( |
334 | const ArrayRef<TensorIndex>& indices) { |
335 | // Count the number of indexed dimensions (everything but ellipsis and None) |
336 | int64_t count = 0; |
337 | for (auto& obj : indices) { |
338 | if (obj.is_tensor()) { |
339 | auto& tensor = obj.tensor(); |
340 | if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) { |
341 | count += tensor.dim(); |
342 | } else { |
343 | count++; |
344 | } |
345 | } else if (!obj.is_none() && !obj.is_ellipsis() && !obj.is_boolean()) { |
346 | count++; |
347 | } |
348 | } |
349 | return count; |
350 | } |
351 | } // namespace impl |
352 | |
353 | // NOTE: Many functions below are only for consumption from Python indexing |
354 | // implementation, they include: |
355 | // |
356 | // - `Tensor scalarToTensor(...)` |
357 | // - `IntArrayRef slicePrefix1sSize(...)` |
358 | // - `void copy_to(...)` |
359 | // - `Tensor handleDimInMultiDimIndexing(...)` |
360 | // - `Tensor dispatch_index(...)` |
361 | // - `Tensor dispatch_index_put_(...)` |
362 | // - `Tensor get_item(...)` |
363 | // - `void set_item(...)` |
364 | // |
365 | // The rest of the functions are in `at::indexing::impl` namespace, signifying |
366 | // that they shouldn't be used from Python indexing implementation. |
367 | static inline Tensor scalarToTensor( |
368 | const Scalar& v, |
369 | const TensorOptions& options, |
370 | const at::Device& self_device) { |
371 | if (self_device == at::kCPU) { |
372 | return at::detail::scalar_tensor_static( |
373 | v, options.dtype_opt()->toScalarType(), self_device); |
374 | } else { |
375 | return impl::scalarToTensorNonNativeDeviceType(v, options); |
376 | } |
377 | } |
378 | |
379 | // To match numpy semantics: |
380 | // As a special case for backwards compatibility, |
381 | // strip away unit dimensions from the left of 'src' |
382 | static inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) { |
383 | size_t first_non1_src = sizes.size(); |
384 | for (const auto i : c10::irange(sizes.size())) { |
385 | if (sizes[i] != 1) { |
386 | first_non1_src = i; |
387 | break; |
388 | } |
389 | } |
390 | |
391 | return sizes.slice(first_non1_src); |
392 | } |
393 | |
394 | static inline void copy_to(const Tensor& dst, const Tensor& src) { |
395 | if (dst.sym_sizes().equals(src.sym_sizes())) { |
396 | // A shortcut to avoid generating hard-coded constant sizes during tracing. |
397 | // This is not a perfect solution: when src & dst have different shapes, |
398 | // constants will still appear. Users can workaround that case by |
399 | // dst[index..] = src.reshape(..) |
400 | dst.copy_(src); |
401 | return; |
402 | } else if (src.dim() == 0 && src.device().type() == at::kCPU) { |
403 | dst.fill_(src); |
404 | return; |
405 | } |
406 | auto src_view = src.view_symint(slicePrefix1sSize(src.sym_sizes())); |
407 | c10::MaybeOwned<Tensor> b_src = expand_inplace(dst, src_view, "setitem" ); |
408 | dst.copy_(*b_src); |
409 | } |
410 | |
411 | // See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor |
412 | // indexing functions from Python ] |
413 | static inline Tensor handleDimInMultiDimIndexing( |
414 | const Tensor& prev_dim_result, |
415 | const Tensor& original_tensor, |
416 | const TensorIndex& index, |
417 | int64_t* dim_ptr, |
418 | int64_t* specified_dims_ptr, |
419 | int64_t real_dim, |
420 | std::vector<Tensor>& outIndices, |
421 | bool disable_slice_optimization, |
422 | const at::Device& original_tensor_device, |
423 | const c10::optional<SymIntArrayRef>& prev_dim_result_sizes) { |
424 | if (index.is_integer()) { |
425 | return impl::applySelect( |
426 | prev_dim_result, |
427 | *dim_ptr, |
428 | index.integer(), |
429 | real_dim, |
430 | original_tensor_device, |
431 | prev_dim_result_sizes); |
432 | } else if (index.is_slice()) { |
433 | Tensor result = impl::applySlice( |
434 | prev_dim_result, |
435 | *dim_ptr, |
436 | index.slice().start(), |
437 | index.slice().stop(), |
438 | index.slice().step(), |
439 | /*disable_slice_optimization=*/disable_slice_optimization, |
440 | original_tensor_device, |
441 | prev_dim_result_sizes); |
442 | (*dim_ptr)++; |
443 | return result; |
444 | } else if (index.is_ellipsis()) { |
445 | (*dim_ptr) += original_tensor.dim() - (*specified_dims_ptr); |
446 | return prev_dim_result; |
447 | } else if (index.is_none()) { |
448 | Tensor result = prev_dim_result.unsqueeze(*dim_ptr); |
449 | (*dim_ptr)++; |
450 | return result; |
451 | } else if (index.is_boolean()) { |
452 | Tensor result = prev_dim_result.unsqueeze(*dim_ptr); |
453 | impl::recordTensorIndex( |
454 | impl::boolToIndexingTensor( |
455 | result, index.boolean(), original_tensor_device), |
456 | outIndices, |
457 | dim_ptr); |
458 | return result; |
459 | } else if (index.is_tensor()) { |
460 | Tensor result = prev_dim_result; |
461 | const Tensor& tensor = index.tensor(); |
462 | auto scalar_type = tensor.scalar_type(); |
463 | if (tensor.dim() == 0 && |
464 | at::isIntegralType(scalar_type, /*includeBool=*/true)) { |
465 | if (scalar_type != at::kByte && scalar_type != at::kBool) { |
466 | result = impl::applySelect( |
467 | result, |
468 | *dim_ptr, |
469 | tensor.item<int64_t>(), |
470 | real_dim, |
471 | original_tensor_device, |
472 | prev_dim_result_sizes); |
473 | } else { |
474 | result = result.unsqueeze(*dim_ptr); |
475 | if (scalar_type == at::kBool) { |
476 | impl::recordTensorIndex( |
477 | impl::boolToIndexingTensor( |
478 | result, tensor.item<bool>() != 0, original_tensor_device), |
479 | outIndices, |
480 | dim_ptr); |
481 | } else { |
482 | impl::recordTensorIndex( |
483 | impl::boolToIndexingTensor( |
484 | result, tensor.item<uint8_t>() != 0, original_tensor_device), |
485 | outIndices, |
486 | dim_ptr); |
487 | } |
488 | } |
489 | } else { |
490 | impl::recordTensorIndex(tensor, outIndices, dim_ptr); |
491 | } |
492 | return result; |
493 | } else { |
494 | TORCH_INTERNAL_ASSERT(false, "Invalid TensorIndex type" ); |
495 | } |
496 | } |
497 | |
498 | namespace impl { |
499 | // This mirrors `applySlicing` in |
500 | // torch/csrc/autograd/python_variable_indexing.cpp |
501 | static inline Tensor applySlicing( |
502 | const Tensor& self, |
503 | const ArrayRef<TensorIndex>& indices, |
504 | std::vector<Tensor>& outIndices, |
505 | bool disable_slice_optimization, |
506 | const at::Device& self_device, |
507 | const c10::optional<SymIntArrayRef>& self_sizes) { |
508 | int64_t dim = 0; |
509 | int64_t specified_dims = impl::count_specified_dimensions(indices); |
510 | |
511 | // See NOTE [nested tensor size for indexing] |
512 | if (self_sizes.has_value()) { |
513 | TORCH_CHECK_INDEX( |
514 | specified_dims <= (int64_t)self_sizes->size(), |
515 | "too many indices for tensor of dimension " , |
516 | (int)self_sizes->size()); |
517 | } |
518 | |
519 | Tensor result = self; |
520 | for (const auto i : c10::irange(indices.size())) { |
521 | auto& obj = indices[i]; |
522 | // See NOTE [nested tensor size for indexing] |
523 | c10::optional<SymIntArrayRef> result_sizes = result.is_nested() |
524 | ? c10::optional<SymIntArrayRef>(c10::nullopt) |
525 | : c10::optional<SymIntArrayRef>(result.sym_sizes()); |
526 | result = handleDimInMultiDimIndexing( |
527 | /*prev_dim_result=*/result, |
528 | /*original_tensor=*/self, |
529 | /*index=*/obj, |
530 | /*dim=*/&dim, |
531 | /*specified_dims=*/&specified_dims, |
532 | /*real_dim=*/i, |
533 | /*outIndices=*/outIndices, |
534 | /*disable_slice_optimization=*/disable_slice_optimization, |
535 | /*original_tensor_device=*/self_device, |
536 | /*prev_dim_result_sizes=*/result_sizes); |
537 | } |
538 | return result; |
539 | } |
540 | } // namespace impl |
541 | |
542 | static inline Tensor dispatch_index( |
543 | const Tensor& self, |
544 | std::vector<Tensor>&& indices) { |
545 | return self.index(impl::typeConvertIndices(self, std::move(indices))); |
546 | } |
547 | |
548 | static inline Tensor dispatch_index_put_( |
549 | Tensor& self, |
550 | std::vector<Tensor>&& indices, |
551 | const Tensor& value) { |
552 | return self.index_put_( |
553 | impl::typeConvertIndices(self, std::move(indices)), value); |
554 | } |
555 | |
556 | // NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing |
557 | // functions from Python ] |
558 | // |
559 | // Question: When should we set `disable_slice_optimization` to `true` when |
560 | // calling C++ tensor indexing functions from Python indexing code? |
561 | // |
562 | // Answer: What "slice optimization" means: when we have a slicing expression |
563 | // like `x[0:5, 0]`, where the sliced tensor was of size 5 in dimension 0, we |
564 | // would skip dispatching the actual slice call as an optimization. However, |
565 | // here are the cases where we DON'T want this optimization: |
566 | // |
567 | // 1. When we are doing 1-D slicing (e.g. `tensor[:]`). |
568 | // Reason: we always return a shallow copy for expressions such as |
569 | // `tensor[:]` / `tensor[...]` / `tensor[:, :]`. (Note that for `tensor[:, |
570 | // :]`, we return an alias of `tensor` by doing the following: |
571 | // ``` |
572 | // Tensor sliced = impl::applySlicing(self, indices, tensorIndices, |
573 | // disable_slice_optimization, self_device, self_sizes); if |
574 | // (tensorIndices.empty()) { |
575 | // if (sliced.is_same(self)) { |
576 | // // ensure we return a shallow copy for things like x[...] |
577 | // sliced = at::alias(sliced); |
578 | // } |
579 | // return sliced; |
580 | // } |
581 | // ```) |
582 | // 2. When we are doing JIT tracing. |
583 | // Reason: JIT tracing needs the `self.slice(...)` call to properly trace the |
584 | // slice operation. |
585 | |
586 | // This mirrors `THPVariable_getitem` in |
587 | // torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting |
588 | // `disable_slice_optimization` when calling C++ tensor indexing functions from |
589 | // Python ] |
590 | static inline Tensor get_item( |
591 | const Tensor& self, |
592 | const ArrayRef<TensorIndex>& indices, |
593 | bool disable_slice_optimization = false) { |
594 | at::Device self_device = self.device(); |
595 | // NOTE [nested tensor size for indexing] |
596 | // nested tensor does not have a size (yet) so for now we represent its size |
597 | // as null may need to be changed after we reach a better solution for nested |
598 | // tensor size |
599 | c10::optional<SymIntArrayRef> self_sizes = self.is_nested() |
600 | ? c10::optional<SymIntArrayRef>(c10::nullopt) |
601 | : c10::optional<SymIntArrayRef>(self.sym_sizes()); |
602 | |
603 | // handle simple types: integers, slices, none, ellipsis, bool |
604 | if (indices.size() == 1) { |
605 | const TensorIndex& index = indices[0]; |
606 | if (index.is_integer()) { |
607 | return impl::applySelect( |
608 | self, 0, index.integer(), 0, self_device, self_sizes); |
609 | } else if (index.is_slice()) { |
610 | return impl::applySlice( |
611 | self, |
612 | 0, |
613 | index.slice().start(), |
614 | index.slice().stop(), |
615 | index.slice().step(), |
616 | /*disable_slice_optimization=*/true, |
617 | self_device, |
618 | self_sizes); |
619 | } else if (index.is_none()) { |
620 | return self.unsqueeze(0); |
621 | } else if (index.is_ellipsis()) { |
622 | return at::alias(self); |
623 | } else if (index.is_boolean()) { |
624 | Tensor result = self.unsqueeze(0); |
625 | return dispatch_index( |
626 | result, |
627 | std::vector<Tensor>{impl::boolToIndexingTensor( |
628 | result, index.boolean(), self_device)}); |
629 | } |
630 | } |
631 | |
632 | std::vector<Tensor> tensorIndices; |
633 | Tensor sliced = impl::applySlicing( |
634 | self, |
635 | indices, |
636 | tensorIndices, |
637 | disable_slice_optimization, |
638 | self_device, |
639 | self_sizes); |
640 | if (tensorIndices.empty()) { |
641 | if (sliced.is_same(self)) { |
642 | // ensure we return a shallow copy for things like x[...] |
643 | sliced = at::alias(sliced); |
644 | } |
645 | return sliced; |
646 | } |
647 | |
648 | // indexing by tensors ("advanced" indexing) |
649 | return dispatch_index(sliced, std::move(tensorIndices)); |
650 | } |
651 | |
652 | // This mirrors `THPVariable_setitem` in |
653 | // torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a |
654 | // Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++ |
655 | // tensor indexing functions from Python ] |
656 | static inline void set_item( |
657 | const Tensor& self, |
658 | const ArrayRef<TensorIndex>& indices, |
659 | const Tensor& value, |
660 | bool disable_slice_optimization = false) { |
661 | at::Device self_device = self.device(); |
662 | SymIntArrayRef self_sizes = self.sym_sizes(); |
663 | |
664 | // handle simple types: integers, slices, ellipsis, bool |
665 | if (indices.size() == 1) { |
666 | const TensorIndex& index = indices[0]; |
667 | if (index.is_boolean() && !index.boolean()) { |
668 | // do nothing for false (technically we should check the size, but we |
669 | // don't have real 0-sized shapes. |
670 | return; |
671 | } else if (index.is_ellipsis()) { |
672 | copy_to(self, value); |
673 | return; |
674 | } else if (index.is_none() || (index.is_boolean() && index.boolean())) { |
675 | copy_to(self.unsqueeze(0), value); |
676 | return; |
677 | } else if (index.is_integer()) { |
678 | copy_to( |
679 | impl::applySelect( |
680 | self, 0, index.integer(), 0, self_device, self_sizes), |
681 | value); |
682 | return; |
683 | } else if (index.is_slice()) { |
684 | copy_to( |
685 | impl::applySlice( |
686 | self, |
687 | 0, |
688 | index.slice().start(), |
689 | index.slice().stop(), |
690 | index.slice().step(), |
691 | /*disable_slice_optimization=*/disable_slice_optimization, |
692 | self_device, |
693 | self_sizes), |
694 | value); |
695 | return; |
696 | } |
697 | } |
698 | |
699 | std::vector<Tensor> tensorIndices; |
700 | Tensor sliced = impl::applySlicing( |
701 | self, |
702 | indices, |
703 | tensorIndices, |
704 | disable_slice_optimization, |
705 | self_device, |
706 | self_sizes); |
707 | if (tensorIndices.empty()) { |
708 | copy_to(sliced, value); |
709 | return; |
710 | } |
711 | |
712 | SymIntArrayRef valueSizes = value.sym_sizes(); |
713 | SymIntArrayRef slicedValueSizes = slicePrefix1sSize(valueSizes); |
714 | Tensor valuesSliced; |
715 | if (!valueSizes.equals(slicedValueSizes)) { |
716 | valuesSliced = value.view_symint(slicedValueSizes); |
717 | } else { |
718 | valuesSliced = value; |
719 | } |
720 | dispatch_index_put_(sliced, std::move(tensorIndices), valuesSliced); |
721 | return; |
722 | } |
723 | |
724 | } // namespace indexing |
725 | } // namespace at |
726 | |