1#pragma once
2
3#include <ATen/core/Tensor.h>
4#include <ATen/native/ResizeCommon.h>
5#include <ATen/EmptyTensor.h>
6#include <ATen/TensorUtils.h>
7
8#include <c10/core/CPUAllocator.h>
9
10#include <utility>
11
12
13namespace at { namespace native {
14
15// TODO: make all operations that resize given outputs use this function
16// for consistency and maintainability.
17// Some operations like `cat` might not be able to make the use of
18// resize_output directly. For more details to understand how it works in `cat`,
19// see https://github.com/pytorch/pytorch/pull/62560#discussion_r687363362
20// Resizes outputs
21// Functions accepting output tensors, like with the "out" kwarg, should
22// call this function to handle resizing their output tensor.
23// Issues a warning if the output tensor has one or more elements and
24// needs resizing
25// NOTE: In the future the warning will become an error
26// Returns a bool saying whether or not the resize actually happened or not
27TORCH_API bool resize_output(const Tensor& output, IntArrayRef shape);
28TORCH_API bool resize_output_symint(const Tensor& output, SymIntArrayRef shape);
29
30// Utility for resize_output
31// Returns a bool saying resize should happen or not and
32// raises a warning if resizing for one or more elements
33TORCH_API bool resize_output_check(const Tensor& output, IntArrayRef shape);
34TORCH_API bool resize_output_check_symint(const Tensor& output, SymIntArrayRef shape);
35
36TORCH_API void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes);
37
38static inline void maybe_resize_storage_cpu(TensorImpl* self, size_t new_size_bytes) {
39 // It does not make sense to try to resize a storage
40 // to hold 0 elements, and this can break
41 // if storage_offset is positive but
42 // new_size is 0, so just bail in that case
43 // (same comment is in cuda/Resize.h)
44 if (self->numel() == 0) {
45 return;
46 }
47
48 const Storage& storage = self->unsafe_storage();
49 if (!storage) {
50 auto new_storage = c10::make_intrusive<StorageImpl>(
51 StorageImpl::use_byte_size_t(),
52 new_size_bytes,
53 c10::GetCPUAllocator(),
54 true);
55 self->set_storage_keep_dtype(std::move(new_storage));
56 } else if (new_size_bytes > storage.nbytes()) {
57 resize_bytes_cpu(storage.unsafeGetStorageImpl(), new_size_bytes);
58 }
59}
60
61TORCH_API TensorImpl* resize_impl_cpu_(
62 TensorImpl* self,
63 IntArrayRef size,
64 at::OptionalIntArrayRef stride,
65 bool resize_storage = true);
66
67template <typename T>
68T maybe_convert_symint(c10::SymInt) = delete;
69
70template <>
71inline c10::SymInt maybe_convert_symint(c10::SymInt x) { return x; }
72
73template <>
74inline int64_t maybe_convert_symint(c10::SymInt x) { return x.expect_int(); }
75
76template <typename T>
77static inline void checkInBoundsForStorage(
78 ArrayRef<T> size,
79 ArrayRef<T> stride,
80 T storage_offset,
81 const caffe2::TypeMeta data_type,
82 const Storage& new_storage) {
83 T storage_size_bytes =
84 at::detail::computeStorageNbytes(size, stride, data_type.itemsize());
85 T storage_offset_bytes = storage_offset * data_type.itemsize();
86 if (storage_size_bytes == 0) {
87 // NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
88 return;
89 }
90 T new_storage_size_bytes = maybe_convert_symint<T>(new_storage.sym_nbytes());
91 TORCH_CHECK(
92 storage_size_bytes + storage_offset_bytes <= new_storage_size_bytes,
93 "setStorage: sizes ",
94 size,
95 ", strides ",
96 stride,
97 ","
98 " storage offset ",
99 storage_offset,
100 ", and itemsize ",
101 data_type.itemsize(),
102 " requiring a storage size of ",
103 storage_size_bytes + storage_offset_bytes,
104 " are out of bounds for storage of size ",
105 new_storage_size_bytes);
106}
107
108template <typename T>
109static inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
110 ArrayRef<T> size, ArrayRef<T> stride) {
111 // FIXME: stride should be optional
112 if (stride.data()) {
113 TORCH_CHECK(size.size() == stride.size(), "unequal size length (", size.size(),
114 ") and stride length (", stride.size(), ")");
115 }
116
117#ifdef DEBUG
118 TORCH_CHECK(size.size() <= INT_MAX, "size length (", size.size(), ") greater than INT_MAX");
119#endif
120
121 // storage: note this can't be replaced with result.set_(storage) as the semantics of that
122 // function is to set the tensor size to be equal to the size of the storage.
123 if (!result.storage().is_alias_of(storage)) {
124 // Caffe2 might have tensors whose storages are null, but we
125 // don't allow it in PyTorch.
126 TORCH_INTERNAL_ASSERT(storage);
127 TORCH_INTERNAL_ASSERT(result.storage());
128
129 // We used to allow this, but this breaks device caching.
130 // Let's put an actual error message for this one.
131 TORCH_CHECK(result.storage().device() == storage.device(),
132 "Attempted to set the storage of a tensor on device \"", result.storage().device(),
133 "\" to a storage on different device \"", storage.device(),
134 "\". This is no longer allowed; the devices must match.");
135 result.unsafeGetTensorImpl()->set_storage_keep_dtype(std::move(storage));
136 }
137
138 // storageOffset
139 TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
140}
141
142/**
143 * Set self's sizes, strides, and storage_offset.
144 * (size, stride, storage_offset) must be in bounds for self's storage.
145 */
146template <typename T>
147inline void setStrided(
148 const Tensor& self,
149 ArrayRef<T> size,
150 ArrayRef<T> stride,
151 T storage_offset) {
152 TORCH_CHECK(size.size() == stride.size(), "mismatch in length of strides and shape");
153 for (const auto& val : stride) {
154 TORCH_CHECK(val >= 0,
155 "as_strided: Negative strides are not supported at the moment, "
156 "got strides: ", stride);
157 }
158
159 auto* self_ = self.unsafeGetTensorImpl();
160 checkInBoundsForStorage(
161 size, stride, storage_offset, self_->dtype(), self_->storage());
162
163 /* storage offset */
164 TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
165 self_->set_sizes_and_strides(size, stride, c10::make_optional(storage_offset));
166}
167
168}}
169