1#pragma once
2
3#include <ATen/EmptyTensor.h>
4#include <ATen/native/ResizeCommon.h>
5
6#include <c10/cuda/CUDAGuard.h>
7
8namespace at { namespace native {
9
10TORCH_CUDA_CPP_API void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes);
11
12static inline void maybe_resize_storage_cuda(TensorImpl* self, size_t new_size_bytes) {
13 // It does not make sense to try to resize a storage
14 // to hold 0 elements, and this can break
15 // if storage_offset is positive but
16 // new_size is 0, so just bail in that case
17 // (same comment is in Resize.h)
18 if (self->numel() == 0) {
19 return;
20 }
21
22 const Storage &storage = self->unsafe_storage();
23 TORCH_CHECK(storage, "Tensor: invalid null storage");
24 if (new_size_bytes > storage.nbytes()) {
25 resize_bytes_cuda(storage.unsafeGetStorageImpl(), new_size_bytes);
26 }
27}
28
29inline TensorImpl* resize_impl_cuda_(
30 TensorImpl* self,
31 IntArrayRef size,
32 at::OptionalIntArrayRef stride,
33 bool device_guard = true) {
34 if (self->sizes() == size && (!stride || self->strides() == stride)) {
35 return self;
36 }
37
38 // NB: We don't need to hold the device guard when calling from TH
39 cuda::OptionalCUDAGuard guard;
40 if (device_guard) {
41 guard.set_index(self->storage().device().index());
42 }
43
44 const auto itemsize = self->dtype().itemsize();
45 const auto storage_offset = self->storage_offset();
46 size_t storage_size = 1;
47 if (stride) {
48 self->set_sizes_and_strides(size, *stride);
49 storage_size = at::detail::computeStorageNbytes(
50 size, *stride, itemsize, storage_offset);
51 } else {
52 self->set_sizes_contiguous(size);
53 storage_size = at::detail::computeStorageNbytesContiguous(
54 size, itemsize, storage_offset);
55 }
56 maybe_resize_storage_cuda(self, storage_size);
57
58 return self;
59}
60
61}}
62