1 | #pragma once |
2 | |
3 | #include <ATen/ATen.h> |
4 | /** |
5 | * WARNING: EValue is a class used by Executorch, for its boxed operators. It |
6 | * contains similar logic as `IValue` in PyTorch, by providing APIs to convert |
7 | * boxed values to unboxed values. |
8 | * |
9 | * It's mirroring a fbcode internal source file |
10 | * [`EValue.h`](https://www.internalfb.com/code/fbsource/xplat/executorch/core/values/Evalue.h). |
11 | * |
12 | * The reason why we are mirroring this class, is to make sure we have CI job |
13 | * coverage on torchgen logic, given that torchgen is used for both Executorch |
14 | * and PyTorch. |
15 | * |
16 | * If any of the logic here needs to be changed, please update fbcode version of |
17 | * `Evalue.h` as well. These two versions will be merged as soon as Executorch |
18 | * is in OSS (hopefully by Q2 2023). |
19 | */ |
20 | namespace torch { |
21 | namespace executor { |
22 | |
23 | #define ET_CHECK_MSG TORCH_CHECK_MSG |
24 | #define EXECUTORCH_FORALL_TAGS(_) \ |
25 | _(None) \ |
26 | _(Tensor) \ |
27 | _(String) \ |
28 | _(Double) \ |
29 | _(Int) \ |
30 | _(Bool) \ |
31 | _(ListBool) \ |
32 | _(ListDouble) \ |
33 | _(ListInt) \ |
34 | _(ListTensor) \ |
35 | _(ListScalar) \ |
36 | _(ListOptionalTensor) |
37 | |
38 | enum class Tag : uint32_t { |
39 | #define DEFINE_TAG(x) x, |
40 | EXECUTORCH_FORALL_TAGS(DEFINE_TAG) |
41 | #undef DEFINE_TAG |
42 | }; |
43 | |
44 | struct EValue; |
45 | |
46 | template <typename T> |
47 | struct evalue_to_const_ref_overload_return { |
48 | using type = T; |
49 | }; |
50 | |
51 | template <> |
52 | struct evalue_to_const_ref_overload_return<at::Tensor> { |
53 | using type = const at::Tensor&; |
54 | }; |
55 | |
56 | template <typename T> |
57 | struct evalue_to_ref_overload_return { |
58 | using type = T; |
59 | }; |
60 | |
61 | template <> |
62 | struct evalue_to_ref_overload_return<at::Tensor> { |
63 | using type = at::Tensor&; |
64 | }; |
65 | |
66 | /* |
67 | * Helper class used to correlate EValues in the executor table, with the |
68 | * unwrapped list of the proper type. Because values in the runtime's values |
69 | * table can change during execution, we cannot statically allocate list of |
70 | * objects at deserialization. Imagine the serialized list says index 0 in the |
71 | * value table is element 2 in the list, but during execution the value in |
72 | * element 2 changes (in the case of tensor this means the TensorImpl* stored in |
73 | * the tensor changes). To solve this instead they must be created dynamically |
74 | * whenever they are used. |
75 | */ |
76 | template <typename T> |
77 | class EValObjectList { |
78 | public: |
79 | EValObjectList() = default; |
80 | /* |
81 | * Wrapped_vals is a list of pointers into the values table of the runtime |
82 | * whose destinations correlate with the elements of the list, unwrapped_vals |
83 | * is a container of the same size whose serves as memory to construct the |
84 | * unwrapped vals. |
85 | */ |
86 | EValObjectList(EValue** wrapped_vals, T* unwrapped_vals, int size) |
87 | : wrapped_vals_(wrapped_vals, size), unwrapped_vals_(unwrapped_vals) {} |
88 | /* |
89 | * Constructs and returns the list of T specified by the EValue pointers |
90 | */ |
91 | at::ArrayRef<T> get() const; |
92 | |
93 | private: |
94 | // Source of truth for the list |
95 | at::ArrayRef<EValue*> wrapped_vals_; |
96 | // Same size as wrapped_vals |
97 | mutable T* unwrapped_vals_; |
98 | }; |
99 | |
100 | // Aggregate typing system similar to IValue only slimmed down with less |
101 | // functionality, no dependencies on atomic, and fewer supported types to better |
102 | // suit embedded systems (ie no intrusive ptr) |
103 | struct EValue { |
104 | union Payload { |
105 | // When in ATen mode at::Tensor is not trivially copyable, this nested union |
106 | // lets us handle tensor as a special case while leaving the rest of the |
107 | // fields in a simple state instead of requiring a switch on tag everywhere. |
108 | union TriviallyCopyablePayload { |
109 | TriviallyCopyablePayload() : as_int(0) {} |
110 | // Scalar supported through these 3 types |
111 | int64_t as_int; |
112 | double as_double; |
113 | bool as_bool; |
114 | // TODO(jakeszwe): convert back to pointers to optimize size of this |
115 | // struct |
116 | at::ArrayRef<char> as_string; |
117 | at::ArrayRef<int64_t> as_int_list; |
118 | at::ArrayRef<double> as_double_list; |
119 | at::ArrayRef<bool> as_bool_list; |
120 | EValObjectList<at::Tensor> as_tensor_list; |
121 | EValObjectList<at::optional<at::Tensor>> as_list_optional_tensor; |
122 | } copyable_union; |
123 | |
124 | // Since a Tensor just holds a TensorImpl*, there's no value to use Tensor* |
125 | // here. |
126 | at::Tensor as_tensor; |
127 | |
128 | Payload() {} |
129 | ~Payload() {} |
130 | }; |
131 | |
132 | // Data storage and type tag |
133 | Payload payload; |
134 | Tag tag; |
135 | |
136 | // Basic ctors and assignments |
137 | EValue(const EValue& rhs) : EValue(rhs.payload, rhs.tag) {} |
138 | |
139 | EValue(EValue&& rhs) noexcept : tag(rhs.tag) { |
140 | moveFrom(std::move(rhs)); |
141 | } |
142 | |
143 | EValue& operator=(EValue&& rhs) & noexcept { |
144 | if (&rhs == this) { |
145 | return *this; |
146 | } |
147 | |
148 | destroy(); |
149 | moveFrom(std::move(rhs)); |
150 | return *this; |
151 | } |
152 | |
153 | EValue& operator=(EValue const& rhs) & { |
154 | // Define copy assignment through copy ctor and move assignment |
155 | *this = EValue(rhs); |
156 | return *this; |
157 | } |
158 | |
159 | ~EValue() { |
160 | destroy(); |
161 | } |
162 | |
163 | /****** None Type ******/ |
164 | EValue() : tag(Tag::None) { |
165 | payload.copyable_union.as_int = 0; |
166 | } |
167 | |
168 | bool isNone() const { |
169 | return tag == Tag::None; |
170 | } |
171 | |
172 | /****** Int Type ******/ |
173 | /*implicit*/ EValue(int64_t i) : tag(Tag::Int) { |
174 | payload.copyable_union.as_int = i; |
175 | } |
176 | |
177 | bool isInt() const { |
178 | return tag == Tag::Int; |
179 | } |
180 | |
181 | int64_t toInt() const { |
182 | ET_CHECK_MSG(isInt(), "EValue is not an int." ); |
183 | return payload.copyable_union.as_int; |
184 | } |
185 | |
186 | /****** Double Type ******/ |
187 | /*implicit*/ EValue(double d) : tag(Tag::Double) { |
188 | payload.copyable_union.as_double = d; |
189 | } |
190 | |
191 | bool isDouble() const { |
192 | return tag == Tag::Double; |
193 | } |
194 | |
195 | double toDouble() const { |
196 | ET_CHECK_MSG(isDouble(), "EValue is not a Double." ); |
197 | return payload.copyable_union.as_double; |
198 | } |
199 | |
200 | /****** Bool Type ******/ |
201 | /*implicit*/ EValue(bool b) : tag(Tag::Bool) { |
202 | payload.copyable_union.as_bool = b; |
203 | } |
204 | |
205 | bool isBool() const { |
206 | return tag == Tag::Bool; |
207 | } |
208 | |
209 | bool toBool() const { |
210 | ET_CHECK_MSG(isBool(), "EValue is not a Bool." ); |
211 | return payload.copyable_union.as_bool; |
212 | } |
213 | |
214 | /****** Scalar Type ******/ |
215 | /// Construct an EValue using the implicit value of a Scalar. |
216 | /*implicit*/ EValue(at::Scalar s) { |
217 | if (s.isIntegral(false)) { |
218 | tag = Tag::Int; |
219 | payload.copyable_union.as_int = s.to<int64_t>(); |
220 | } else if (s.isFloatingPoint()) { |
221 | tag = Tag::Double; |
222 | payload.copyable_union.as_double = s.to<double>(); |
223 | } else if (s.isBoolean()) { |
224 | tag = Tag::Bool; |
225 | payload.copyable_union.as_bool = s.to<bool>(); |
226 | } else { |
227 | ET_CHECK_MSG(false, "Scalar passed to EValue is not initialized." ); |
228 | } |
229 | } |
230 | |
231 | bool isScalar() const { |
232 | return tag == Tag::Int || tag == Tag::Double || tag == Tag::Bool; |
233 | } |
234 | |
235 | at::Scalar toScalar() const { |
236 | // Convert from implicit value to Scalar using implicit constructors. |
237 | |
238 | if (isDouble()) { |
239 | return toDouble(); |
240 | } else if (isInt()) { |
241 | return toInt(); |
242 | } else if (isBool()) { |
243 | return toBool(); |
244 | } else { |
245 | ET_CHECK_MSG(false, "EValue is not a Scalar." ); |
246 | return c10::Scalar(); |
247 | } |
248 | } |
249 | |
250 | /****** Tensor Type ******/ |
251 | /*implicit*/ EValue(at::Tensor t) : tag(Tag::Tensor) { |
252 | // When built in aten mode, at::Tensor has a non trivial constructor |
253 | // destructor, so regular assignment to a union field is UB. Instead we must |
254 | // go through placement new (which causes a refcount bump). |
255 | new (&payload.as_tensor) at::Tensor(t); |
256 | } |
257 | |
258 | bool isTensor() const { |
259 | return tag == Tag::Tensor; |
260 | } |
261 | |
262 | at::Tensor toTensor() && { |
263 | ET_CHECK_MSG(isTensor(), "EValue is not a Tensor." ); |
264 | return std::move(payload.as_tensor); |
265 | } |
266 | |
267 | at::Tensor& toTensor() & { |
268 | ET_CHECK_MSG(isTensor(), "EValue is not a Tensor." ); |
269 | return payload.as_tensor; |
270 | } |
271 | |
272 | const at::Tensor& toTensor() const& { |
273 | ET_CHECK_MSG(isTensor(), "EValue is not a Tensor." ); |
274 | return payload.as_tensor; |
275 | } |
276 | |
277 | /****** String Type ******/ |
278 | /*implicit*/ EValue(const char* s, size_t size) : tag(Tag::String) { |
279 | payload.copyable_union.as_string = at::ArrayRef<char>(s, size); |
280 | } |
281 | |
282 | bool isString() const { |
283 | return tag == Tag::String; |
284 | } |
285 | |
286 | at::string_view toString() const { |
287 | ET_CHECK_MSG(isString(), "EValue is not a String." ); |
288 | return at::string_view( |
289 | payload.copyable_union.as_string.data(), |
290 | payload.copyable_union.as_string.size()); |
291 | } |
292 | |
293 | /****** Int List Type ******/ |
294 | /*implicit*/ EValue(at::ArrayRef<int64_t> i) : tag(Tag::ListInt) { |
295 | payload.copyable_union.as_int_list = i; |
296 | } |
297 | |
298 | bool isIntList() const { |
299 | return tag == Tag::ListInt; |
300 | } |
301 | |
302 | at::ArrayRef<int64_t> toIntList() const { |
303 | ET_CHECK_MSG(isIntList(), "EValue is not an Int List." ); |
304 | return payload.copyable_union.as_int_list; |
305 | } |
306 | |
307 | /****** Bool List Type ******/ |
308 | /*implicit*/ EValue(at::ArrayRef<bool> b) : tag(Tag::ListBool) { |
309 | payload.copyable_union.as_bool_list = b; |
310 | } |
311 | |
312 | bool isBoolList() const { |
313 | return tag == Tag::ListBool; |
314 | } |
315 | |
316 | at::ArrayRef<bool> toBoolList() const { |
317 | ET_CHECK_MSG(isBoolList(), "EValue is not a Bool List." ); |
318 | return payload.copyable_union.as_bool_list; |
319 | } |
320 | |
321 | /****** Double List Type ******/ |
322 | /*implicit*/ EValue(at::ArrayRef<double> d) : tag(Tag::ListDouble) { |
323 | payload.copyable_union.as_double_list = d; |
324 | } |
325 | |
326 | bool isDoubleList() const { |
327 | return tag == Tag::ListDouble; |
328 | } |
329 | |
330 | at::ArrayRef<double> toDoubleList() const { |
331 | ET_CHECK_MSG(isDoubleList(), "EValue is not a Double List." ); |
332 | return payload.copyable_union.as_double_list; |
333 | } |
334 | |
335 | /****** Tensor List Type ******/ |
336 | /*implicit*/ EValue(EValObjectList<at::Tensor> t) : tag(Tag::ListTensor) { |
337 | payload.copyable_union.as_tensor_list = t; |
338 | } |
339 | |
340 | bool isTensorList() const { |
341 | return tag == Tag::ListTensor; |
342 | } |
343 | |
344 | at::ArrayRef<at::Tensor> toTensorList() const { |
345 | ET_CHECK_MSG(isTensorList(), "EValue is not a Tensor List." ); |
346 | return payload.copyable_union.as_tensor_list.get(); |
347 | } |
348 | |
349 | /****** List Optional Tensor Type ******/ |
350 | /*implicit*/ EValue(EValObjectList<at::optional<at::Tensor>> t) |
351 | : tag(Tag::ListOptionalTensor) { |
352 | payload.copyable_union.as_list_optional_tensor = t; |
353 | } |
354 | |
355 | bool isListOptionalTensor() const { |
356 | return tag == Tag::ListOptionalTensor; |
357 | } |
358 | |
359 | at::ArrayRef<at::optional<at::Tensor>> toListOptionalTensor() { |
360 | return payload.copyable_union.as_list_optional_tensor.get(); |
361 | } |
362 | |
363 | /****** ScalarType Type ******/ |
364 | at::ScalarType toScalarType() const { |
365 | ET_CHECK_MSG(isInt(), "EValue is not a ScalarType." ); |
366 | return static_cast<at::ScalarType>(payload.copyable_union.as_int); |
367 | } |
368 | |
369 | /****** MemoryFormat Type ******/ |
370 | at::MemoryFormat toMemoryFormat() const { |
371 | ET_CHECK_MSG(isInt(), "EValue is not a MemoryFormat." ); |
372 | return static_cast<at::MemoryFormat>(payload.copyable_union.as_int); |
373 | } |
374 | |
375 | template <typename T> |
376 | T to() &&; |
377 | |
378 | template <typename T> |
379 | typename evalue_to_ref_overload_return<T>::type to() &; |
380 | |
381 | /** |
382 | * Converts the EValue to an optional object that can represent both T and |
383 | * an uninitialized state. |
384 | */ |
385 | template <typename T> |
386 | inline at::optional<T> toOptional() { |
387 | if (this->isNone()) { |
388 | return at::nullopt; |
389 | } |
390 | return this->to<T>(); |
391 | } |
392 | |
393 | private: |
394 | // Pre cond: the payload value has had its destructor called |
395 | void clearToNone() noexcept { |
396 | payload.copyable_union.as_int = 0; |
397 | tag = Tag::None; |
398 | } |
399 | |
400 | // Shared move logic |
401 | void moveFrom(EValue&& rhs) noexcept { |
402 | if (rhs.isTensor()) { |
403 | new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor)); |
404 | rhs.payload.as_tensor.~Tensor(); |
405 | } else { |
406 | payload.copyable_union = rhs.payload.copyable_union; |
407 | } |
408 | tag = rhs.tag; |
409 | rhs.clearToNone(); |
410 | } |
411 | |
412 | // Destructs stored tensor if there is one |
413 | void destroy() { |
414 | // Necessary for ATen tensor to refcount decrement the intrusive_ptr to |
415 | // tensorimpl that got a refcount increment when we placed it in the evalue, |
416 | // no-op if executorch tensor #ifdef could have a |
417 | // minor performance bump for a code maintainability hit |
418 | if (isTensor()) { |
419 | payload.as_tensor.~Tensor(); |
420 | } else if (isTensorList()) { |
421 | for (auto& tensor : toTensorList()) { |
422 | tensor.~Tensor(); |
423 | } |
424 | } else if (isListOptionalTensor()) { |
425 | for (auto& optional_tensor : toListOptionalTensor()) { |
426 | optional_tensor.~optional(); |
427 | } |
428 | } |
429 | } |
430 | |
431 | EValue(const Payload& p, Tag t) : tag(t) { |
432 | if (isTensor()) { |
433 | new (&payload.as_tensor) at::Tensor(p.as_tensor); |
434 | } else { |
435 | payload.copyable_union = p.copyable_union; |
436 | } |
437 | } |
438 | }; |
439 | |
440 | #define EVALUE_DEFINE_TO(T, method_name) \ |
441 | template <> \ |
442 | inline evalue_to_ref_overload_return<T>::type EValue::to<T>()& { \ |
443 | return static_cast<T>(this->method_name()); \ |
444 | } |
445 | |
446 | template <> |
447 | inline at::Tensor& EValue::to<at::Tensor>() & { |
448 | return this->toTensor(); |
449 | } |
450 | |
451 | EVALUE_DEFINE_TO(at::Scalar, toScalar) |
452 | EVALUE_DEFINE_TO(int64_t, toInt) |
453 | EVALUE_DEFINE_TO(bool, toBool) |
454 | EVALUE_DEFINE_TO(double, toDouble) |
455 | EVALUE_DEFINE_TO(at::string_view, toString) |
456 | EVALUE_DEFINE_TO(at::ScalarType, toScalarType) |
457 | EVALUE_DEFINE_TO(at::MemoryFormat, toMemoryFormat) |
458 | EVALUE_DEFINE_TO(at::optional<at::Tensor>, toOptional<at::Tensor>) |
459 | EVALUE_DEFINE_TO(at::ArrayRef<int64_t>, toIntList) |
460 | EVALUE_DEFINE_TO( |
461 | at::optional<at::ArrayRef<int64_t>>, |
462 | toOptional<at::ArrayRef<int64_t>>) |
463 | EVALUE_DEFINE_TO( |
464 | at::optional<at::ArrayRef<double>>, |
465 | toOptional<at::ArrayRef<double>>) |
466 | EVALUE_DEFINE_TO(at::ArrayRef<at::optional<at::Tensor>>, toListOptionalTensor) |
467 | EVALUE_DEFINE_TO(at::ArrayRef<double>, toDoubleList) |
468 | #undef EVALUE_DEFINE_TO |
469 | |
470 | template <typename T> |
471 | at::ArrayRef<T> EValObjectList<T>::get() const { |
472 | for (size_t i = 0; i < wrapped_vals_.size(); i++) { |
473 | unwrapped_vals_[i] = wrapped_vals_[i]->template to<T>(); |
474 | } |
475 | return at::ArrayRef<T>{unwrapped_vals_, wrapped_vals_.size()}; |
476 | } |
477 | |
478 | } // namespace executor |
479 | } // namespace torch |
480 | |