1#include <array>
2
3#include <ATen/Functions.h>
4#include <ATen/Utils.h>
5
6namespace at {
7
8Tensor TensorMaker::make_tensor() {
9 AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove.
10 tracer::impl::NoTracerDispatchMode tracer_guard{};
11
12 check_size_nonnegative(sizes_);
13
14 TORCH_CHECK_VALUE(
15 !deleter_ || !ctx_,
16 "The deleter and context arguments are mutually exclusive.");
17
18 if (device_ == nullopt) {
19 device_ = globalContext().getDeviceFromPtr(data_, opts_.device().type());
20 }
21
22 if (opts_.device().has_index()) {
23 // clang-format off
24 TORCH_CHECK_VALUE(
25 opts_.device() == *device_,
26 "Specified device ", opts_.device(), " does not match device of data ", *device_);
27 // clang-format on
28 }
29
30 std::size_t size_bytes = computeStorageSize();
31
32 DataPtr data_ptr{};
33 if (deleter_) {
34 data_ptr = makeDataPtrFromDeleter();
35 } else {
36 data_ptr = makeDataPtrFromContext();
37 }
38
39 Storage storage{Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr)};
40
41 Tensor tensor = detail::make_tensor<TensorImpl>(
42 std::move(storage), opts_.computeDispatchKey(), opts_.dtype());
43
44 TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
45 if (strides_) {
46 tensor_impl->set_sizes_and_strides(sizes_, *strides_);
47 } else {
48 tensor_impl->set_sizes_contiguous(sizes_);
49 }
50 if (storage_offset_) {
51 tensor_impl->set_storage_offset(*storage_offset_);
52 }
53
54 return tensor;
55 }
56
57 std::size_t TensorMaker::computeStorageSize() const noexcept {
58 std::size_t itemsize = opts_.dtype().itemsize();
59
60 if (strides_) {
61 auto storage_size = detail::computeStorageNbytes(sizes_, *strides_, itemsize);
62 if (storage_offset_) {
63 storage_size += storage_offset_.value();
64 }
65 return storage_size;
66 }
67
68 std::size_t size = 1;
69 for (std::int64_t s : sizes_) {
70 size *= static_cast<std::size_t>(s);
71 }
72 auto storage_size = size * itemsize;
73 if (storage_offset_) {
74 storage_size += storage_offset_.value();
75 }
76 return storage_size;
77 }
78
79 inline DataPtr TensorMaker::makeDataPtrFromDeleter() const {
80 return InefficientStdFunctionContext::makeDataPtr(data_, deleter_, *device_);
81 }
82
83 inline DataPtr TensorMaker::makeDataPtrFromContext() noexcept {
84 return DataPtr{data_, ctx_.release(), ctx_.get_deleter(), *device_};
85 }
86
87 IntArrayRef TensorMaker::makeTempSizes() const noexcept {
88 static std::int64_t zeros[5] = {0, 0, 0, 0, 0};
89 if (opts_.has_memory_format()) {
90 MemoryFormat format = *opts_.memory_format_opt();
91 if (format == MemoryFormat::ChannelsLast) {
92 return IntArrayRef(zeros, 4);
93 }
94 if (format == MemoryFormat::ChannelsLast3d) {
95 return IntArrayRef(zeros, 5);
96 }
97 }
98 return IntArrayRef(zeros, 1);
99 }
100
101} // namespace at
102