1#pragma once
2#include <ATen/core/Tensor.h>
3
4namespace at {
5
6namespace detail {
7
8TORCH_API inline void noopDelete(void*) {}
9
10} // namespace detail
11
12/// Provides a fluent API to construct tensors from external data.
13///
14/// The fluent API can be used instead of `from_blob` functions in case the
15/// required set of parameters does not align with the existing overloads.
16///
17/// at::Tensor tensor = at::for_blob(data, sizes)
18/// .strides(strides)
19/// .context(context, [](void *ctx) { delete static_cast<Ctx*>(ctx); })
20/// .options(...)
21/// .make_tensor();
22///
23class TORCH_API TensorMaker {
24 friend TensorMaker for_blob(void* data, IntArrayRef sizes) noexcept;
25
26 public:
27 using ContextDeleter = DeleterFnPtr;
28
29 TensorMaker& strides(OptionalIntArrayRef value) noexcept {
30 strides_ = value;
31
32 return *this;
33 }
34
35 TensorMaker& storage_offset(optional<int64_t> value) noexcept {
36 storage_offset_ = value;
37
38 return *this;
39 }
40
41 TensorMaker& deleter(std::function<void(void*)> value) noexcept {
42 deleter_ = std::move(value);
43
44 return *this;
45 }
46
47 TensorMaker& context(void* value, ContextDeleter deleter = nullptr) noexcept {
48 ctx_ = std::unique_ptr<void, ContextDeleter>{
49 value, deleter != nullptr ? deleter : detail::noopDelete};
50
51 return *this;
52 }
53
54 TensorMaker& target_device(optional<Device> value) noexcept {
55 device_ = value;
56
57 return *this;
58 }
59
60 TensorMaker& options(TensorOptions value) noexcept {
61 opts_ = value;
62
63 return *this;
64 }
65
66 Tensor make_tensor();
67
68 private:
69 explicit TensorMaker(void* data, IntArrayRef sizes) noexcept
70 : data_{data}, sizes_{sizes} {}
71
72 std::size_t computeStorageSize() const noexcept;
73
74 DataPtr makeDataPtrFromDeleter() const;
75
76 DataPtr makeDataPtrFromContext() noexcept;
77
78 IntArrayRef makeTempSizes() const noexcept;
79
80 void* data_;
81 IntArrayRef sizes_;
82 OptionalIntArrayRef strides_{};
83 optional<int64_t> storage_offset_{};
84 std::function<void(void*)> deleter_{};
85 std::unique_ptr<void, ContextDeleter> ctx_{nullptr, detail::noopDelete};
86 optional<Device> device_{};
87 TensorOptions opts_{};
88};
89
90inline TensorMaker for_blob(void* data, IntArrayRef sizes) noexcept {
91 return TensorMaker{data, sizes};
92}
93
94inline Tensor from_blob(
95 void* data,
96 IntArrayRef sizes,
97 IntArrayRef strides,
98 const std::function<void(void*)>& deleter,
99 const TensorOptions& options = {},
100 const c10::optional<Device> target_device = c10::nullopt) {
101 return for_blob(data, sizes)
102 .strides(strides)
103 .deleter(deleter)
104 .options(options)
105 .target_device(target_device)
106 .make_tensor();
107}
108
109inline Tensor from_blob(
110 void* data,
111 IntArrayRef sizes,
112 IntArrayRef strides,
113 int64_t storage_offset,
114 const std::function<void(void*)>& deleter,
115 const TensorOptions& options = {},
116 const c10::optional<Device> target_device = c10::nullopt) {
117 return for_blob(data, sizes)
118 .strides(strides)
119 .storage_offset(storage_offset)
120 .deleter(deleter)
121 .options(options)
122 .target_device(target_device)
123 .make_tensor();
124}
125
126inline Tensor from_blob(
127 void* data,
128 IntArrayRef sizes,
129 const std::function<void(void*)>& deleter,
130 const TensorOptions& options = {}) {
131 return for_blob(data, sizes)
132 .deleter(deleter)
133 .options(options)
134 .make_tensor();
135}
136
137inline Tensor from_blob(
138 void* data,
139 IntArrayRef sizes,
140 IntArrayRef strides,
141 const TensorOptions& options = {}) {
142 return for_blob(data, sizes)
143 .strides(strides)
144 .options(options)
145 .make_tensor();
146}
147
148inline Tensor from_blob(
149 void* data,
150 IntArrayRef sizes,
151 const TensorOptions& options = {}) {
152 return for_blob(data, sizes).options(options).make_tensor();
153}
154
155} // namespace at
156