1 | #pragma once |
2 | |
3 | #include <ATen/core/DimVector.h> |
4 | #include <ATen/core/TensorBody.h> |
5 | #include <ATen/core/blob.h> |
6 | #include <ATen/core/custom_class.h> |
7 | #include <ATen/core/ivalue_to.h> |
8 | #include <ATen/core/jit_type_base.h> |
9 | #include <ATen/core/type_factory.h> |
10 | #include <c10/core/SymFloat.h> |
11 | #include <c10/macros/Export.h> |
12 | #include <c10/util/C++17.h> |
13 | #include <c10/util/MaybeOwned.h> |
14 | #include <c10/util/intrusive_ptr.h> |
15 | #include <typeindex> |
16 | #include <utility> |
17 | |
18 | namespace torch { |
19 | class TORCH_API CustomClassHolder : public c10::intrusive_ptr_target {}; |
20 | namespace jit { |
21 | using ::torch::CustomClassHolder; |
22 | struct Function; |
23 | struct CompilationUnit; |
24 | struct Module; |
25 | } // namespace jit |
26 | } // namespace torch |
27 | namespace c10 { |
28 | template <class Key, class Value> |
29 | class Dict; |
30 | template <class T> |
31 | class List; |
32 | template <class T> |
33 | class IListRef; |
34 | struct IValue; |
35 | struct ClassType; |
36 | struct Type; |
37 | class RRefInterface; |
38 | |
39 | struct ClassType; |
40 | using ClassTypePtr = std::shared_ptr<ClassType>; |
41 | |
42 | TORCH_API bool _fastEqualsForContainer(const IValue& lhs, const IValue& rhs); |
43 | |
44 | TORCH_API torch::jit::Function* checkObjectSortSchema( |
45 | const c10::ClassTypePtr& t, |
46 | std::stringstream& why_not); |
47 | |
48 | // A comparator that checks ordering of two IValues of same type. |
49 | typedef std::function<bool(const IValue& a, const IValue& b)> IValueComparator; |
50 | |
51 | TORCH_API IValueComparator getLessThanComparator(const IValue& v); |
52 | TORCH_API IValueComparator getGreaterThanComparator(const IValue& v); |
53 | |
54 | namespace ivalue { |
55 | struct Tuple; |
56 | struct Future; |
57 | struct Await; |
58 | struct ConstantString; |
59 | struct GenericDict; |
60 | struct Object; |
61 | struct PyObjectHolder; |
62 | struct EnumHolder; |
63 | // We need a ComplexHolder because currently the payloads in the Union |
64 | // only take 64 bits. Since ComplexDouble takes up 128 bits, and is too big |
65 | // to fit in the IValue directly, we indirect complex numbers through an intrusive |
66 | // pointer to ComplexHolder (which contains a c10::complex). |
67 | struct ComplexHolder : c10::intrusive_ptr_target { |
68 | public: |
69 | template <typename T> |
70 | ComplexHolder(c10::complex<T> c) { |
71 | val = convert<decltype(val), c10::complex<T>>(c); |
72 | } |
73 | ComplexHolder() = default; |
74 | c10::complex<double> val; |
75 | }; |
76 | |
77 | // Similar to ComplexHolder, for StreamData3 |
78 | struct StreamData3Holder : c10::intrusive_ptr_target { |
79 | public: |
80 | StreamData3Holder(struct c10::StreamData3 d) { |
81 | val = d; |
82 | } |
83 | StreamData3Holder() = delete; |
84 | struct c10::StreamData3 val; |
85 | }; |
86 | |
87 | } // namespace ivalue |
88 | |
89 | // This is an owning wrapper for a c10::optional<std::vector<T>> |
90 | // that can be implicitly converted to a (non-owning) optional<ArrayRef<T>>. |
91 | // Its purpose is to be used in generated code to keep the vector alive |
92 | // either until the end of a statement (as a temporary), or as a saved arg |
93 | // in autograd. |
94 | template <typename T> |
95 | struct OptionalArray { |
96 | c10::optional<std::vector<T>> list; |
97 | |
98 | OptionalArray()= default; |
99 | OptionalArray(std::vector<T> val) : list(std::move(val)) {} |
100 | |
101 | // Used when saving an argument for the backwards pass. |
102 | OptionalArray& operator=(c10::optional<ArrayRef<T>> ref) { |
103 | if (ref) { |
104 | list = std::vector<T>(ref->begin(), ref->end()); |
105 | } else { |
106 | list = nullopt; |
107 | } |
108 | return *this; |
109 | } |
110 | |
111 | // Used when saving an argument for the backwards pass. |
112 | OptionalArray& operator=(c10::OptionalArrayRef<T> ref) { |
113 | if (ref) { |
114 | list = std::vector<T>(ref->begin(), ref->end()); |
115 | } else { |
116 | list = nullopt; |
117 | } |
118 | return *this; |
119 | } |
120 | |
121 | operator c10::optional<c10::ArrayRef<T>>() { |
122 | if (!list) { |
123 | return nullopt; |
124 | } |
125 | return *list; |
126 | } |
127 | |
128 | operator c10::OptionalArrayRef<T>() { |
129 | if (!list) { |
130 | return nullopt; |
131 | } |
132 | return *list; |
133 | } |
134 | }; |
135 | |
136 | // Capsule is an internal implementation detail of custom C++ classes. We |
137 | // define it as an owning wrapper for |
138 | // c10::intrusive_ptr<torch::CustomClassHolder> This wrapper is here to serve as |
139 | // an abstraction of the type erased custom class object pointer. It also allow |
140 | // pybind11 to treat this as a standalone class to register as a separate type |
141 | // caster, instead of a custom pointer holder which the pointer holder type |
142 | // caster try to "unwrap" it automatically. |
143 | struct Capsule { |
144 | c10::intrusive_ptr<torch::CustomClassHolder> obj_ptr; |
145 | explicit Capsule(c10::intrusive_ptr<torch::CustomClassHolder> ptr) |
146 | : obj_ptr(std::move(ptr)) {} |
147 | }; |
148 | |
149 | // IValue is the generic tagged union used by the interpreter to hold |
150 | // all value types. |
151 | // It is a 16-byte object with an 8-byte payload and an 8-byte tag. |
152 | // The tag is currently 4 bytes to determine the type, and 1 byte |
153 | // to mark whether that type is a subtype of c10::intrusive_ptr_target and needs |
154 | // retain/release calls. |
155 | |
156 | #define TORCH_FORALL_TAGS(_) \ |
157 | _(None) \ |
158 | _(Tensor) \ |
159 | _(Storage) \ |
160 | _(Double) \ |
161 | _(ComplexDouble) \ |
162 | _(Int) \ |
163 | _(SymInt) \ |
164 | _(SymFloat) \ |
165 | _(Bool) \ |
166 | _(Tuple) \ |
167 | _(String) \ |
168 | _(Blob) \ |
169 | _(GenericList) \ |
170 | _(GenericDict) \ |
171 | _(Future) \ |
172 | _(Await) \ |
173 | _(Device) \ |
174 | _(Stream) \ |
175 | _(Object) \ |
176 | _(PyObject) \ |
177 | _(Uninitialized) \ |
178 | _(Capsule) \ |
179 | _(RRef) \ |
180 | _(Quantizer) \ |
181 | _(Generator) \ |
182 | _(Enum) |
183 | |
184 | // [doxygen private] |
185 | // These methods are not actually private but we don't want to document them, so |
186 | // they are marked `@private`, which hides them on the doxygen documentation for |
187 | // this page. |
188 | |
189 | /// IValue (Interpreter Value) is a tagged union over the types |
190 | /// supported by the TorchScript interpreter. IValues contain their |
191 | /// values as an `IValue::Payload`, which holds primitive types |
192 | /// (`int64_t`, `bool`, `double`, `Device`) and `Tensor` as values, |
193 | /// and all other types as a `c10::intrusive_ptr`. In order to |
194 | /// optimize performance of the destructor and related operations by |
195 | /// making the `Tensor` and `c10::intrusive_ptr` paths generate the |
196 | /// same code, we represent a null `c10::intrusive_ptr` as |
197 | /// `UndefinedTensorImpl::singleton()`, *not* `nullptr`. |
198 | /// |
199 | /// IValues are used as inputs to and outputs from the TorchScript interpreter. |
200 | /// To retrieve the value contained within an IValue, use the `.toX()` methods, |
201 | /// where `X` is the type you are trying to get. Note that neither the `.toX()` |
202 | /// methods nor the templated `.to<T>` functions do any kind of casting, they |
203 | /// only unwrap the contained value. For example: |
204 | /// |
205 | /// \rst |
206 | /// .. code-block:: cpp |
207 | /// |
208 | /// // Make the IValue |
209 | /// torch::IValue my_ivalue(26); |
210 | /// std::cout << my_ivalue << "\n"; |
211 | /// |
212 | /// // Unwrap the IValue |
213 | /// int64_t my_int = my_ivalue.toInt(); |
214 | /// std::cout << my_int << "\n"; |
215 | /// |
216 | /// // This will throw an error! |
217 | /// // `my_ivalue` is tagged as an int and cannot be used as another type |
218 | /// torch::Tensor my_tensor = my_ivalue.toTensor(); |
219 | /// \endrst |
220 | struct TORCH_API IValue final { |
221 | IValue(const IValue& rhs) |
222 | : IValue(rhs.payload, rhs.tag) { |
223 | if (isIntrusivePtr() && payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { |
224 | c10::raw::intrusive_ptr::incref(payload.u.as_intrusive_ptr); |
225 | } |
226 | } |
227 | |
228 | IValue(IValue&& rhs) noexcept : tag(rhs.tag) { |
229 | moveFrom(std::move(rhs)); |
230 | } |
231 | |
232 | /// @private [doxygen private] |
233 | ~IValue() { |
234 | destroy(); |
235 | } |
236 | |
237 | C10_ALWAYS_INLINE IValue& operator=(IValue&& rhs) & noexcept { |
238 | if (&rhs == this) { |
239 | return *this; |
240 | } |
241 | |
242 | destroy(); |
243 | moveFrom(std::move(rhs)); |
244 | return *this; |
245 | } |
246 | |
247 | IValue& operator=(IValue const& rhs) & { |
248 | *this = IValue(rhs); |
249 | return *this; |
250 | } |
251 | |
252 | void dump() const; |
253 | |
254 | /** |
255 | * Equality comparison. The semantics are the same as Python's `==`: |
256 | * 1. Numerical types are compared by value. |
257 | * 2. Tensors compute element-wise equality, returning a BoolTensor (see: |
258 | * `torch.eq()`) |
259 | * 3. Strings are compared by value. |
260 | * 4. Sequence types (list, tuple) are compared lexicographically by |
261 | * comparing their elements. Different sequence types never compare equal. |
262 | * 5. Mappings (dict) must have equal (key, value) pairs. |
263 | * 6. If not listed above, the default behavior for is to test identity |
264 | * equality (e.g. pointer equality). |
265 | * |
266 | * Why does this return an IValue instead of a bool? Because in PyTorch, |
267 | * `tensor1 == tensor2` returns a `BoolTensor`, not a bool. |
268 | * |
269 | * NOTE: we (like Python) assume that identity equality implies value equality |
270 | * for efficiency. |
271 | * TODO: need to support customizing equality |
272 | */ |
273 | IValue equals(const IValue& rhs) const; |
274 | /** |
275 | * This implements the same semantics as `bool(lhs == rhs)` in Python. which |
276 | * is the same as `equals()` except for Tensor types. |
277 | */ |
278 | TORCH_API friend bool operator==(const IValue& lhs, const IValue& rhs); |
279 | TORCH_API friend bool operator!=(const IValue& lhs, const IValue& rhs); |
280 | |
281 | /** |
282 | * Identity comparison. Checks if `this` is the same object as `rhs`. The |
283 | * semantics are the same as Python's `is` operator. |
284 | * |
285 | * NOTE: Like in Python, this operation is poorly defined for primitive types |
286 | * like numbers and strings. Prefer to use `==` unless you really want to |
287 | * check identity equality. |
288 | */ |
289 | bool is(const IValue& rhs) const; |
290 | |
291 | /** |
292 | * Hashing for IValues. Returns an IValue-boxed int. |
293 | * |
294 | * Some notes: |
295 | * - Like eager, Tensors are hashed by looking at the pointer. This is not |
296 | * strictly correct because two value-equal tensors with different tensor |
297 | * pointers will hash differently, but we choose to reproduce the eager |
298 | * semantics. |
299 | * - Hashing is not defined on all built-in IValue types (e.g. list and |
300 | * dict), following Python. Calling `hash()` on these types will throw. |
301 | */ |
302 | IValue hash() const { |
303 | return (int64_t)IValue::hash(*this); |
304 | } |
305 | // This is defined because `c10::hash` dispatches to a function of this |
306 | // signature. See the member function `hash()`. |
307 | static size_t hash(const IValue& iv); |
308 | |
309 | /** |
310 | * @private [doxygen private] |
311 | * [container equality] |
312 | * This is an equality implementation that assumes objects with the same |
313 | * identity equal themselves, for efficiency reasons. We primarily have this |
314 | * for consistency, because Python does the same thing. This actually |
315 | * provokes user-visible changes in behavior due to quirks in torch: |
316 | * [tensor1] == [tensor1] -> True (because container equality will first |
317 | * compare identity) [tensor1] == [tensor1_copy] -> RuntimeError: |
318 | * Boolean value of Tensor with more than one value is ambiguous |
319 | */ |
320 | TORCH_API friend bool _fastEqualsForContainer( |
321 | const IValue& lhs, |
322 | const IValue& rhs); |
323 | |
324 | private: |
325 | static bool isAliasOf(const at::Tensor& a, const at::Tensor& b) { |
326 | if (a.is_sparse()) { |
327 | return isAliasOf(a._values(), b) || isAliasOf(a._indices(), b); |
328 | } |
329 | if (b.is_sparse()) { |
330 | return isAliasOf(a, b._values()) || isAliasOf(a, b._indices()); |
331 | } |
332 | if (a.is_sparse_csr()) { |
333 | return isAliasOf(a.values(), b) || |
334 | isAliasOf(a.crow_indices(), b) || |
335 | isAliasOf(a.col_indices(), b); |
336 | } |
337 | if (b.is_sparse_csr()) { |
338 | return isAliasOf(a, b.values()) || |
339 | isAliasOf(a, b.crow_indices()) || |
340 | isAliasOf(a, b.col_indices()); |
341 | } |
342 | |
343 | // Opaque tensors such as the ones constructed by the MKL-DNN backend |
344 | // don't have storage so we just compare their TensorImpls. |
345 | // TODO: Find way to expose alias info for opaque tensors. |
346 | if (!a.has_storage() || !b.has_storage()) { |
347 | return a.unsafeGetTensorImpl() == b.unsafeGetTensorImpl(); |
348 | } |
349 | |
350 | return a.is_alias_of(b); |
351 | } |
352 | |
353 | template <typename T> |
354 | bool isListOf() const; |
355 | |
356 | public: |
357 | /// @private [doxygen private] |
358 | bool isAliasOf(const IValue& rhs) const { |
359 | if (this->tag != rhs.tag) { |
360 | // Trivially don't alias if the type is different |
361 | return false; |
362 | } |
363 | |
364 | // Tensors should be compared based on internal storage |
365 | if (this->isTensor()) { |
366 | return isAliasOf(this->toTensor(), rhs.toTensor()); |
367 | } |
368 | |
369 | if (!isIntrusivePtr()) { |
370 | // Primitive types don't alias anything |
371 | return false; |
372 | } |
373 | |
374 | AT_ASSERT(rhs.isIntrusivePtr()); |
375 | |
376 | // Other types can be compared by their ptr value |
377 | return this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr; |
378 | } |
379 | |
380 | /// @private [doxygen private] |
381 | size_t use_count() const noexcept { |
382 | if (isTensor()) { |
383 | return payload.as_tensor.use_count(); |
384 | } |
385 | |
386 | if (!isIntrusivePtrLegacyBehavior()) { |
387 | return 1; |
388 | } |
389 | |
390 | if (payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()) { |
391 | return 0; |
392 | } |
393 | return c10::raw::intrusive_ptr::use_count(payload.u.as_intrusive_ptr); |
394 | } |
395 | |
396 | /// @private [doxygen private] |
397 | void swap(IValue& rhs) noexcept { |
398 | if (isTensor() && rhs.isTensor()) { |
399 | std::swap(payload.as_tensor, rhs.payload.as_tensor); |
400 | } else if (isTensor()) { |
401 | at::Tensor t = std::move(payload.as_tensor); |
402 | // As far as I can tell, omitting the usual explicit destructor call |
403 | // is not UB in and of itself, and it's a slight perf win. The |
404 | // destructor is a no-op, because the moved-from Tensor is |
405 | // effectively an intrusive_ptr in the null state, so we don't need |
406 | // the behavior for correctness reasons either. Leaving this |
407 | // explanatory comment, including commented-out destructor call, to |
408 | // make this abundantly clear. |
409 | // |
410 | // payload.as_tensor.~Tensor(); |
411 | payload.u = rhs.payload.u; |
412 | new (&rhs.payload.as_tensor) at::Tensor(std::move(t)); |
413 | } else if (rhs.isTensor()) { |
414 | rhs.swap(*this); |
415 | return; |
416 | } else { |
417 | std::swap(payload.u, rhs.payload.u); |
418 | } |
419 | std::swap(tag, rhs.tag); |
420 | } |
421 | |
422 | // Accessors for subtypes are arranged together below |
423 | // While some of these accessors could be generated through templates, |
424 | // we prefer to write them manually for clarity |
425 | |
426 | IValue(at::TensorBase t) : tag(Tag::Tensor) { |
427 | new (&payload.as_tensor) at::Tensor(std::move(t)); |
428 | } |
429 | bool isTensor() const { |
430 | return Tag::Tensor == tag; |
431 | } |
432 | |
433 | private: |
434 | // Outlined error path so that toTensor() can be inlined. |
435 | [[noreturn]] void reportToTensorTypeError() const; |
436 | |
437 | public: |
438 | at::Tensor toTensor() &&; |
439 | at::Tensor& toTensor() &; |
440 | const at::Tensor& toTensor() const&; |
441 | at::TensorImpl* unsafeToTensorImpl() const { |
442 | TORCH_INTERNAL_ASSERT(isTensor()); |
443 | return payload.as_tensor.unsafeGetTensorImpl(); |
444 | } |
445 | |
446 | IValue(at::Storage s) : tag(Tag::Storage) { |
447 | payload.u.as_intrusive_ptr = null_to_undefined_tensor(s.unsafeReleaseStorageImpl()); |
448 | } |
449 | bool isStorage() const { |
450 | return Tag::Storage == tag; |
451 | } |
452 | c10::Storage toStorage() &&; |
453 | c10::Storage toStorage() const&; |
454 | |
455 | const IValue& toIValue() const { |
456 | return *this; |
457 | } |
458 | IValue& toIValue() { |
459 | return *this; |
460 | } |
461 | |
462 | /// @private [doxygen private] |
463 | IValue(intrusive_ptr<caffe2::Blob> blob) |
464 | : tag(Tag::Blob) { |
465 | // TODO (after Tensor merge) If we pass in a Blob holding a Tensor, extract |
466 | // and store it as a Tensor instead. |
467 | payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release()); |
468 | } |
469 | |
470 | /// @private [doxygen private] |
471 | bool isBlob() const { |
472 | return Tag::Blob == tag; |
473 | } |
474 | |
475 | /// @private [doxygen private] |
476 | c10::intrusive_ptr<caffe2::Blob> toBlob() &&; |
477 | |
478 | /// @private [doxygen private] |
479 | c10::intrusive_ptr<caffe2::Blob> toBlob() const&; |
480 | |
481 | // Capsule. No new callsites of these APIs should |
482 | // be introduced. |
483 | static inline IValue make_capsule( |
484 | intrusive_ptr<torch::CustomClassHolder> blob); |
485 | bool isCapsule() const { |
486 | return Tag::Capsule == tag; |
487 | } |
488 | c10::intrusive_ptr<torch::CustomClassHolder> toCapsule() &&; |
489 | c10::intrusive_ptr<torch::CustomClassHolder> toCapsule() const&; |
490 | |
491 | // Custom C++ classes |
492 | template < |
493 | typename T, |
494 | std::enable_if_t< |
495 | std::is_base_of<torch::CustomClassHolder, T>::value, |
496 | int> = 0> |
497 | IValue(intrusive_ptr<T> custom_class); |
498 | bool isCustomClass() const; |
499 | template <typename T> |
500 | c10::intrusive_ptr<T> toCustomClass() &&; |
501 | template <typename T> |
502 | c10::intrusive_ptr<T> toCustomClass() const&; |
503 | |
504 | // Tuple |
505 | IValue(c10::intrusive_ptr<ivalue::Tuple> v); |
506 | |
507 | template < |
508 | typename... Args, |
509 | std::enable_if_t< |
510 | !guts::disjunction< |
511 | std::is_lvalue_reference<Args>..., |
512 | guts::negation<std::is_constructible<IValue, Args>>...>::value, |
513 | std::nullptr_t> = nullptr> |
514 | IValue(const std::tuple<Args...>& t); |
515 | template < |
516 | typename... Args, |
517 | std::enable_if_t< |
518 | !guts::disjunction< |
519 | std::is_lvalue_reference<Args>..., |
520 | guts::negation<std::is_constructible<IValue, Args>>...>::value, |
521 | std::nullptr_t> = nullptr> |
522 | IValue(std::tuple<Args...>&& t); |
523 | bool isTuple() const { |
524 | return Tag::Tuple == tag; |
525 | } |
526 | c10::intrusive_ptr<ivalue::Tuple> toTuple() &&; |
527 | c10::intrusive_ptr<ivalue::Tuple> toTuple() const&; |
528 | C10_NODISCARD ivalue::Tuple& toTupleRef() const; |
529 | |
530 | // Double |
531 | IValue(double d) : tag(Tag::Double) { |
532 | payload.u.as_double = d; |
533 | } |
534 | bool isDouble() const { |
535 | return Tag::Double == tag; |
536 | } |
537 | double toDouble() const { |
538 | AT_ASSERT(isDouble()); |
539 | return payload.u.as_double; |
540 | } |
541 | |
542 | // ComplexDouble |
543 | template <typename T> |
544 | IValue(c10::complex<T> c); |
545 | bool isComplexDouble() const { return Tag::ComplexDouble == tag; } |
546 | c10::complex<double> toComplexDouble() const; |
547 | |
548 | // Future |
549 | IValue(c10::intrusive_ptr<ivalue::Future> v); |
550 | bool isFuture() const { |
551 | return Tag::Future == tag; |
552 | } |
553 | c10::intrusive_ptr<ivalue::Future> toFuture() &&; |
554 | c10::intrusive_ptr<ivalue::Future> toFuture() const&; |
555 | |
556 | IValue(c10::intrusive_ptr<ivalue::Await> v); |
557 | bool isAwait() const { |
558 | return Tag::Await == tag; |
559 | } |
560 | c10::intrusive_ptr<ivalue::Await> toAwait() &&; |
561 | c10::intrusive_ptr<ivalue::Await> toAwait() const&; |
562 | |
563 | // RRef |
564 | IValue(c10::intrusive_ptr<c10::RRefInterface> v); |
565 | bool isRRef() const { |
566 | return Tag::RRef == tag; |
567 | } |
568 | c10::intrusive_ptr<c10::RRefInterface> toRRef() &&; |
569 | c10::intrusive_ptr<c10::RRefInterface> toRRef() const&; |
570 | |
571 | // Quantizer |
572 | IValue(c10::intrusive_ptr<at::Quantizer> v); |
573 | bool isQuantizer() const { |
574 | return Tag::Quantizer == tag; |
575 | } |
576 | c10::intrusive_ptr<at::Quantizer> toQuantizer() &&; |
577 | c10::intrusive_ptr<at::Quantizer> toQuantizer() const&; |
578 | |
579 | // Int |
580 | IValue(int64_t i) : tag(Tag::Int) { |
581 | payload.u.as_int = i; |
582 | } |
583 | |
584 | IValue(c10::SymInt i) { |
585 | if (i.is_symbolic()) { |
586 | tag = Tag::SymInt; |
587 | payload.u.as_intrusive_ptr = i.toSymNodeImpl().release(); |
588 | } else { |
589 | tag = Tag::Int; |
590 | payload.u.as_int = i.as_int_unchecked(); |
591 | } |
592 | } |
593 | |
594 | bool isSymInt() const { |
595 | return Tag::SymInt == tag; |
596 | } |
597 | |
598 | c10::SymInt toSymInt() &&; |
599 | c10::SymInt toSymInt() const&; |
600 | |
601 | IValue(c10::SymFloat i) { |
602 | if (i.is_symbolic()) { |
603 | tag = Tag::SymFloat; |
604 | payload.u.as_intrusive_ptr = i.toSymNodeImpl().release(); |
605 | } else { |
606 | tag = Tag::Double; |
607 | payload.u.as_double = i.as_float_unchecked(); |
608 | } |
609 | } |
610 | |
611 | bool isSymFloat() const { |
612 | return Tag::SymFloat == tag; |
613 | } |
614 | |
615 | c10::SymFloat toSymFloat() &&; |
616 | c10::SymFloat toSymFloat() const&; |
617 | |
618 | // allow you to pass literals (3, 4) without ambiguity |
619 | IValue(int32_t i) : IValue(static_cast<int64_t>(i)) {} |
620 | |
621 | bool isInt() const { |
622 | return Tag::Int == tag; |
623 | } |
624 | |
625 | int64_t toInt() const { |
626 | AT_ASSERT(isInt()); |
627 | return payload.u.as_int; |
628 | } |
629 | |
630 | // Bool |
631 | IValue(bool b) : tag(Tag::Bool) { |
632 | #if defined(__clang__) && defined(__x86_64__) |
633 | // Initializing entire payload stops valgrind's from reporting |
634 | // "jump or move depends on uninitialised value" in IValue copy constructor |
635 | // See https://github.com/pytorch/pytorch/issues/37117 |
636 | payload.u.as_int = b; |
637 | #else |
638 | payload.u.as_bool = b; |
639 | #endif |
640 | } |
641 | bool isBool() const { |
642 | return Tag::Bool == tag; |
643 | } |
644 | bool toBool() const { |
645 | AT_ASSERT(isBool()); |
646 | return payload.u.as_bool; |
647 | } |
648 | |
649 | // IntList |
650 | bool isIntList() const; |
651 | c10::List<int64_t> toIntList() &&; |
652 | c10::List<int64_t> toIntList() const&; |
653 | std::vector<int64_t> toIntVector() const; |
654 | at::DimVector toDimVector() const; |
655 | |
656 | // ConstantString |
657 | IValue(c10::intrusive_ptr<ivalue::ConstantString> v); |
658 | IValue(std::string v); |
659 | IValue(const char* v) : IValue(std::string(v)) {} |
660 | IValue(c10::string_view v) : IValue(std::string(v)) {}; |
661 | bool isString() const { |
662 | return Tag::String == tag; |
663 | } |
664 | c10::intrusive_ptr<ivalue::ConstantString> toString() &&; |
665 | c10::intrusive_ptr<ivalue::ConstantString> toString() const&; |
666 | const std::string& toStringRef() const; |
667 | c10::optional<std::reference_wrapper<const std::string>> toOptionalStringRef() |
668 | const; |
669 | c10::string_view toStringView() const; |
670 | |
671 | // DoubleList |
672 | bool isDoubleList() const; |
673 | c10::List<double> toDoubleList() &&; |
674 | c10::List<double> toDoubleList() const&; |
675 | std::vector<double> toDoubleVector() const; |
676 | |
677 | // ComplexDoubleList |
678 | bool isComplexDoubleList() const; |
679 | c10::List<c10::complex<double>> toComplexDoubleList() &&; |
680 | c10::List<c10::complex<double>> toComplexDoubleList() const&; |
681 | std::vector<c10::complex<double>> toComplexDoubleVector() const; |
682 | |
683 | // BoolList |
684 | bool isBoolList() const; |
685 | c10::List<bool> toBoolList() &&; |
686 | c10::List<bool> toBoolList() const&; |
687 | |
688 | // TensorList |
689 | bool isTensorList() const; |
690 | c10::List<at::Tensor> toTensorList() &&; |
691 | c10::List<at::Tensor> toTensorList() const&; |
692 | std::vector<at::Tensor> toTensorVector() const; |
693 | |
694 | // OptionalTensorList |
695 | bool isOptionalTensorList() const; |
696 | c10::List<c10::optional<at::Tensor>> toOptionalTensorList() &&; |
697 | c10::List<c10::optional<at::Tensor>> toOptionalTensorList() const&; |
698 | std::vector<c10::optional<at::Tensor>> toOptionalTensorVector() const; |
699 | |
700 | // GenericList |
701 | IValue(c10::List<IValue> v); |
702 | bool isList() const { |
703 | return Tag::GenericList == tag; |
704 | } |
705 | c10::List<IValue> toList() &&; |
706 | c10::List<IValue> toList() const&; |
707 | c10::ArrayRef<IValue> toListRef() const; |
708 | |
709 | // Some template constructors of IValue calls another constructor recursively. |
710 | // This SFINAEs the called constructor exists. |
711 | template <class T> |
712 | using enable_if_ivalue_constructible = |
713 | std::enable_if_t<std::is_constructible<IValue, T>::value, std::nullptr_t>; |
714 | |
715 | // The rule for lists is more complicated; the generic constructor is only |
716 | // acceptable if your element isn't SymInt. If you do have a SymInt element, |
717 | // then you must also, at construction time, check if you can decay the list |
718 | // into an int list (this is MANDATORY, as at a use site we may expect |
719 | // toIntList to work even if at the call site you had a SymIntArrayRef |
720 | // argument). In practice, only SymIntArrayRef is used this way, so we |
721 | // didn't bother making it work for the other constructors, we just make sure |
722 | // they're not selectable. |
723 | template <class T> |
724 | using enable_if_list_is_ivalue_constructible = |
725 | std::enable_if_t<std::is_constructible<IValue, T>::value && |
726 | !std::is_same<T, c10::SymInt>::value, std::nullptr_t>; |
727 | |
728 | template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr> |
729 | IValue(c10::List<T>&& v); |
730 | template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr> |
731 | IValue(const c10::List<T>& v); |
732 | template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr> |
733 | IValue(at::ArrayRef<T> v); |
734 | template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr> |
735 | IValue(const std::vector<T>& v); |
736 | template <class T, size_t N> |
737 | IValue(std::array<T, N> v); |
738 | |
739 | // Manual constructors for lists of symints, which decay to int list if |
740 | // possible. To avoid ambiguous overload situations, we template them |
741 | // to prevent implicit conversions |
742 | template <class T> |
743 | using enable_if_symint = |
744 | std::enable_if_t<std::is_same<T, c10::SymInt>::value, std::nullptr_t>; |
745 | |
746 | template <class T, enable_if_symint<T> = nullptr> |
747 | IValue(at::ArrayRef<T> v); |
748 | template <class T, enable_if_symint<T> = nullptr> |
749 | IValue(at::OptionalArrayRef<T> v); |
750 | template <class T, enable_if_symint<T> = nullptr> |
751 | IValue(const std::vector<T>& v); |
752 | |
753 | template <class T> |
754 | using enable_if_ilist_is_ivalue_constructible = std::enable_if_t< |
755 | std::is_constructible<IValue, T>::value && |
756 | std::is_constructible<IValue, typename IListRef<T>::boxed_type>::value && |
757 | !std::is_same<T, c10::SymInt>::value, |
758 | std::nullptr_t>; |
759 | |
760 | template <class T, enable_if_ilist_is_ivalue_constructible<T> = nullptr> |
761 | IValue(c10::IListRef<T> v); |
762 | |
763 | // GenericDict |
764 | IValue(c10::Dict<IValue, IValue> v); |
765 | bool isGenericDict() const { |
766 | return Tag::GenericDict == tag; |
767 | } |
768 | c10::Dict<IValue, IValue> toGenericDict() &&; |
769 | c10::Dict<IValue, IValue> toGenericDict() const&; |
770 | |
771 | template <class Key, class Value> |
772 | IValue(c10::Dict<Key, Value> v); |
773 | |
774 | template <class Key, class Value> |
775 | /// \cond |
776 | /// DOXYGEN_CANNOT_HANDLE_CONSTRUCTORS_WITH_MACROS_SO_EXCLUDE_THIS_LINE_FROM_DOXYGEN |
777 | C10_DEPRECATED_MESSAGE( |
778 | "IValues based on std::unordered_map<K, V> are slow and deprecated. Please use c10::Dict<K, V> instead." ) |
779 | /// \endcond |
780 | IValue(std::unordered_map<Key, Value> v); |
781 | |
782 | template <class T, enable_if_ivalue_constructible<T> = nullptr> |
783 | IValue(c10::optional<T> v); |
784 | template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr> |
785 | IValue(c10::OptionalArrayRef<T> v); |
786 | IValue(c10::nullopt_t); |
787 | |
788 | // ClassType |
789 | IValue(c10::intrusive_ptr<ivalue::Object> v); |
790 | bool isObject() const { |
791 | return tag == Tag::Object; |
792 | } |
793 | c10::intrusive_ptr<ivalue::Object> toObject() &&; |
794 | c10::intrusive_ptr<ivalue::Object> toObject() const&; |
795 | ivalue::Object& toObjectRef() const; |
796 | |
797 | torch::jit::Module toModule() const; |
798 | bool isModule() const; |
799 | |
800 | // PyObject |
801 | IValue(c10::intrusive_ptr<ivalue::PyObjectHolder> v); |
802 | bool isPyObject() const { |
803 | return tag == Tag::PyObject; |
804 | } |
805 | c10::intrusive_ptr<ivalue::PyObjectHolder> toPyObjectHolder() &&; |
806 | c10::intrusive_ptr<ivalue::PyObjectHolder> toPyObjectHolder() const&; |
807 | PyObject* toPyObject() const; |
808 | |
809 | // Enum |
810 | explicit IValue(c10::intrusive_ptr<ivalue::EnumHolder> v); |
811 | bool isEnum() const { |
812 | return tag == Tag::Enum; |
813 | } |
814 | c10::intrusive_ptr<ivalue::EnumHolder> toEnumHolder() &&; |
815 | c10::intrusive_ptr<ivalue::EnumHolder> toEnumHolder() const&; |
816 | |
817 | // None |
818 | IValue() : tag(Tag::None) {} |
819 | bool isNone() const { |
820 | return Tag::None == tag; |
821 | } |
822 | std::string toNone() const { |
823 | AT_ASSERT(isNone()); |
824 | return "None" ; |
825 | } |
826 | |
827 | static IValue uninitialized() { |
828 | auto i = IValue(); |
829 | i.tag = Tag::Uninitialized; |
830 | return i; |
831 | } |
832 | |
833 | // Scalar, which gets encoded as either an Int, a Double or a ComplexDouble |
834 | IValue(const at::Scalar& s) : IValue() { |
835 | // NB: do the symbolic versions first, as isFloatingPoint is true |
836 | // for both SymFloat and double |
837 | if (s.isSymInt()) { |
838 | tag = Tag::SymInt; |
839 | payload.u.as_intrusive_ptr = s.toSymInt().toSymNodeImpl().release(); |
840 | } else if (s.isSymFloat()) { |
841 | tag = Tag::SymFloat; |
842 | payload.u.as_intrusive_ptr = s.toSymFloat().toSymNodeImpl().release(); |
843 | } else if (s.isFloatingPoint()) { |
844 | tag = Tag::Double; |
845 | payload.u.as_double = s.toDouble(); |
846 | } else if (s.isComplex()) { |
847 | *this = s.toComplexDouble(); |
848 | } else if (s.isBoolean()) { |
849 | tag = Tag::Bool; |
850 | payload.u.as_bool = s.toBool(); |
851 | } else { |
852 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(s.isIntegral(false), "Unknown type in Scalar" ); |
853 | tag = Tag::Int; |
854 | payload.u.as_int = s.toLong(); |
855 | } |
856 | } |
857 | |
858 | bool isScalar() const { |
859 | return isDouble() || isInt() || isComplexDouble() || isBool() || isSymInt() || isSymFloat(); |
860 | } |
861 | |
862 | at::Scalar toScalar() const { |
863 | if (isDouble()) |
864 | return toDouble(); |
865 | else if (isInt()) |
866 | return toInt(); |
867 | else if (isComplexDouble()) |
868 | return toComplexDouble(); |
869 | else if (isBool()) |
870 | return toBool(); |
871 | else if (isSymInt()) |
872 | return toSymInt(); |
873 | else if (isSymFloat()) |
874 | return toSymFloat(); |
875 | throw std::runtime_error("IValue is not a Scalar" ); |
876 | } |
877 | |
878 | // Device |
879 | IValue(c10::Device d) : tag(Tag::Device) { |
880 | payload.u.as_device.type = d.type(); |
881 | payload.u.as_device.index = d.index(); |
882 | } |
883 | bool isDevice() const { |
884 | return Tag::Device == tag; |
885 | } |
886 | c10::Device toDevice() const { |
887 | AT_ASSERT(isDevice()); |
888 | return c10::Device(payload.u.as_device.type, payload.u.as_device.index); |
889 | } |
890 | |
891 | // Stream |
892 | IValue(c10::Stream s) |
893 | : tag(Tag::Stream) { |
894 | auto v = c10::make_intrusive<ivalue::StreamData3Holder>(s.pack3()); |
895 | payload.u.as_intrusive_ptr = v.release(); |
896 | } |
897 | c10::Stream toStream() &&; |
898 | c10::Stream toStream() const &; |
899 | bool isStream() const { return Tag::Stream == tag; } |
900 | |
901 | // ScalarType |
902 | IValue(ScalarType t) |
903 | : IValue(static_cast<std::underlying_type<ScalarType>::type>(t)) {} |
904 | at::ScalarType toScalarType() const { |
905 | return static_cast<at::ScalarType>(toInt()); |
906 | } |
907 | |
908 | // Layout |
909 | IValue(Layout l) |
910 | : IValue(static_cast<std::underlying_type<Layout>::type>(l)) {} |
911 | at::Layout toLayout() const { |
912 | return static_cast<at::Layout>(toInt()); |
913 | } |
914 | |
915 | // MemoryFormat |
916 | IValue(MemoryFormat m) |
917 | : IValue(static_cast<std::underlying_type<MemoryFormat>::type>(m)) {} |
918 | at::MemoryFormat toMemoryFormat() const { |
919 | return static_cast<at::MemoryFormat>(toInt()); |
920 | } |
921 | |
922 | // QScheme |
923 | IValue(at::QScheme qscheme) : tag(Tag::Int) { |
924 | payload.u.as_int = static_cast<int64_t>(qscheme); |
925 | } |
926 | |
927 | at::QScheme toQScheme() const { |
928 | return static_cast<at::QScheme>(toInt()); |
929 | } |
930 | |
931 | // Dimname |
932 | IValue(at::Dimname dimname) : IValue(dimname.symbol().toQualString()) {} |
933 | |
934 | at::Dimname toDimname() const { |
935 | return at::Dimname::fromSymbol(Symbol::fromQualString(toStringRef())); |
936 | } |
937 | |
938 | // Generator |
939 | IValue(at::Generator g) : tag(Tag::Generator) { |
940 | payload.u.as_intrusive_ptr = null_to_undefined_tensor(g.unsafeReleaseGeneratorImpl()); |
941 | } |
942 | bool isGenerator() const { |
943 | return Tag::Generator == tag; |
944 | } |
945 | at::Generator toGenerator() &&; |
946 | at::Generator toGenerator() const&; |
947 | |
948 | // for debugging |
949 | std::string tagKind() const { |
950 | switch (tag) { |
951 | #define DEFINE_CASE(x) \ |
952 | case Tag::x: \ |
953 | return #x; |
954 | TORCH_FORALL_TAGS(DEFINE_CASE) |
955 | #undef DEFINE_CASE |
956 | } |
957 | return "InvalidTag(" + c10::guts::to_string(static_cast<int>(tag)) + ")" ; |
958 | } |
959 | |
960 | // generic v.to<at::Tensor>() implementations |
961 | // that can be used in special functions like pop/push |
962 | // that use template meta-programming. |
963 | // prefer the directly named methods when you can, |
964 | // since they are simpler to understand |
965 | |
966 | // Note: if you get linker errors saying one of these is missing, |
967 | // change it to ... && = delete; and you will see better error messages for |
968 | // why However, we cannot commit this because some compiler versions barf on |
969 | // it. |
970 | template <typename T> |
971 | T to() &&; |
972 | template <typename T> |
973 | typename c10::detail::ivalue_to_const_ref_overload_return<T>::type to() const&; |
974 | |
975 | // ToOptional: convert a IValue to the Optional obj that accepts both T and |
976 | // None |
977 | template <typename T> |
978 | optional<T> toOptional(); |
979 | template <typename T> |
980 | optional<T> toOptional() const; |
981 | |
982 | /// @private [doxygen private] |
983 | /// this is a shallow comparison of two IValues to test the object identity |
984 | bool isSameIdentity(const IValue& rhs) const; |
985 | |
986 | // Computes the "official" string representation of an IValue. This produces a |
987 | // TorchScript expression that can be used to recreate an IValue with the same |
988 | // value (e.g. when we are printing constants in the serializer). |
989 | // |
990 | // Callers can use `customFormatter` to override how `repr()` prints out an |
991 | // IValue. This is useful if you have some other environment where you can |
992 | // look up values, and you want to print a reference to that environment (like |
993 | // the serializer's constant table). |
994 | // |
995 | // repr() is not necessarily defined on all objects! |
996 | std::ostream& repr( |
997 | std::ostream& stream, |
998 | std::function<bool(std::ostream&, const IValue& v)> customFormatter) |
999 | const; |
1000 | |
1001 | // Computes an "informal" string representation of an IValue. This should be |
1002 | // used for debugging, or servicing `print()`-like functions. |
1003 | // This is different from `repr()` in that there is no expectation that we can |
1004 | // exactly reconstruct an IValue from the output; feel free to use a |
1005 | // concise/pretty form |
1006 | TORCH_API friend std::ostream& operator<<( |
1007 | std::ostream& out, |
1008 | const IValue& v); |
1009 | |
1010 | bool isPtrType() const { |
1011 | if (isTensor()) { |
1012 | return payload.as_tensor.defined(); |
1013 | } |
1014 | return isIntrusivePtrLegacyBehavior(); |
1015 | } |
1016 | |
1017 | /// @private [doxygen private] |
1018 | const void* internalToPointer() const { |
1019 | TORCH_INTERNAL_ASSERT( |
1020 | isPtrType(), "Can only call internalToPointer() for pointer types" ); |
1021 | if (isTensor()) { |
1022 | return payload.as_tensor.unsafeGetTensorImpl(); |
1023 | } else { |
1024 | return payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton() |
1025 | ? payload.u.as_intrusive_ptr : nullptr; |
1026 | } |
1027 | } |
1028 | |
1029 | template <typename T = c10::PlatformType> |
1030 | TypePtr type() const; |
1031 | |
1032 | // Detect aliased tensors. |
1033 | struct HashAliasedIValue { |
1034 | size_t hashTensor(const at::Tensor& ten) const { |
1035 | if (ten.is_sparse()) { |
1036 | // COO sparse tensors have a "values" tensor and an "indices" tensor |
1037 | // so this will detect overlap of sparse tensors that share a values |
1038 | // tensor, but not sparse tensors that share an indices tensor. |
1039 | return hashTensor(ten._values()); |
1040 | } else if (ten.is_sparse_csr()) { |
1041 | // COO sparse tensors have a "values" tensor and an "indices" tensor |
1042 | // so this will detect overlap of sparse tensors that share a values |
1043 | // tensor, but not sparse tensors that share an indices tensor. |
1044 | return hashTensor(ten.values()); |
1045 | } else if (!ten.has_storage()) { |
1046 | // Opaque tensors such as the ones constructed by the MKL-DNN backend |
1047 | // don't have storage so we just use their TensorImpls. |
1048 | // TODO: Find way to expose alias info for opaque tensors. |
1049 | return reinterpret_cast<size_t>(ten.unsafeGetTensorImpl()); |
1050 | } else { |
1051 | return reinterpret_cast<size_t>( |
1052 | ten.storage().unsafeGetStorageImpl()); |
1053 | } |
1054 | } |
1055 | size_t operator()(const IValue& val) const { |
1056 | if (val.isTensor()) { |
1057 | return hashTensor(val.toTensor()); |
1058 | } |
1059 | // If it is not a Tensor, then two mutable IValues alias each other only |
1060 | // if they are the same pointer. |
1061 | return val.payload.u.as_int; |
1062 | } |
1063 | }; |
1064 | |
1065 | struct CompAliasedIValues { |
1066 | bool operator()(const IValue& lhs, const IValue& rhs) const { |
1067 | return lhs.isAliasOf(rhs); |
1068 | } |
1069 | }; |
1070 | |
1071 | using HashAliasedIValues = |
1072 | std::unordered_set<IValue, HashAliasedIValue, CompAliasedIValues>; |
1073 | using HashAliasedIValueMap = |
1074 | std::unordered_map<IValue, IValue, HashAliasedIValue, CompAliasedIValues>; |
1075 | |
1076 | // Chechs if this and rhs has a subvalues in common. |
1077 | // [t1,t2] and [t2, t3] returns true. |
1078 | bool overlaps(const IValue& rhs) const; |
1079 | |
1080 | // Inserts all subvalues of this in subValues. |
1081 | void getSubValues(HashAliasedIValues& subValues) const; |
1082 | |
1083 | // Apply visitor to every subvalue. |
1084 | // TODO: There are several places that recurse over IValue. This is fragile. |
1085 | // This visitor should be used to recurse over ivalues. |
1086 | void visit(const std::function<bool(const IValue&)>& visitor) const; |
1087 | IValue deepcopy() const; |
1088 | IValue deepcopy(HashAliasedIValueMap& memo) const; |
1089 | |
1090 | private: |
1091 | static c10::intrusive_ptr_target* null_to_undefined_tensor(c10::intrusive_ptr_target* p) { |
1092 | return p ? p : static_cast<c10::intrusive_ptr_target*>(c10::UndefinedTensorImpl::singleton()); |
1093 | } |
1094 | |
1095 | static bool ptrEqual(const IValue& lhs, const IValue& rhs); |
1096 | // NOTE: IValue tags are intentionally private. In the future we may encode |
1097 | // this value different (e.g. using NaN boxing), and this would make it more |
1098 | // costly to determine the tag for all types vs just determining if something |
1099 | // is a particular type. Instead we want clients to use the `isX` methods when |
1100 | // possible. If for perf. reasons you really, absolutely, must have a jump |
1101 | // table, then we can revisit this. |
1102 | enum class Tag : uint32_t { |
1103 | #define DEFINE_TAG(x) x, |
1104 | TORCH_FORALL_TAGS(DEFINE_TAG) |
1105 | #undef DEFINE_TAG |
1106 | }; |
1107 | |
1108 | template < |
1109 | class T, |
1110 | class NullType = c10::detail::intrusive_target_default_null_type<T>> |
1111 | c10::intrusive_ptr<T, NullType> moveToIntrusivePtr(); |
1112 | template < |
1113 | typename T, |
1114 | class NullType = c10::detail::intrusive_target_default_null_type<T>> |
1115 | c10::intrusive_ptr<T, NullType> toIntrusivePtr() const; |
1116 | |
1117 | void destroy() { |
1118 | // We carefully construct this call to both 1) avoid UB by using |
1119 | // the "wrong" one of as_tensor and as_intrusive_ptr and 2) enable |
1120 | // the compiler to generate the same code for each case. It is |
1121 | // surprisingly difficult to get this right. |
1122 | if (isTensor() || isIntrusivePtr()) { |
1123 | c10::intrusive_ptr_target* p = isTensor() ? payload.as_tensor.unsafeGetTensorImpl() : payload.u.as_intrusive_ptr; |
1124 | c10::intrusive_ptr<intrusive_ptr_target, c10::UndefinedTensorImpl>::reclaim(p); |
1125 | // No need to make this destructor call! |
1126 | // payload.as_tensor.~Tensor(); |
1127 | } |
1128 | } |
1129 | |
1130 | C10_ALWAYS_INLINE void moveFrom(IValue&& rhs) noexcept { |
1131 | if (rhs.isTensor()) { |
1132 | new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor)); |
1133 | // As far as I can tell, omitting the usual explicit destructor call |
1134 | // is not UB in and of itself, and it's a slight perf win. The |
1135 | // destructor is a no-op, because the moved-from Tensor is |
1136 | // effectively an intrusive_ptr in the null state, so we don't need |
1137 | // the behavior for correctness reasons either. Leaving this |
1138 | // explanatory comment, including commented-out destructor call, to |
1139 | // make this abundantly clear. |
1140 | // |
1141 | // rhs.payload.as_tensor.~Tensor(); |
1142 | } else { |
1143 | payload.u = rhs.payload.u; |
1144 | } |
1145 | tag = rhs.tag; |
1146 | rhs.clearToNone(); |
1147 | } |
1148 | |
1149 | void clearToNone() noexcept { |
1150 | payload.u.as_int = 0; |
1151 | tag = Tag::None; |
1152 | } |
1153 | |
1154 | bool isIntrusivePtr() const { |
1155 | switch (tag) { |
1156 | case Tag::None: |
1157 | return false; |
1158 | case Tag::Tensor: |
1159 | return false; |
1160 | case Tag::Storage: |
1161 | return true; |
1162 | case Tag::Generator: |
1163 | return true; |
1164 | case Tag::Double: |
1165 | return false; |
1166 | case Tag::ComplexDouble: |
1167 | return true; |
1168 | case Tag::Int: |
1169 | return false; |
1170 | case Tag::SymInt: |
1171 | return true; |
1172 | case Tag::SymFloat: |
1173 | return true; |
1174 | case Tag::Bool: |
1175 | return false; |
1176 | case Tag::Tuple: |
1177 | return true; |
1178 | case Tag::String: |
1179 | return true; |
1180 | case Tag::Blob: |
1181 | return true; |
1182 | case Tag::GenericList: |
1183 | return true; |
1184 | case Tag::GenericDict: |
1185 | return true; |
1186 | case Tag::Future: |
1187 | return true; |
1188 | case Tag::Await: |
1189 | return true; |
1190 | case Tag::Device: |
1191 | return false; |
1192 | case Tag::Stream: |
1193 | return true; |
1194 | case Tag::Object: |
1195 | return true; |
1196 | case Tag::PyObject: |
1197 | return true; |
1198 | case Tag::Uninitialized: |
1199 | return false; |
1200 | case Tag::Capsule: |
1201 | return true; |
1202 | case Tag::RRef: |
1203 | return true; |
1204 | case Tag::Quantizer: |
1205 | return true; |
1206 | case Tag::Enum: |
1207 | return true; |
1208 | } |
1209 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false, "unexpected tag " , static_cast<int>(tag)); |
1210 | return false; |
1211 | } |
1212 | |
1213 | // Storage and Generator were treated specially when |
1214 | // is_intrusive_ptr was stored as explicit state. This getter |
1215 | // preserves the old behavior for use with WeakIValue for now. |
1216 | bool isIntrusivePtrLegacyBehavior() const { |
1217 | if (tag == Tag::Storage || tag == Tag::Generator) { |
1218 | return payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(); |
1219 | } else { |
1220 | return isIntrusivePtr(); |
1221 | } |
1222 | } |
1223 | |
1224 | union Payload { |
1225 | // [TriviallyCopyablePayload] |
1226 | // We use a nested union here so that we can make the copy easy |
1227 | // and efficient in the non-tensor (i.e., trivially copyable) |
1228 | // case. Specifically, we do not have to do a switch-on-tag to |
1229 | // figure out which union member to assign; we can just use |
1230 | // TriviallyCopyablePayload::operator=. |
1231 | union TriviallyCopyablePayload { |
1232 | TriviallyCopyablePayload() : as_int(0) {} |
1233 | int64_t as_int; |
1234 | double as_double; |
1235 | bool as_bool; |
1236 | // Invariant: never nullptr; null state is represented as |
1237 | // c10::UndefinedTensorImpl::singleton() for consistency of |
1238 | // representation with Tensor. |
1239 | c10::intrusive_ptr_target* as_intrusive_ptr; |
1240 | struct { |
1241 | DeviceType type; |
1242 | DeviceIndex index; |
1243 | } as_device; |
1244 | } u; |
1245 | at::Tensor as_tensor; |
1246 | Payload() : u() {} |
1247 | ~Payload() {} |
1248 | }; |
1249 | |
1250 | IValue(const Payload& p, Tag t) : tag(t) { |
1251 | if (isTensor()) { |
1252 | new (&payload.as_tensor) at::Tensor(p.as_tensor); |
1253 | } else { |
1254 | payload.u = p.u; |
1255 | } |
1256 | } |
1257 | |
1258 | template <typename T> |
1259 | struct TagType {}; |
1260 | |
1261 | friend MaybeOwnedTraits<IValue>; |
1262 | |
1263 | Payload payload; |
1264 | Tag tag{IValue::Tag::None}; |
1265 | friend struct WeakIValue; |
1266 | }; |
1267 | |
1268 | struct TORCH_API WeakIValue final { |
1269 | WeakIValue() = default; |
1270 | |
1271 | WeakIValue(const WeakIValue& rhs) |
1272 | : payload(rhs.payload), |
1273 | tag(rhs.tag), |
1274 | is_intrusive_ptr(rhs.is_intrusive_ptr) { |
1275 | if (is_intrusive_ptr && payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { |
1276 | c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr); |
1277 | } |
1278 | } |
1279 | WeakIValue(const IValue& rhs) |
1280 | : tag(rhs.tag), |
1281 | is_intrusive_ptr(rhs.isIntrusivePtrLegacyBehavior()) { |
1282 | if (rhs.isTensor()) { |
1283 | payload.as_intrusive_ptr = rhs.unsafeToTensorImpl(); |
1284 | is_intrusive_ptr = true; |
1285 | } else { |
1286 | payload = rhs.payload.u; |
1287 | } |
1288 | if (is_intrusive_ptr) { |
1289 | if (payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { |
1290 | c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr); |
1291 | } |
1292 | } |
1293 | } |
1294 | WeakIValue(WeakIValue&& rhs) noexcept : WeakIValue() { |
1295 | swap(rhs); |
1296 | } |
1297 | ~WeakIValue() { |
1298 | if (is_intrusive_ptr && payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { |
1299 | c10::raw::weak_intrusive_ptr::decref(payload.as_intrusive_ptr); |
1300 | } |
1301 | } |
1302 | WeakIValue& operator=(WeakIValue&& rhs) & noexcept { |
1303 | WeakIValue(std::move(rhs)).swap(*this); // this also sets rhs to None |
1304 | return *this; |
1305 | } |
1306 | WeakIValue& operator=(WeakIValue const& rhs) & { |
1307 | WeakIValue(rhs).swap(*this); |
1308 | return *this; |
1309 | } |
1310 | void swap(WeakIValue& rhs) noexcept { |
1311 | std::swap(payload, rhs.payload); |
1312 | std::swap(is_intrusive_ptr, rhs.is_intrusive_ptr); |
1313 | std::swap(tag, rhs.tag); |
1314 | } |
1315 | |
1316 | bool isSameIdentity(const WeakIValue& rhs) const { |
1317 | return payload.as_int == rhs.payload.as_int && tag == rhs.tag && |
1318 | is_intrusive_ptr == rhs.is_intrusive_ptr; |
1319 | } |
1320 | |
1321 | IValue lock() const { |
1322 | if (!is_intrusive_ptr) { |
1323 | IValue::Payload newPayload; |
1324 | newPayload.u = payload; |
1325 | return IValue(newPayload, tag); |
1326 | } |
1327 | if (IValue::Tag::Tensor == tag) { |
1328 | auto temp = c10::weak_intrusive_ptr<at::TensorImpl, c10::UndefinedTensorImpl>::reclaim( |
1329 | static_cast<at::TensorImpl*>(payload.as_intrusive_ptr)); |
1330 | c10::intrusive_ptr<at::TensorImpl, c10::UndefinedTensorImpl> ip(temp.lock()); |
1331 | temp.release(); |
1332 | if (!ip) { |
1333 | return IValue(); |
1334 | } else { |
1335 | return IValue(at::Tensor(std::move(ip))); |
1336 | } |
1337 | } else { |
1338 | auto temp = c10::weak_intrusive_ptr<c10::intrusive_ptr_target>::reclaim( |
1339 | payload.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton() |
1340 | ? nullptr |
1341 | : payload.as_intrusive_ptr); |
1342 | IValue::Payload pl; |
1343 | pl.u.as_intrusive_ptr = temp.lock().release(); |
1344 | temp.release(); |
1345 | if (!pl.u.as_intrusive_ptr) { |
1346 | return IValue(); |
1347 | } else { |
1348 | return IValue(pl, tag); |
1349 | } |
1350 | } |
1351 | } |
1352 | |
1353 | size_t use_count() const noexcept { |
1354 | if (!is_intrusive_ptr) { |
1355 | return 1; |
1356 | } |
1357 | auto temp = c10::weak_intrusive_ptr<c10::intrusive_ptr_target, c10::UndefinedTensorImpl>::reclaim( |
1358 | payload.as_intrusive_ptr); |
1359 | size_t result = temp.use_count(); |
1360 | temp.release(); |
1361 | return result; |
1362 | } |
1363 | |
1364 | size_t weak_use_count() const noexcept { |
1365 | if (!is_intrusive_ptr) { |
1366 | return 1; |
1367 | } |
1368 | auto temp = c10::weak_intrusive_ptr<c10::intrusive_ptr_target, c10::UndefinedTensorImpl>::reclaim( |
1369 | payload.as_intrusive_ptr); |
1370 | size_t result = temp.weak_use_count(); |
1371 | temp.release(); |
1372 | return result; |
1373 | } |
1374 | size_t hash() const { |
1375 | return payload.as_int; |
1376 | } |
1377 | |
1378 | private: |
1379 | using Payload = IValue::Payload::TriviallyCopyablePayload; |
1380 | Payload payload; |
1381 | IValue::Tag tag{IValue::Tag::None}; |
1382 | bool is_intrusive_ptr{false}; |
1383 | }; |
1384 | |
1385 | // An owning pointer to a type. When the type is class type, it requires a pair |
1386 | // of shared_ptrs to the class type and its owning CU, so that the class type is |
1387 | // guaranteed to stay alive as long as we hold this object. |
1388 | struct TORCH_API StrongTypePtr { |
1389 | StrongTypePtr( |
1390 | std::shared_ptr<torch::jit::CompilationUnit> cu, |
1391 | TypePtr type); |
1392 | |
1393 | std::shared_ptr<torch::jit::CompilationUnit> cu_; |
1394 | TypePtr type_; |
1395 | }; |
1396 | |
1397 | // [Constant Object Weak CompilationUnit Reference] |
1398 | // A non owning pointer to a type. When a class get inserted as a constant |
1399 | // into a graph, if we used a strong pointer we would have a circular reference |
1400 | // from Object -> CompilationUnit and CompilationUnit -> Graph (which owns the |
1401 | // Constant Object) |
1402 | struct TORCH_API WeakTypePtr { |
1403 | WeakTypePtr( |
1404 | std::weak_ptr<torch::jit::CompilationUnit> cu, |
1405 | TypePtr type); |
1406 | |
1407 | std::weak_ptr<torch::jit::CompilationUnit> cu_; |
1408 | TypePtr type_; |
1409 | }; |
1410 | |
1411 | // internal build errors with std::variant :/ |
1412 | struct WeakOrStrongCompilationUnit { |
1413 | explicit WeakOrStrongCompilationUnit( |
1414 | std::shared_ptr<torch::jit::CompilationUnit> shared_cu) : strong_ptr_(std::move(shared_cu)), weak_ptr_(c10::nullopt) {} |
1415 | |
1416 | explicit WeakOrStrongCompilationUnit( |
1417 | std::weak_ptr<torch::jit::CompilationUnit> weak_cu) : strong_ptr_(c10::nullopt), weak_ptr_(std::move(weak_cu)) {} |
1418 | |
1419 | std::shared_ptr<torch::jit::CompilationUnit> getStrongRefOrThrow() const { |
1420 | TORCH_INTERNAL_ASSERT(strong_ptr_ != c10::nullopt); |
1421 | return *strong_ptr_; |
1422 | } |
1423 | |
1424 | std::weak_ptr<torch::jit::CompilationUnit> getWeakRefOrThrow() const { |
1425 | TORCH_INTERNAL_ASSERT(weak_ptr_ != c10::nullopt); |
1426 | return *weak_ptr_; |
1427 | } |
1428 | |
1429 | bool holdingStrongRef() const { |
1430 | return strong_ptr_ != c10::nullopt; |
1431 | } |
1432 | |
1433 | bool holdingEmptyStrongRef() const { |
1434 | return holdingStrongRef() && *strong_ptr_ == nullptr; |
1435 | } |
1436 | |
1437 | c10::optional<std::shared_ptr<torch::jit::CompilationUnit>> strong_ptr_; |
1438 | c10::optional<std::weak_ptr<torch::jit::CompilationUnit>> weak_ptr_; |
1439 | }; |
1440 | |
1441 | // An Object will hold a non-owning Compilation Unit reference if it is a |
1442 | // Constant in the graph and a Owning reference otherwise |
1443 | struct TORCH_API WeakOrStrongTypePtr { |
1444 | explicit WeakOrStrongTypePtr(WeakTypePtr weak) |
1445 | : cu_(WeakOrStrongCompilationUnit(std::move(weak.cu_))), type_(std::move(weak.type_)) {} |
1446 | explicit WeakOrStrongTypePtr(StrongTypePtr strong) |
1447 | : cu_(WeakOrStrongCompilationUnit(std::move(strong.cu_))), type_(std::move(strong.type_)) {} |
1448 | explicit WeakOrStrongTypePtr(WeakOrStrongCompilationUnit cu, TypePtr type) |
1449 | : cu_(std::move(cu)), type_(std::move(type)) {} |
1450 | WeakTypePtr asWeakTypePtr() const; |
1451 | |
1452 | WeakOrStrongCompilationUnit cu_; |
1453 | TypePtr type_; |
1454 | |
1455 | bool holds_strong_ref() const { |
1456 | return cu_.holdingStrongRef(); |
1457 | } |
1458 | |
1459 | bool holds_empty_strong_ref() const { |
1460 | return cu_.holdingEmptyStrongRef(); |
1461 | } |
1462 | }; |
1463 | |
1464 | |
1465 | } // namespace c10 |
1466 | |
1467 | #include <ATen/core/ivalue_inl.h> // IWYU pragma: keep |
1468 | |