1#pragma once
2
3#include <ATen/ATen.h>
4/**
5 * WARNING: EValue is a class used by Executorch, for its boxed operators. It
6 * contains similar logic as `IValue` in PyTorch, by providing APIs to convert
7 * boxed values to unboxed values.
8 *
9 * It's mirroring a fbcode internal source file
10 * [`EValue.h`](https://www.internalfb.com/code/fbsource/xplat/executorch/core/values/Evalue.h).
11 *
12 * The reason why we are mirroring this class, is to make sure we have CI job
13 * coverage on torchgen logic, given that torchgen is used for both Executorch
14 * and PyTorch.
15 *
16 * If any of the logic here needs to be changed, please update fbcode version of
17 * `Evalue.h` as well. These two versions will be merged as soon as Executorch
18 * is in OSS (hopefully by Q2 2023).
19 */
20namespace torch {
21namespace executor {
22
23#define ET_CHECK_MSG TORCH_CHECK_MSG
24#define EXECUTORCH_FORALL_TAGS(_) \
25 _(None) \
26 _(Tensor) \
27 _(String) \
28 _(Double) \
29 _(Int) \
30 _(Bool) \
31 _(ListBool) \
32 _(ListDouble) \
33 _(ListInt) \
34 _(ListTensor) \
35 _(ListScalar) \
36 _(ListOptionalTensor)
37
38enum class Tag : uint32_t {
39#define DEFINE_TAG(x) x,
40 EXECUTORCH_FORALL_TAGS(DEFINE_TAG)
41#undef DEFINE_TAG
42};
43
44struct EValue;
45
46template <typename T>
47struct evalue_to_const_ref_overload_return {
48 using type = T;
49};
50
51template <>
52struct evalue_to_const_ref_overload_return<at::Tensor> {
53 using type = const at::Tensor&;
54};
55
56template <typename T>
57struct evalue_to_ref_overload_return {
58 using type = T;
59};
60
61template <>
62struct evalue_to_ref_overload_return<at::Tensor> {
63 using type = at::Tensor&;
64};
65
66/*
67 * Helper class used to correlate EValues in the executor table, with the
68 * unwrapped list of the proper type. Because values in the runtime's values
69 * table can change during execution, we cannot statically allocate list of
70 * objects at deserialization. Imagine the serialized list says index 0 in the
71 * value table is element 2 in the list, but during execution the value in
72 * element 2 changes (in the case of tensor this means the TensorImpl* stored in
73 * the tensor changes). To solve this instead they must be created dynamically
74 * whenever they are used.
75 */
76template <typename T>
77class EValObjectList {
78 public:
79 EValObjectList() = default;
80 /*
81 * Wrapped_vals is a list of pointers into the values table of the runtime
82 * whose destinations correlate with the elements of the list, unwrapped_vals
83 * is a container of the same size whose serves as memory to construct the
84 * unwrapped vals.
85 */
86 EValObjectList(EValue** wrapped_vals, T* unwrapped_vals, int size)
87 : wrapped_vals_(wrapped_vals, size), unwrapped_vals_(unwrapped_vals) {}
88 /*
89 * Constructs and returns the list of T specified by the EValue pointers
90 */
91 at::ArrayRef<T> get() const;
92
93 private:
94 // Source of truth for the list
95 at::ArrayRef<EValue*> wrapped_vals_;
96 // Same size as wrapped_vals
97 mutable T* unwrapped_vals_;
98};
99
100// Aggregate typing system similar to IValue only slimmed down with less
101// functionality, no dependencies on atomic, and fewer supported types to better
102// suit embedded systems (ie no intrusive ptr)
103struct EValue {
104 union Payload {
105 // When in ATen mode at::Tensor is not trivially copyable, this nested union
106 // lets us handle tensor as a special case while leaving the rest of the
107 // fields in a simple state instead of requiring a switch on tag everywhere.
108 union TriviallyCopyablePayload {
109 TriviallyCopyablePayload() : as_int(0) {}
110 // Scalar supported through these 3 types
111 int64_t as_int;
112 double as_double;
113 bool as_bool;
114 // TODO(jakeszwe): convert back to pointers to optimize size of this
115 // struct
116 at::ArrayRef<char> as_string;
117 at::ArrayRef<int64_t> as_int_list;
118 at::ArrayRef<double> as_double_list;
119 at::ArrayRef<bool> as_bool_list;
120 EValObjectList<at::Tensor> as_tensor_list;
121 EValObjectList<at::optional<at::Tensor>> as_list_optional_tensor;
122 } copyable_union;
123
124 // Since a Tensor just holds a TensorImpl*, there's no value to use Tensor*
125 // here.
126 at::Tensor as_tensor;
127
128 Payload() {}
129 ~Payload() {}
130 };
131
132 // Data storage and type tag
133 Payload payload;
134 Tag tag;
135
136 // Basic ctors and assignments
137 EValue(const EValue& rhs) : EValue(rhs.payload, rhs.tag) {}
138
139 EValue(EValue&& rhs) noexcept : tag(rhs.tag) {
140 moveFrom(std::move(rhs));
141 }
142
143 EValue& operator=(EValue&& rhs) & noexcept {
144 if (&rhs == this) {
145 return *this;
146 }
147
148 destroy();
149 moveFrom(std::move(rhs));
150 return *this;
151 }
152
153 EValue& operator=(EValue const& rhs) & {
154 // Define copy assignment through copy ctor and move assignment
155 *this = EValue(rhs);
156 return *this;
157 }
158
159 ~EValue() {
160 destroy();
161 }
162
163 /****** None Type ******/
164 EValue() : tag(Tag::None) {
165 payload.copyable_union.as_int = 0;
166 }
167
168 bool isNone() const {
169 return tag == Tag::None;
170 }
171
172 /****** Int Type ******/
173 /*implicit*/ EValue(int64_t i) : tag(Tag::Int) {
174 payload.copyable_union.as_int = i;
175 }
176
177 bool isInt() const {
178 return tag == Tag::Int;
179 }
180
181 int64_t toInt() const {
182 ET_CHECK_MSG(isInt(), "EValue is not an int.");
183 return payload.copyable_union.as_int;
184 }
185
186 /****** Double Type ******/
187 /*implicit*/ EValue(double d) : tag(Tag::Double) {
188 payload.copyable_union.as_double = d;
189 }
190
191 bool isDouble() const {
192 return tag == Tag::Double;
193 }
194
195 double toDouble() const {
196 ET_CHECK_MSG(isDouble(), "EValue is not a Double.");
197 return payload.copyable_union.as_double;
198 }
199
200 /****** Bool Type ******/
201 /*implicit*/ EValue(bool b) : tag(Tag::Bool) {
202 payload.copyable_union.as_bool = b;
203 }
204
205 bool isBool() const {
206 return tag == Tag::Bool;
207 }
208
209 bool toBool() const {
210 ET_CHECK_MSG(isBool(), "EValue is not a Bool.");
211 return payload.copyable_union.as_bool;
212 }
213
214 /****** Scalar Type ******/
215 /// Construct an EValue using the implicit value of a Scalar.
216 /*implicit*/ EValue(at::Scalar s) {
217 if (s.isIntegral(false)) {
218 tag = Tag::Int;
219 payload.copyable_union.as_int = s.to<int64_t>();
220 } else if (s.isFloatingPoint()) {
221 tag = Tag::Double;
222 payload.copyable_union.as_double = s.to<double>();
223 } else if (s.isBoolean()) {
224 tag = Tag::Bool;
225 payload.copyable_union.as_bool = s.to<bool>();
226 } else {
227 ET_CHECK_MSG(false, "Scalar passed to EValue is not initialized.");
228 }
229 }
230
231 bool isScalar() const {
232 return tag == Tag::Int || tag == Tag::Double || tag == Tag::Bool;
233 }
234
235 at::Scalar toScalar() const {
236 // Convert from implicit value to Scalar using implicit constructors.
237
238 if (isDouble()) {
239 return toDouble();
240 } else if (isInt()) {
241 return toInt();
242 } else if (isBool()) {
243 return toBool();
244 } else {
245 ET_CHECK_MSG(false, "EValue is not a Scalar.");
246 return c10::Scalar();
247 }
248 }
249
250 /****** Tensor Type ******/
251 /*implicit*/ EValue(at::Tensor t) : tag(Tag::Tensor) {
252 // When built in aten mode, at::Tensor has a non trivial constructor
253 // destructor, so regular assignment to a union field is UB. Instead we must
254 // go through placement new (which causes a refcount bump).
255 new (&payload.as_tensor) at::Tensor(t);
256 }
257
258 bool isTensor() const {
259 return tag == Tag::Tensor;
260 }
261
262 at::Tensor toTensor() && {
263 ET_CHECK_MSG(isTensor(), "EValue is not a Tensor.");
264 return std::move(payload.as_tensor);
265 }
266
267 at::Tensor& toTensor() & {
268 ET_CHECK_MSG(isTensor(), "EValue is not a Tensor.");
269 return payload.as_tensor;
270 }
271
272 const at::Tensor& toTensor() const& {
273 ET_CHECK_MSG(isTensor(), "EValue is not a Tensor.");
274 return payload.as_tensor;
275 }
276
277 /****** String Type ******/
278 /*implicit*/ EValue(const char* s, size_t size) : tag(Tag::String) {
279 payload.copyable_union.as_string = at::ArrayRef<char>(s, size);
280 }
281
282 bool isString() const {
283 return tag == Tag::String;
284 }
285
286 at::string_view toString() const {
287 ET_CHECK_MSG(isString(), "EValue is not a String.");
288 return at::string_view(
289 payload.copyable_union.as_string.data(),
290 payload.copyable_union.as_string.size());
291 }
292
293 /****** Int List Type ******/
294 /*implicit*/ EValue(at::ArrayRef<int64_t> i) : tag(Tag::ListInt) {
295 payload.copyable_union.as_int_list = i;
296 }
297
298 bool isIntList() const {
299 return tag == Tag::ListInt;
300 }
301
302 at::ArrayRef<int64_t> toIntList() const {
303 ET_CHECK_MSG(isIntList(), "EValue is not an Int List.");
304 return payload.copyable_union.as_int_list;
305 }
306
307 /****** Bool List Type ******/
308 /*implicit*/ EValue(at::ArrayRef<bool> b) : tag(Tag::ListBool) {
309 payload.copyable_union.as_bool_list = b;
310 }
311
312 bool isBoolList() const {
313 return tag == Tag::ListBool;
314 }
315
316 at::ArrayRef<bool> toBoolList() const {
317 ET_CHECK_MSG(isBoolList(), "EValue is not a Bool List.");
318 return payload.copyable_union.as_bool_list;
319 }
320
321 /****** Double List Type ******/
322 /*implicit*/ EValue(at::ArrayRef<double> d) : tag(Tag::ListDouble) {
323 payload.copyable_union.as_double_list = d;
324 }
325
326 bool isDoubleList() const {
327 return tag == Tag::ListDouble;
328 }
329
330 at::ArrayRef<double> toDoubleList() const {
331 ET_CHECK_MSG(isDoubleList(), "EValue is not a Double List.");
332 return payload.copyable_union.as_double_list;
333 }
334
335 /****** Tensor List Type ******/
336 /*implicit*/ EValue(EValObjectList<at::Tensor> t) : tag(Tag::ListTensor) {
337 payload.copyable_union.as_tensor_list = t;
338 }
339
340 bool isTensorList() const {
341 return tag == Tag::ListTensor;
342 }
343
344 at::ArrayRef<at::Tensor> toTensorList() const {
345 ET_CHECK_MSG(isTensorList(), "EValue is not a Tensor List.");
346 return payload.copyable_union.as_tensor_list.get();
347 }
348
349 /****** List Optional Tensor Type ******/
350 /*implicit*/ EValue(EValObjectList<at::optional<at::Tensor>> t)
351 : tag(Tag::ListOptionalTensor) {
352 payload.copyable_union.as_list_optional_tensor = t;
353 }
354
355 bool isListOptionalTensor() const {
356 return tag == Tag::ListOptionalTensor;
357 }
358
359 at::ArrayRef<at::optional<at::Tensor>> toListOptionalTensor() {
360 return payload.copyable_union.as_list_optional_tensor.get();
361 }
362
363 /****** ScalarType Type ******/
364 at::ScalarType toScalarType() const {
365 ET_CHECK_MSG(isInt(), "EValue is not a ScalarType.");
366 return static_cast<at::ScalarType>(payload.copyable_union.as_int);
367 }
368
369 /****** MemoryFormat Type ******/
370 at::MemoryFormat toMemoryFormat() const {
371 ET_CHECK_MSG(isInt(), "EValue is not a MemoryFormat.");
372 return static_cast<at::MemoryFormat>(payload.copyable_union.as_int);
373 }
374
375 template <typename T>
376 T to() &&;
377
378 template <typename T>
379 typename evalue_to_ref_overload_return<T>::type to() &;
380
381 /**
382 * Converts the EValue to an optional object that can represent both T and
383 * an uninitialized state.
384 */
385 template <typename T>
386 inline at::optional<T> toOptional() {
387 if (this->isNone()) {
388 return at::nullopt;
389 }
390 return this->to<T>();
391 }
392
393 private:
394 // Pre cond: the payload value has had its destructor called
395 void clearToNone() noexcept {
396 payload.copyable_union.as_int = 0;
397 tag = Tag::None;
398 }
399
400 // Shared move logic
401 void moveFrom(EValue&& rhs) noexcept {
402 if (rhs.isTensor()) {
403 new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor));
404 rhs.payload.as_tensor.~Tensor();
405 } else {
406 payload.copyable_union = rhs.payload.copyable_union;
407 }
408 tag = rhs.tag;
409 rhs.clearToNone();
410 }
411
412 // Destructs stored tensor if there is one
413 void destroy() {
414 // Necessary for ATen tensor to refcount decrement the intrusive_ptr to
415 // tensorimpl that got a refcount increment when we placed it in the evalue,
416 // no-op if executorch tensor #ifdef could have a
417 // minor performance bump for a code maintainability hit
418 if (isTensor()) {
419 payload.as_tensor.~Tensor();
420 } else if (isTensorList()) {
421 for (auto& tensor : toTensorList()) {
422 tensor.~Tensor();
423 }
424 } else if (isListOptionalTensor()) {
425 for (auto& optional_tensor : toListOptionalTensor()) {
426 optional_tensor.~optional();
427 }
428 }
429 }
430
431 EValue(const Payload& p, Tag t) : tag(t) {
432 if (isTensor()) {
433 new (&payload.as_tensor) at::Tensor(p.as_tensor);
434 } else {
435 payload.copyable_union = p.copyable_union;
436 }
437 }
438};
439
440#define EVALUE_DEFINE_TO(T, method_name) \
441 template <> \
442 inline evalue_to_ref_overload_return<T>::type EValue::to<T>()& { \
443 return static_cast<T>(this->method_name()); \
444 }
445
446template <>
447inline at::Tensor& EValue::to<at::Tensor>() & {
448 return this->toTensor();
449}
450
451EVALUE_DEFINE_TO(at::Scalar, toScalar)
452EVALUE_DEFINE_TO(int64_t, toInt)
453EVALUE_DEFINE_TO(bool, toBool)
454EVALUE_DEFINE_TO(double, toDouble)
455EVALUE_DEFINE_TO(at::string_view, toString)
456EVALUE_DEFINE_TO(at::ScalarType, toScalarType)
457EVALUE_DEFINE_TO(at::MemoryFormat, toMemoryFormat)
458EVALUE_DEFINE_TO(at::optional<at::Tensor>, toOptional<at::Tensor>)
459EVALUE_DEFINE_TO(at::ArrayRef<int64_t>, toIntList)
460EVALUE_DEFINE_TO(
461 at::optional<at::ArrayRef<int64_t>>,
462 toOptional<at::ArrayRef<int64_t>>)
463EVALUE_DEFINE_TO(
464 at::optional<at::ArrayRef<double>>,
465 toOptional<at::ArrayRef<double>>)
466EVALUE_DEFINE_TO(at::ArrayRef<at::optional<at::Tensor>>, toListOptionalTensor)
467EVALUE_DEFINE_TO(at::ArrayRef<double>, toDoubleList)
468#undef EVALUE_DEFINE_TO
469
470template <typename T>
471at::ArrayRef<T> EValObjectList<T>::get() const {
472 for (size_t i = 0; i < wrapped_vals_.size(); i++) {
473 unwrapped_vals_[i] = wrapped_vals_[i]->template to<T>();
474 }
475 return at::ArrayRef<T>{unwrapped_vals_, wrapped_vals_.size()};
476}
477
478} // namespace executor
479} // namespace torch
480