1#pragma once
2
3#include <ATen/ExpandUtils.h>
4#include <ATen/ScalarOps.h>
5#include <ATen/core/Tensor.h>
6#include <ATen/core/TensorBody.h>
7#include <c10/core/SymInt.h>
8#include <c10/util/Optional.h>
9#include <c10/util/irange.h>
10
11#ifndef AT_PER_OPERATOR_HEADERS
12#include <ATen/Functions.h>
13#include <ATen/NativeFunctions.h>
14#else
15#include <ATen/ops/alias.h>
16#include <ATen/ops/empty.h>
17#include <ATen/ops/scalar_tensor.h>
18#include <ATen/ops/zeros.h>
19#endif
20
21#include <ATen/core/List.h>
22
23#include <utility>
24
25namespace at {
26namespace indexing {
27
28const int64_t INDEX_MIN = c10::SymInt::min_representable_int();
29const int64_t INDEX_MAX = -(INDEX_MIN + 1);
30
31enum class TensorIndexType { None, Ellipsis, Integer, Boolean, Slice, Tensor };
32
33constexpr c10::nullopt_t None = c10::nullopt;
34
35struct TORCH_API EllipsisIndexType final {
36 EllipsisIndexType() = default;
37};
38TORCH_API extern const EllipsisIndexType Ellipsis;
39
40struct TORCH_API Slice final {
41 public:
42 Slice(
43 c10::optional<c10::SymInt> start_index = c10::nullopt,
44 c10::optional<c10::SymInt> stop_index = c10::nullopt,
45 c10::optional<c10::SymInt> step_index = c10::nullopt) {
46 if (!step_index.has_value()) {
47 step_ = c10::SymInt(1);
48 } else {
49 step_ = std::move(step_index).value();
50 }
51
52 TORCH_CHECK_VALUE(step_ != 0, "slice step cannot be zero");
53
54 if (!start_index.has_value()) {
55 start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0);
56 } else {
57 start_ = std::move(start_index).value();
58 }
59
60 if (!stop_index.has_value()) {
61 stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX);
62 } else {
63 stop_ = std::move(stop_index).value();
64 }
65 }
66
67 inline c10::SymInt start() const {
68 return start_;
69 }
70
71 inline c10::SymInt stop() const {
72 return stop_;
73 }
74
75 inline c10::SymInt step() const {
76 return step_;
77 }
78
79 private:
80 c10::SymInt start_;
81 c10::SymInt stop_;
82 c10::SymInt step_;
83};
84
85TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice);
86
87// `at::indexing::TensorIndex` is used for converting C++ tensor indices such as
88// `{None, "...", Ellipsis, 0, true, Slice(1, None, 2), torch::tensor({1, 2})}`
89// into its equivalent `std::vector<TensorIndex>`, so that further tensor
90// indexing operations can be performed using the supplied indices.
91//
92// There is one-to-one correspondence between Python and C++ tensor index types:
93// Python | C++
94// -----------------------------------------------------
95// `None` | `at::indexing::None`
96// `Ellipsis` | `at::indexing::Ellipsis`
97// `...` | `"..."`
98// `123` | `123`
99// `True` / `False` | `true` / `false`
100// `:` | `Slice()` / `Slice(None, None)`
101// `::` | `Slice()` / `Slice(None, None, None)`
102// `1:` | `Slice(1, None)`
103// `1::` | `Slice(1, None, None)`
104// `:3` | `Slice(None, 3)`
105// `:3:` | `Slice(None, 3, None)`
106// `::2` | `Slice(None, None, 2)`
107// `1:3` | `Slice(1, 3)`
108// `1::2` | `Slice(1, None, 2)`
109// `:3:2` | `Slice(None, 3, 2)`
110// `1:3:2` | `Slice(1, 3, 2)`
111// `torch.tensor([1, 2])`) | `torch::tensor({1, 2})`
112struct TORCH_API TensorIndex final {
113 // Case 1: `at::indexing::None`
114 TensorIndex(c10::nullopt_t) : type_(TensorIndexType::None) {}
115
116 // Case 2: "..." / `at::indexing::Ellipsis`
117 TensorIndex(at::indexing::EllipsisIndexType)
118 : type_(TensorIndexType::Ellipsis) {}
119 TensorIndex(const char* str) : TensorIndex(at::indexing::Ellipsis) {
120 TORCH_CHECK_VALUE(
121 strcmp(str, "...") == 0,
122 "Expected \"...\" to represent an ellipsis index, but got \"",
123 str,
124 "\"");
125 }
126
127 // Case 3: Integer value
128 TensorIndex(int64_t integer)
129 : integer_(integer), type_(TensorIndexType::Integer) {}
130 TensorIndex(int integer) : TensorIndex((int64_t)integer) {}
131
132 // Case 4: Boolean value
133 template <
134 class T,
135 class = typename std::enable_if<std::is_same<bool, T>::value>::type>
136 TensorIndex(T boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {}
137
138 // Case 5: Slice represented in `at::indexing::Slice` form
139 TensorIndex(Slice slice)
140 : slice_(std::move(slice)), type_(TensorIndexType::Slice) {}
141
142 // Case 6: Tensor value
143 TensorIndex(Tensor tensor)
144 : tensor_(std::move(tensor)), type_(TensorIndexType::Tensor) {}
145
146 inline bool is_none() const {
147 return type_ == TensorIndexType::None;
148 }
149
150 inline bool is_ellipsis() const {
151 return type_ == TensorIndexType::Ellipsis;
152 }
153
154 inline bool is_integer() const {
155 return type_ == TensorIndexType::Integer;
156 }
157
158 inline int64_t integer() const {
159 return integer_;
160 }
161
162 inline bool is_boolean() const {
163 return type_ == TensorIndexType::Boolean;
164 }
165
166 inline bool boolean() const {
167 return boolean_;
168 }
169
170 inline bool is_slice() const {
171 return type_ == TensorIndexType::Slice;
172 }
173
174 inline const Slice& slice() const {
175 return slice_;
176 }
177
178 inline bool is_tensor() const {
179 return type_ == TensorIndexType::Tensor;
180 }
181
182 inline const Tensor& tensor() const {
183 return tensor_;
184 }
185
186 private:
187 int64_t integer_ = 0;
188 bool boolean_ = false;
189 Slice slice_;
190 Tensor tensor_;
191 TensorIndexType type_;
192};
193
194TORCH_API std::ostream& operator<<(
195 std::ostream& stream,
196 const TensorIndex& tensor_index);
197TORCH_API std::ostream& operator<<(
198 std::ostream& stream,
199 const std::vector<TensorIndex>& tensor_indices);
200
201namespace impl {
202static inline Tensor applySlice(
203 const Tensor& self,
204 int64_t dim,
205 c10::SymInt start,
206 c10::SymInt stop,
207 c10::SymInt step,
208 bool disable_slice_optimization,
209 const at::Device& self_device,
210 const c10::optional<SymIntArrayRef>& self_sizes) {
211 // TODO: implement negative step
212 TORCH_CHECK_VALUE(step > 0, "step must be greater than zero");
213
214 // See NOTE [nested tensor size for indexing]
215 if (self_sizes.has_value()) {
216 // Skip this optimization if we are tracing, as the trace may be polymorphic
217 // over the shape of the `self` tensor, and we still want to record
218 // the slice.
219 SymInt length = (self_device == at::kCPU || self_device == at::kCUDA)
220 ? (*self_sizes)[dim]
221 : self.sym_size(dim);
222 if (!disable_slice_optimization && start == 0 && length == stop &&
223 step == 1) {
224 return self;
225 }
226 }
227 return self.slice_symint(dim, start, stop, std::move(step));
228}
229
230static inline Tensor applySelect(
231 const Tensor& self,
232 int64_t dim,
233 int64_t index,
234 int64_t real_dim,
235 const at::Device& /*self_device*/,
236 const c10::optional<SymIntArrayRef>& self_sizes) {
237 // See NOTE [nested tensor size for indexing]
238 if (self_sizes.has_value()) {
239 TORCH_CHECK_INDEX(
240 !(index == 0 && dim == 0 && self_sizes->empty()),
241 "invalid index of a 0-dim tensor. ",
242 "Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number");
243
244 auto size = (*self_sizes)[dim];
245 TORCH_CHECK_INDEX(
246 size >= -index && size > index,
247 "index ",
248 index,
249 " is out of bounds for dimension ",
250 real_dim,
251 " with size ",
252 size);
253 }
254
255 // if the index is negative, do not normalize it because that would fix the
256 // index on the current tensor size in the tracer. aten::select also works on
257 // negative indices
258 return self.select(dim, index);
259}
260
261static inline Tensor boolToIndexingTensorCPUOrCUDA(
262 const Tensor& self,
263 bool value) {
264 // booleans add a dimension of size 1. true indexes this dimension as if 0:,
265 // false as empty.
266 if (value) {
267 return at::empty({1}, {}, self.options().dtype(kLong)).fill_(0.);
268 } else {
269 return at::empty({0}, {}, self.options().dtype(kLong));
270 }
271}
272
273static inline Tensor boolToIndexingTensorNonNativeDeviceType(
274 const Tensor& self,
275 bool value) {
276 // booleans add a dimension of size 1. true indexes this dimension as if 0:,
277 // false as empty.
278 if (value) {
279 return at::zeros({1}, {}, self.options().dtype(kLong));
280 } else {
281 return at::empty({0}, {}, self.options().dtype(kLong));
282 }
283}
284
285static inline Tensor boolToIndexingTensor(
286 const Tensor& self,
287 bool value,
288 const at::Device& self_device) {
289 if (self_device == at::kCPU || self_device == at::kCUDA) {
290 return boolToIndexingTensorCPUOrCUDA(self, value);
291 } else {
292 return boolToIndexingTensorNonNativeDeviceType(self, value);
293 }
294}
295
296static inline Tensor scalarToTensorNonNativeDeviceType(
297 const Scalar& v,
298 const TensorOptions& options) {
299 return at::scalar_tensor(v, options);
300}
301
302static inline void recordTensorIndex(
303 const Tensor& tensor,
304 std::vector<Tensor>& outIndices,
305 int64_t* dim_ptr) {
306 // TODO: check scalarType
307 outIndices.resize(*dim_ptr + 1);
308 outIndices[*dim_ptr] = tensor;
309 (*dim_ptr)++;
310};
311
312static inline c10::List<c10::optional<Tensor>> typeConvertIndices(
313 const Tensor& /*self*/,
314 std::vector<Tensor>&& indices) {
315 c10::List<c10::optional<Tensor>> converted_inds;
316 converted_inds.reserve(indices.size());
317 for (const auto& i : indices) {
318 converted_inds.push_back(std::move(i));
319 }
320 return converted_inds;
321}
322
323// NOTE: Why do we mirror instead of replace the `count_specified_dimensions`
324// function in torch/csrc/autograd/python_variable_indexing.cpp? It's because
325// `count_specified_dimensions` is on the hot path of Python tensor multi-dim
326// indexing (i.e. it's called by `applySlicing` which is called by
327// `THPVariable_getitem` / `THPVariable_setitem` when handling indexing of more
328// than one dimension). If we were to merge the Python/C++
329// `count_specified_dimensions` function, on the Python side we would have to
330// construct a `std::vector` container to be consumed by the C++
331// `count_specified_dimensions` function, which adds 100s of nanoseconds
332// overhead and is undesirable.
333static inline int64_t count_specified_dimensions(
334 const ArrayRef<TensorIndex>& indices) {
335 // Count the number of indexed dimensions (everything but ellipsis and None)
336 int64_t count = 0;
337 for (auto& obj : indices) {
338 if (obj.is_tensor()) {
339 auto& tensor = obj.tensor();
340 if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) {
341 count += tensor.dim();
342 } else {
343 count++;
344 }
345 } else if (!obj.is_none() && !obj.is_ellipsis() && !obj.is_boolean()) {
346 count++;
347 }
348 }
349 return count;
350}
351} // namespace impl
352
353// NOTE: Many functions below are only for consumption from Python indexing
354// implementation, they include:
355//
356// - `Tensor scalarToTensor(...)`
357// - `IntArrayRef slicePrefix1sSize(...)`
358// - `void copy_to(...)`
359// - `Tensor handleDimInMultiDimIndexing(...)`
360// - `Tensor dispatch_index(...)`
361// - `Tensor dispatch_index_put_(...)`
362// - `Tensor get_item(...)`
363// - `void set_item(...)`
364//
365// The rest of the functions are in `at::indexing::impl` namespace, signifying
366// that they shouldn't be used from Python indexing implementation.
367static inline Tensor scalarToTensor(
368 const Scalar& v,
369 const TensorOptions& options,
370 const at::Device& self_device) {
371 if (self_device == at::kCPU) {
372 return at::detail::scalar_tensor_static(
373 v, options.dtype_opt()->toScalarType(), self_device);
374 } else {
375 return impl::scalarToTensorNonNativeDeviceType(v, options);
376 }
377}
378
379// To match numpy semantics:
380// As a special case for backwards compatibility,
381// strip away unit dimensions from the left of 'src'
382static inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) {
383 size_t first_non1_src = sizes.size();
384 for (const auto i : c10::irange(sizes.size())) {
385 if (sizes[i] != 1) {
386 first_non1_src = i;
387 break;
388 }
389 }
390
391 return sizes.slice(first_non1_src);
392}
393
394static inline void copy_to(const Tensor& dst, const Tensor& src) {
395 if (dst.sym_sizes().equals(src.sym_sizes())) {
396 // A shortcut to avoid generating hard-coded constant sizes during tracing.
397 // This is not a perfect solution: when src & dst have different shapes,
398 // constants will still appear. Users can workaround that case by
399 // dst[index..] = src.reshape(..)
400 dst.copy_(src);
401 return;
402 } else if (src.dim() == 0 && src.device().type() == at::kCPU) {
403 dst.fill_(src);
404 return;
405 }
406 auto src_view = src.view_symint(slicePrefix1sSize(src.sym_sizes()));
407 c10::MaybeOwned<Tensor> b_src = expand_inplace(dst, src_view, "setitem");
408 dst.copy_(*b_src);
409}
410
411// See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor
412// indexing functions from Python ]
413static inline Tensor handleDimInMultiDimIndexing(
414 const Tensor& prev_dim_result,
415 const Tensor& original_tensor,
416 const TensorIndex& index,
417 int64_t* dim_ptr,
418 int64_t* specified_dims_ptr,
419 int64_t real_dim,
420 std::vector<Tensor>& outIndices,
421 bool disable_slice_optimization,
422 const at::Device& original_tensor_device,
423 const c10::optional<SymIntArrayRef>& prev_dim_result_sizes) {
424 if (index.is_integer()) {
425 return impl::applySelect(
426 prev_dim_result,
427 *dim_ptr,
428 index.integer(),
429 real_dim,
430 original_tensor_device,
431 prev_dim_result_sizes);
432 } else if (index.is_slice()) {
433 Tensor result = impl::applySlice(
434 prev_dim_result,
435 *dim_ptr,
436 index.slice().start(),
437 index.slice().stop(),
438 index.slice().step(),
439 /*disable_slice_optimization=*/disable_slice_optimization,
440 original_tensor_device,
441 prev_dim_result_sizes);
442 (*dim_ptr)++;
443 return result;
444 } else if (index.is_ellipsis()) {
445 (*dim_ptr) += original_tensor.dim() - (*specified_dims_ptr);
446 return prev_dim_result;
447 } else if (index.is_none()) {
448 Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
449 (*dim_ptr)++;
450 return result;
451 } else if (index.is_boolean()) {
452 Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
453 impl::recordTensorIndex(
454 impl::boolToIndexingTensor(
455 result, index.boolean(), original_tensor_device),
456 outIndices,
457 dim_ptr);
458 return result;
459 } else if (index.is_tensor()) {
460 Tensor result = prev_dim_result;
461 const Tensor& tensor = index.tensor();
462 auto scalar_type = tensor.scalar_type();
463 if (tensor.dim() == 0 &&
464 at::isIntegralType(scalar_type, /*includeBool=*/true)) {
465 if (scalar_type != at::kByte && scalar_type != at::kBool) {
466 result = impl::applySelect(
467 result,
468 *dim_ptr,
469 tensor.item<int64_t>(),
470 real_dim,
471 original_tensor_device,
472 prev_dim_result_sizes);
473 } else {
474 result = result.unsqueeze(*dim_ptr);
475 if (scalar_type == at::kBool) {
476 impl::recordTensorIndex(
477 impl::boolToIndexingTensor(
478 result, tensor.item<bool>() != 0, original_tensor_device),
479 outIndices,
480 dim_ptr);
481 } else {
482 impl::recordTensorIndex(
483 impl::boolToIndexingTensor(
484 result, tensor.item<uint8_t>() != 0, original_tensor_device),
485 outIndices,
486 dim_ptr);
487 }
488 }
489 } else {
490 impl::recordTensorIndex(tensor, outIndices, dim_ptr);
491 }
492 return result;
493 } else {
494 TORCH_INTERNAL_ASSERT(false, "Invalid TensorIndex type");
495 }
496}
497
498namespace impl {
499// This mirrors `applySlicing` in
500// torch/csrc/autograd/python_variable_indexing.cpp
501static inline Tensor applySlicing(
502 const Tensor& self,
503 const ArrayRef<TensorIndex>& indices,
504 std::vector<Tensor>& outIndices,
505 bool disable_slice_optimization,
506 const at::Device& self_device,
507 const c10::optional<SymIntArrayRef>& self_sizes) {
508 int64_t dim = 0;
509 int64_t specified_dims = impl::count_specified_dimensions(indices);
510
511 // See NOTE [nested tensor size for indexing]
512 if (self_sizes.has_value()) {
513 TORCH_CHECK_INDEX(
514 specified_dims <= (int64_t)self_sizes->size(),
515 "too many indices for tensor of dimension ",
516 (int)self_sizes->size());
517 }
518
519 Tensor result = self;
520 for (const auto i : c10::irange(indices.size())) {
521 auto& obj = indices[i];
522 // See NOTE [nested tensor size for indexing]
523 c10::optional<SymIntArrayRef> result_sizes = result.is_nested()
524 ? c10::optional<SymIntArrayRef>(c10::nullopt)
525 : c10::optional<SymIntArrayRef>(result.sym_sizes());
526 result = handleDimInMultiDimIndexing(
527 /*prev_dim_result=*/result,
528 /*original_tensor=*/self,
529 /*index=*/obj,
530 /*dim=*/&dim,
531 /*specified_dims=*/&specified_dims,
532 /*real_dim=*/i,
533 /*outIndices=*/outIndices,
534 /*disable_slice_optimization=*/disable_slice_optimization,
535 /*original_tensor_device=*/self_device,
536 /*prev_dim_result_sizes=*/result_sizes);
537 }
538 return result;
539}
540} // namespace impl
541
542static inline Tensor dispatch_index(
543 const Tensor& self,
544 std::vector<Tensor>&& indices) {
545 return self.index(impl::typeConvertIndices(self, std::move(indices)));
546}
547
548static inline Tensor dispatch_index_put_(
549 Tensor& self,
550 std::vector<Tensor>&& indices,
551 const Tensor& value) {
552 return self.index_put_(
553 impl::typeConvertIndices(self, std::move(indices)), value);
554}
555
556// NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing
557// functions from Python ]
558//
559// Question: When should we set `disable_slice_optimization` to `true` when
560// calling C++ tensor indexing functions from Python indexing code?
561//
562// Answer: What "slice optimization" means: when we have a slicing expression
563// like `x[0:5, 0]`, where the sliced tensor was of size 5 in dimension 0, we
564// would skip dispatching the actual slice call as an optimization. However,
565// here are the cases where we DON'T want this optimization:
566//
567// 1. When we are doing 1-D slicing (e.g. `tensor[:]`).
568// Reason: we always return a shallow copy for expressions such as
569// `tensor[:]` / `tensor[...]` / `tensor[:, :]`. (Note that for `tensor[:,
570// :]`, we return an alias of `tensor` by doing the following:
571// ```
572// Tensor sliced = impl::applySlicing(self, indices, tensorIndices,
573// disable_slice_optimization, self_device, self_sizes); if
574// (tensorIndices.empty()) {
575// if (sliced.is_same(self)) {
576// // ensure we return a shallow copy for things like x[...]
577// sliced = at::alias(sliced);
578// }
579// return sliced;
580// }
581// ```)
582// 2. When we are doing JIT tracing.
583// Reason: JIT tracing needs the `self.slice(...)` call to properly trace the
584// slice operation.
585
586// This mirrors `THPVariable_getitem` in
587// torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting
588// `disable_slice_optimization` when calling C++ tensor indexing functions from
589// Python ]
590static inline Tensor get_item(
591 const Tensor& self,
592 const ArrayRef<TensorIndex>& indices,
593 bool disable_slice_optimization = false) {
594 at::Device self_device = self.device();
595 // NOTE [nested tensor size for indexing]
596 // nested tensor does not have a size (yet) so for now we represent its size
597 // as null may need to be changed after we reach a better solution for nested
598 // tensor size
599 c10::optional<SymIntArrayRef> self_sizes = self.is_nested()
600 ? c10::optional<SymIntArrayRef>(c10::nullopt)
601 : c10::optional<SymIntArrayRef>(self.sym_sizes());
602
603 // handle simple types: integers, slices, none, ellipsis, bool
604 if (indices.size() == 1) {
605 const TensorIndex& index = indices[0];
606 if (index.is_integer()) {
607 return impl::applySelect(
608 self, 0, index.integer(), 0, self_device, self_sizes);
609 } else if (index.is_slice()) {
610 return impl::applySlice(
611 self,
612 0,
613 index.slice().start(),
614 index.slice().stop(),
615 index.slice().step(),
616 /*disable_slice_optimization=*/true,
617 self_device,
618 self_sizes);
619 } else if (index.is_none()) {
620 return self.unsqueeze(0);
621 } else if (index.is_ellipsis()) {
622 return at::alias(self);
623 } else if (index.is_boolean()) {
624 Tensor result = self.unsqueeze(0);
625 return dispatch_index(
626 result,
627 std::vector<Tensor>{impl::boolToIndexingTensor(
628 result, index.boolean(), self_device)});
629 }
630 }
631
632 std::vector<Tensor> tensorIndices;
633 Tensor sliced = impl::applySlicing(
634 self,
635 indices,
636 tensorIndices,
637 disable_slice_optimization,
638 self_device,
639 self_sizes);
640 if (tensorIndices.empty()) {
641 if (sliced.is_same(self)) {
642 // ensure we return a shallow copy for things like x[...]
643 sliced = at::alias(sliced);
644 }
645 return sliced;
646 }
647
648 // indexing by tensors ("advanced" indexing)
649 return dispatch_index(sliced, std::move(tensorIndices));
650}
651
652// This mirrors `THPVariable_setitem` in
653// torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a
654// Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++
655// tensor indexing functions from Python ]
656static inline void set_item(
657 const Tensor& self,
658 const ArrayRef<TensorIndex>& indices,
659 const Tensor& value,
660 bool disable_slice_optimization = false) {
661 at::Device self_device = self.device();
662 SymIntArrayRef self_sizes = self.sym_sizes();
663
664 // handle simple types: integers, slices, ellipsis, bool
665 if (indices.size() == 1) {
666 const TensorIndex& index = indices[0];
667 if (index.is_boolean() && !index.boolean()) {
668 // do nothing for false (technically we should check the size, but we
669 // don't have real 0-sized shapes.
670 return;
671 } else if (index.is_ellipsis()) {
672 copy_to(self, value);
673 return;
674 } else if (index.is_none() || (index.is_boolean() && index.boolean())) {
675 copy_to(self.unsqueeze(0), value);
676 return;
677 } else if (index.is_integer()) {
678 copy_to(
679 impl::applySelect(
680 self, 0, index.integer(), 0, self_device, self_sizes),
681 value);
682 return;
683 } else if (index.is_slice()) {
684 copy_to(
685 impl::applySlice(
686 self,
687 0,
688 index.slice().start(),
689 index.slice().stop(),
690 index.slice().step(),
691 /*disable_slice_optimization=*/disable_slice_optimization,
692 self_device,
693 self_sizes),
694 value);
695 return;
696 }
697 }
698
699 std::vector<Tensor> tensorIndices;
700 Tensor sliced = impl::applySlicing(
701 self,
702 indices,
703 tensorIndices,
704 disable_slice_optimization,
705 self_device,
706 self_sizes);
707 if (tensorIndices.empty()) {
708 copy_to(sliced, value);
709 return;
710 }
711
712 SymIntArrayRef valueSizes = value.sym_sizes();
713 SymIntArrayRef slicedValueSizes = slicePrefix1sSize(valueSizes);
714 Tensor valuesSliced;
715 if (!valueSizes.equals(slicedValueSizes)) {
716 valuesSliced = value.view_symint(slicedValueSizes);
717 } else {
718 valuesSliced = value;
719 }
720 dispatch_index_put_(sliced, std::move(tensorIndices), valuesSliced);
721 return;
722}
723
724} // namespace indexing
725} // namespace at
726