1 | #pragma once |
2 | |
3 | #include <ATen/EmptyTensor.h> |
4 | #include <ATen/native/ResizeCommon.h> |
5 | |
6 | #include <c10/cuda/CUDAGuard.h> |
7 | |
8 | namespace at { namespace native { |
9 | |
10 | TORCH_CUDA_CPP_API void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes); |
11 | |
12 | static inline void maybe_resize_storage_cuda(TensorImpl* self, size_t new_size_bytes) { |
13 | // It does not make sense to try to resize a storage |
14 | // to hold 0 elements, and this can break |
15 | // if storage_offset is positive but |
16 | // new_size is 0, so just bail in that case |
17 | // (same comment is in Resize.h) |
18 | if (self->numel() == 0) { |
19 | return; |
20 | } |
21 | |
22 | const Storage &storage = self->unsafe_storage(); |
23 | TORCH_CHECK(storage, "Tensor: invalid null storage" ); |
24 | if (new_size_bytes > storage.nbytes()) { |
25 | resize_bytes_cuda(storage.unsafeGetStorageImpl(), new_size_bytes); |
26 | } |
27 | } |
28 | |
29 | inline TensorImpl* resize_impl_cuda_( |
30 | TensorImpl* self, |
31 | IntArrayRef size, |
32 | at::OptionalIntArrayRef stride, |
33 | bool device_guard = true) { |
34 | if (self->sizes() == size && (!stride || self->strides() == stride)) { |
35 | return self; |
36 | } |
37 | |
38 | // NB: We don't need to hold the device guard when calling from TH |
39 | cuda::OptionalCUDAGuard guard; |
40 | if (device_guard) { |
41 | guard.set_index(self->storage().device().index()); |
42 | } |
43 | |
44 | const auto itemsize = self->dtype().itemsize(); |
45 | const auto storage_offset = self->storage_offset(); |
46 | size_t storage_size = 1; |
47 | if (stride) { |
48 | self->set_sizes_and_strides(size, *stride); |
49 | storage_size = at::detail::computeStorageNbytes( |
50 | size, *stride, itemsize, storage_offset); |
51 | } else { |
52 | self->set_sizes_contiguous(size); |
53 | storage_size = at::detail::computeStorageNbytesContiguous( |
54 | size, itemsize, storage_offset); |
55 | } |
56 | maybe_resize_storage_cuda(self, storage_size); |
57 | |
58 | return self; |
59 | } |
60 | |
61 | }} |
62 | |