1 | #include <array> |
2 | |
3 | #include <ATen/Functions.h> |
4 | #include <ATen/Utils.h> |
5 | |
6 | namespace at { |
7 | |
8 | Tensor TensorMaker::make_tensor() { |
9 | AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove. |
10 | tracer::impl::NoTracerDispatchMode tracer_guard{}; |
11 | |
12 | check_size_nonnegative(sizes_); |
13 | |
14 | TORCH_CHECK_VALUE( |
15 | !deleter_ || !ctx_, |
16 | "The deleter and context arguments are mutually exclusive." ); |
17 | |
18 | if (device_ == nullopt) { |
19 | device_ = globalContext().getDeviceFromPtr(data_, opts_.device().type()); |
20 | } |
21 | |
22 | if (opts_.device().has_index()) { |
23 | // clang-format off |
24 | TORCH_CHECK_VALUE( |
25 | opts_.device() == *device_, |
26 | "Specified device " , opts_.device(), " does not match device of data " , *device_); |
27 | // clang-format on |
28 | } |
29 | |
30 | std::size_t size_bytes = computeStorageSize(); |
31 | |
32 | DataPtr data_ptr{}; |
33 | if (deleter_) { |
34 | data_ptr = makeDataPtrFromDeleter(); |
35 | } else { |
36 | data_ptr = makeDataPtrFromContext(); |
37 | } |
38 | |
39 | Storage storage{Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr)}; |
40 | |
41 | Tensor tensor = detail::make_tensor<TensorImpl>( |
42 | std::move(storage), opts_.computeDispatchKey(), opts_.dtype()); |
43 | |
44 | TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl(); |
45 | if (strides_) { |
46 | tensor_impl->set_sizes_and_strides(sizes_, *strides_); |
47 | } else { |
48 | tensor_impl->set_sizes_contiguous(sizes_); |
49 | } |
50 | if (storage_offset_) { |
51 | tensor_impl->set_storage_offset(*storage_offset_); |
52 | } |
53 | |
54 | return tensor; |
55 | } |
56 | |
57 | std::size_t TensorMaker::computeStorageSize() const noexcept { |
58 | std::size_t itemsize = opts_.dtype().itemsize(); |
59 | |
60 | if (strides_) { |
61 | auto storage_size = detail::computeStorageNbytes(sizes_, *strides_, itemsize); |
62 | if (storage_offset_) { |
63 | storage_size += storage_offset_.value(); |
64 | } |
65 | return storage_size; |
66 | } |
67 | |
68 | std::size_t size = 1; |
69 | for (std::int64_t s : sizes_) { |
70 | size *= static_cast<std::size_t>(s); |
71 | } |
72 | auto storage_size = size * itemsize; |
73 | if (storage_offset_) { |
74 | storage_size += storage_offset_.value(); |
75 | } |
76 | return storage_size; |
77 | } |
78 | |
79 | inline DataPtr TensorMaker::makeDataPtrFromDeleter() const { |
80 | return InefficientStdFunctionContext::makeDataPtr(data_, deleter_, *device_); |
81 | } |
82 | |
83 | inline DataPtr TensorMaker::makeDataPtrFromContext() noexcept { |
84 | return DataPtr{data_, ctx_.release(), ctx_.get_deleter(), *device_}; |
85 | } |
86 | |
87 | IntArrayRef TensorMaker::makeTempSizes() const noexcept { |
88 | static std::int64_t zeros[5] = {0, 0, 0, 0, 0}; |
89 | if (opts_.has_memory_format()) { |
90 | MemoryFormat format = *opts_.memory_format_opt(); |
91 | if (format == MemoryFormat::ChannelsLast) { |
92 | return IntArrayRef(zeros, 4); |
93 | } |
94 | if (format == MemoryFormat::ChannelsLast3d) { |
95 | return IntArrayRef(zeros, 5); |
96 | } |
97 | } |
98 | return IntArrayRef(zeros, 1); |
99 | } |
100 | |
101 | } // namespace at |
102 | |