1 | #pragma once |
2 | |
3 | #include <ATen/core/Tensor.h> |
4 | #include <ATen/native/ResizeCommon.h> |
5 | #include <ATen/EmptyTensor.h> |
6 | #include <ATen/TensorUtils.h> |
7 | |
8 | #include <c10/core/CPUAllocator.h> |
9 | |
10 | #include <utility> |
11 | |
12 | |
13 | namespace at { namespace native { |
14 | |
15 | // TODO: make all operations that resize given outputs use this function |
16 | // for consistency and maintainability. |
17 | // Some operations like `cat` might not be able to make the use of |
18 | // resize_output directly. For more details to understand how it works in `cat`, |
19 | // see https://github.com/pytorch/pytorch/pull/62560#discussion_r687363362 |
20 | // Resizes outputs |
21 | // Functions accepting output tensors, like with the "out" kwarg, should |
22 | // call this function to handle resizing their output tensor. |
23 | // Issues a warning if the output tensor has one or more elements and |
24 | // needs resizing |
25 | // NOTE: In the future the warning will become an error |
26 | // Returns a bool saying whether or not the resize actually happened or not |
27 | TORCH_API bool resize_output(const Tensor& output, IntArrayRef shape); |
28 | TORCH_API bool resize_output_symint(const Tensor& output, SymIntArrayRef shape); |
29 | |
30 | // Utility for resize_output |
31 | // Returns a bool saying resize should happen or not and |
32 | // raises a warning if resizing for one or more elements |
33 | TORCH_API bool resize_output_check(const Tensor& output, IntArrayRef shape); |
34 | TORCH_API bool resize_output_check_symint(const Tensor& output, SymIntArrayRef shape); |
35 | |
36 | TORCH_API void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes); |
37 | |
38 | static inline void maybe_resize_storage_cpu(TensorImpl* self, size_t new_size_bytes) { |
39 | // It does not make sense to try to resize a storage |
40 | // to hold 0 elements, and this can break |
41 | // if storage_offset is positive but |
42 | // new_size is 0, so just bail in that case |
43 | // (same comment is in cuda/Resize.h) |
44 | if (self->numel() == 0) { |
45 | return; |
46 | } |
47 | |
48 | const Storage& storage = self->unsafe_storage(); |
49 | if (!storage) { |
50 | auto new_storage = c10::make_intrusive<StorageImpl>( |
51 | StorageImpl::use_byte_size_t(), |
52 | new_size_bytes, |
53 | c10::GetCPUAllocator(), |
54 | true); |
55 | self->set_storage_keep_dtype(std::move(new_storage)); |
56 | } else if (new_size_bytes > storage.nbytes()) { |
57 | resize_bytes_cpu(storage.unsafeGetStorageImpl(), new_size_bytes); |
58 | } |
59 | } |
60 | |
61 | TORCH_API TensorImpl* resize_impl_cpu_( |
62 | TensorImpl* self, |
63 | IntArrayRef size, |
64 | at::OptionalIntArrayRef stride, |
65 | bool resize_storage = true); |
66 | |
67 | template <typename T> |
68 | T maybe_convert_symint(c10::SymInt) = delete; |
69 | |
70 | template <> |
71 | inline c10::SymInt maybe_convert_symint(c10::SymInt x) { return x; } |
72 | |
73 | template <> |
74 | inline int64_t maybe_convert_symint(c10::SymInt x) { return x.expect_int(); } |
75 | |
76 | template <typename T> |
77 | static inline void checkInBoundsForStorage( |
78 | ArrayRef<T> size, |
79 | ArrayRef<T> stride, |
80 | T storage_offset, |
81 | const caffe2::TypeMeta data_type, |
82 | const Storage& new_storage) { |
83 | T storage_size_bytes = |
84 | at::detail::computeStorageNbytes(size, stride, data_type.itemsize()); |
85 | T storage_offset_bytes = storage_offset * data_type.itemsize(); |
86 | if (storage_size_bytes == 0) { |
87 | // NB: (a tensor with arbitrary 0 dims)'s storage can have any numel. |
88 | return; |
89 | } |
90 | T new_storage_size_bytes = maybe_convert_symint<T>(new_storage.sym_nbytes()); |
91 | TORCH_CHECK( |
92 | storage_size_bytes + storage_offset_bytes <= new_storage_size_bytes, |
93 | "setStorage: sizes " , |
94 | size, |
95 | ", strides " , |
96 | stride, |
97 | "," |
98 | " storage offset " , |
99 | storage_offset, |
100 | ", and itemsize " , |
101 | data_type.itemsize(), |
102 | " requiring a storage size of " , |
103 | storage_size_bytes + storage_offset_bytes, |
104 | " are out of bounds for storage of size " , |
105 | new_storage_size_bytes); |
106 | } |
107 | |
108 | template <typename T> |
109 | static inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset, |
110 | ArrayRef<T> size, ArrayRef<T> stride) { |
111 | // FIXME: stride should be optional |
112 | if (stride.data()) { |
113 | TORCH_CHECK(size.size() == stride.size(), "unequal size length (" , size.size(), |
114 | ") and stride length (" , stride.size(), ")" ); |
115 | } |
116 | |
117 | #ifdef DEBUG |
118 | TORCH_CHECK(size.size() <= INT_MAX, "size length (" , size.size(), ") greater than INT_MAX" ); |
119 | #endif |
120 | |
121 | // storage: note this can't be replaced with result.set_(storage) as the semantics of that |
122 | // function is to set the tensor size to be equal to the size of the storage. |
123 | if (!result.storage().is_alias_of(storage)) { |
124 | // Caffe2 might have tensors whose storages are null, but we |
125 | // don't allow it in PyTorch. |
126 | TORCH_INTERNAL_ASSERT(storage); |
127 | TORCH_INTERNAL_ASSERT(result.storage()); |
128 | |
129 | // We used to allow this, but this breaks device caching. |
130 | // Let's put an actual error message for this one. |
131 | TORCH_CHECK(result.storage().device() == storage.device(), |
132 | "Attempted to set the storage of a tensor on device \"" , result.storage().device(), |
133 | "\" to a storage on different device \"" , storage.device(), |
134 | "\". This is no longer allowed; the devices must match." ); |
135 | result.unsafeGetTensorImpl()->set_storage_keep_dtype(std::move(storage)); |
136 | } |
137 | |
138 | // storageOffset |
139 | TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset " , storage_offset); |
140 | } |
141 | |
142 | /** |
143 | * Set self's sizes, strides, and storage_offset. |
144 | * (size, stride, storage_offset) must be in bounds for self's storage. |
145 | */ |
146 | template <typename T> |
147 | inline void setStrided( |
148 | const Tensor& self, |
149 | ArrayRef<T> size, |
150 | ArrayRef<T> stride, |
151 | T storage_offset) { |
152 | TORCH_CHECK(size.size() == stride.size(), "mismatch in length of strides and shape" ); |
153 | for (const auto& val : stride) { |
154 | TORCH_CHECK(val >= 0, |
155 | "as_strided: Negative strides are not supported at the moment, " |
156 | "got strides: " , stride); |
157 | } |
158 | |
159 | auto* self_ = self.unsafeGetTensorImpl(); |
160 | checkInBoundsForStorage( |
161 | size, stride, storage_offset, self_->dtype(), self_->storage()); |
162 | |
163 | /* storage offset */ |
164 | TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset " , storage_offset); |
165 | self_->set_sizes_and_strides(size, stride, c10::make_optional(storage_offset)); |
166 | } |
167 | |
168 | }} |
169 | |