1 | #pragma once |
2 | #include <ATen/core/Tensor.h> |
3 | |
4 | namespace at { |
5 | |
6 | namespace detail { |
7 | |
8 | TORCH_API inline void noopDelete(void*) {} |
9 | |
10 | } // namespace detail |
11 | |
12 | /// Provides a fluent API to construct tensors from external data. |
13 | /// |
14 | /// The fluent API can be used instead of `from_blob` functions in case the |
15 | /// required set of parameters does not align with the existing overloads. |
16 | /// |
17 | /// at::Tensor tensor = at::for_blob(data, sizes) |
18 | /// .strides(strides) |
19 | /// .context(context, [](void *ctx) { delete static_cast<Ctx*>(ctx); }) |
20 | /// .options(...) |
21 | /// .make_tensor(); |
22 | /// |
23 | class TORCH_API TensorMaker { |
24 | friend TensorMaker for_blob(void* data, IntArrayRef sizes) noexcept; |
25 | |
26 | public: |
27 | using ContextDeleter = DeleterFnPtr; |
28 | |
29 | TensorMaker& strides(OptionalIntArrayRef value) noexcept { |
30 | strides_ = value; |
31 | |
32 | return *this; |
33 | } |
34 | |
35 | TensorMaker& storage_offset(optional<int64_t> value) noexcept { |
36 | storage_offset_ = value; |
37 | |
38 | return *this; |
39 | } |
40 | |
41 | TensorMaker& deleter(std::function<void(void*)> value) noexcept { |
42 | deleter_ = std::move(value); |
43 | |
44 | return *this; |
45 | } |
46 | |
47 | TensorMaker& context(void* value, ContextDeleter deleter = nullptr) noexcept { |
48 | ctx_ = std::unique_ptr<void, ContextDeleter>{ |
49 | value, deleter != nullptr ? deleter : detail::noopDelete}; |
50 | |
51 | return *this; |
52 | } |
53 | |
54 | TensorMaker& target_device(optional<Device> value) noexcept { |
55 | device_ = value; |
56 | |
57 | return *this; |
58 | } |
59 | |
60 | TensorMaker& options(TensorOptions value) noexcept { |
61 | opts_ = value; |
62 | |
63 | return *this; |
64 | } |
65 | |
66 | Tensor make_tensor(); |
67 | |
68 | private: |
69 | explicit TensorMaker(void* data, IntArrayRef sizes) noexcept |
70 | : data_{data}, sizes_{sizes} {} |
71 | |
72 | std::size_t computeStorageSize() const noexcept; |
73 | |
74 | DataPtr makeDataPtrFromDeleter() const; |
75 | |
76 | DataPtr makeDataPtrFromContext() noexcept; |
77 | |
78 | IntArrayRef makeTempSizes() const noexcept; |
79 | |
80 | void* data_; |
81 | IntArrayRef sizes_; |
82 | OptionalIntArrayRef strides_{}; |
83 | optional<int64_t> storage_offset_{}; |
84 | std::function<void(void*)> deleter_{}; |
85 | std::unique_ptr<void, ContextDeleter> ctx_{nullptr, detail::noopDelete}; |
86 | optional<Device> device_{}; |
87 | TensorOptions opts_{}; |
88 | }; |
89 | |
90 | inline TensorMaker for_blob(void* data, IntArrayRef sizes) noexcept { |
91 | return TensorMaker{data, sizes}; |
92 | } |
93 | |
94 | inline Tensor from_blob( |
95 | void* data, |
96 | IntArrayRef sizes, |
97 | IntArrayRef strides, |
98 | const std::function<void(void*)>& deleter, |
99 | const TensorOptions& options = {}, |
100 | const c10::optional<Device> target_device = c10::nullopt) { |
101 | return for_blob(data, sizes) |
102 | .strides(strides) |
103 | .deleter(deleter) |
104 | .options(options) |
105 | .target_device(target_device) |
106 | .make_tensor(); |
107 | } |
108 | |
109 | inline Tensor from_blob( |
110 | void* data, |
111 | IntArrayRef sizes, |
112 | IntArrayRef strides, |
113 | int64_t storage_offset, |
114 | const std::function<void(void*)>& deleter, |
115 | const TensorOptions& options = {}, |
116 | const c10::optional<Device> target_device = c10::nullopt) { |
117 | return for_blob(data, sizes) |
118 | .strides(strides) |
119 | .storage_offset(storage_offset) |
120 | .deleter(deleter) |
121 | .options(options) |
122 | .target_device(target_device) |
123 | .make_tensor(); |
124 | } |
125 | |
126 | inline Tensor from_blob( |
127 | void* data, |
128 | IntArrayRef sizes, |
129 | const std::function<void(void*)>& deleter, |
130 | const TensorOptions& options = {}) { |
131 | return for_blob(data, sizes) |
132 | .deleter(deleter) |
133 | .options(options) |
134 | .make_tensor(); |
135 | } |
136 | |
137 | inline Tensor from_blob( |
138 | void* data, |
139 | IntArrayRef sizes, |
140 | IntArrayRef strides, |
141 | const TensorOptions& options = {}) { |
142 | return for_blob(data, sizes) |
143 | .strides(strides) |
144 | .options(options) |
145 | .make_tensor(); |
146 | } |
147 | |
148 | inline Tensor from_blob( |
149 | void* data, |
150 | IntArrayRef sizes, |
151 | const TensorOptions& options = {}) { |
152 | return for_blob(data, sizes).options(options).make_tensor(); |
153 | } |
154 | |
155 | } // namespace at |
156 | |