1#pragma once
2
3#include <ATen/core/DimVector.h>
4#include <ATen/core/TensorBody.h>
5#include <ATen/core/blob.h>
6#include <ATen/core/custom_class.h>
7#include <ATen/core/ivalue_to.h>
8#include <ATen/core/jit_type_base.h>
9#include <ATen/core/type_factory.h>
10#include <c10/core/SymFloat.h>
11#include <c10/macros/Export.h>
12#include <c10/util/C++17.h>
13#include <c10/util/MaybeOwned.h>
14#include <c10/util/intrusive_ptr.h>
15#include <typeindex>
16#include <utility>
17
18namespace torch {
19class TORCH_API CustomClassHolder : public c10::intrusive_ptr_target {};
20namespace jit {
21using ::torch::CustomClassHolder;
22struct Function;
23struct CompilationUnit;
24struct Module;
25} // namespace jit
26} // namespace torch
27namespace c10 {
28template <class Key, class Value>
29class Dict;
30template <class T>
31class List;
32template <class T>
33class IListRef;
34struct IValue;
35struct ClassType;
36struct Type;
37class RRefInterface;
38
39struct ClassType;
40using ClassTypePtr = std::shared_ptr<ClassType>;
41
42TORCH_API bool _fastEqualsForContainer(const IValue& lhs, const IValue& rhs);
43
44TORCH_API torch::jit::Function* checkObjectSortSchema(
45 const c10::ClassTypePtr& t,
46 std::stringstream& why_not);
47
48// A comparator that checks ordering of two IValues of same type.
49typedef std::function<bool(const IValue& a, const IValue& b)> IValueComparator;
50
51TORCH_API IValueComparator getLessThanComparator(const IValue& v);
52TORCH_API IValueComparator getGreaterThanComparator(const IValue& v);
53
54namespace ivalue {
55struct Tuple;
56struct Future;
57struct Await;
58struct ConstantString;
59struct GenericDict;
60struct Object;
61struct PyObjectHolder;
62struct EnumHolder;
63// We need a ComplexHolder because currently the payloads in the Union
64// only take 64 bits. Since ComplexDouble takes up 128 bits, and is too big
65// to fit in the IValue directly, we indirect complex numbers through an intrusive
66// pointer to ComplexHolder (which contains a c10::complex).
67struct ComplexHolder : c10::intrusive_ptr_target {
68 public:
69 template <typename T>
70 ComplexHolder(c10::complex<T> c) {
71 val = convert<decltype(val), c10::complex<T>>(c);
72 }
73 ComplexHolder() = default;
74 c10::complex<double> val;
75};
76
77// Similar to ComplexHolder, for StreamData3
78struct StreamData3Holder : c10::intrusive_ptr_target {
79 public:
80 StreamData3Holder(struct c10::StreamData3 d) {
81 val = d;
82 }
83 StreamData3Holder() = delete;
84 struct c10::StreamData3 val;
85};
86
87} // namespace ivalue
88
89// This is an owning wrapper for a c10::optional<std::vector<T>>
90// that can be implicitly converted to a (non-owning) optional<ArrayRef<T>>.
91// Its purpose is to be used in generated code to keep the vector alive
92// either until the end of a statement (as a temporary), or as a saved arg
93// in autograd.
94template <typename T>
95struct OptionalArray {
96 c10::optional<std::vector<T>> list;
97
98 OptionalArray()= default;
99 OptionalArray(std::vector<T> val) : list(std::move(val)) {}
100
101 // Used when saving an argument for the backwards pass.
102 OptionalArray& operator=(c10::optional<ArrayRef<T>> ref) {
103 if (ref) {
104 list = std::vector<T>(ref->begin(), ref->end());
105 } else {
106 list = nullopt;
107 }
108 return *this;
109 }
110
111 // Used when saving an argument for the backwards pass.
112 OptionalArray& operator=(c10::OptionalArrayRef<T> ref) {
113 if (ref) {
114 list = std::vector<T>(ref->begin(), ref->end());
115 } else {
116 list = nullopt;
117 }
118 return *this;
119 }
120
121 operator c10::optional<c10::ArrayRef<T>>() {
122 if (!list) {
123 return nullopt;
124 }
125 return *list;
126 }
127
128 operator c10::OptionalArrayRef<T>() {
129 if (!list) {
130 return nullopt;
131 }
132 return *list;
133 }
134};
135
136// Capsule is an internal implementation detail of custom C++ classes. We
137// define it as an owning wrapper for
138// c10::intrusive_ptr<torch::CustomClassHolder> This wrapper is here to serve as
139// an abstraction of the type erased custom class object pointer. It also allow
140// pybind11 to treat this as a standalone class to register as a separate type
141// caster, instead of a custom pointer holder which the pointer holder type
142// caster try to "unwrap" it automatically.
143struct Capsule {
144 c10::intrusive_ptr<torch::CustomClassHolder> obj_ptr;
145 explicit Capsule(c10::intrusive_ptr<torch::CustomClassHolder> ptr)
146 : obj_ptr(std::move(ptr)) {}
147};
148
149// IValue is the generic tagged union used by the interpreter to hold
150// all value types.
151// It is a 16-byte object with an 8-byte payload and an 8-byte tag.
152// The tag is currently 4 bytes to determine the type, and 1 byte
153// to mark whether that type is a subtype of c10::intrusive_ptr_target and needs
154// retain/release calls.
155
156#define TORCH_FORALL_TAGS(_) \
157 _(None) \
158 _(Tensor) \
159 _(Storage) \
160 _(Double) \
161 _(ComplexDouble) \
162 _(Int) \
163 _(SymInt) \
164 _(SymFloat) \
165 _(Bool) \
166 _(Tuple) \
167 _(String) \
168 _(Blob) \
169 _(GenericList) \
170 _(GenericDict) \
171 _(Future) \
172 _(Await) \
173 _(Device) \
174 _(Stream) \
175 _(Object) \
176 _(PyObject) \
177 _(Uninitialized) \
178 _(Capsule) \
179 _(RRef) \
180 _(Quantizer) \
181 _(Generator) \
182 _(Enum)
183
184// [doxygen private]
185// These methods are not actually private but we don't want to document them, so
186// they are marked `@private`, which hides them on the doxygen documentation for
187// this page.
188
189/// IValue (Interpreter Value) is a tagged union over the types
190/// supported by the TorchScript interpreter. IValues contain their
191/// values as an `IValue::Payload`, which holds primitive types
192/// (`int64_t`, `bool`, `double`, `Device`) and `Tensor` as values,
193/// and all other types as a `c10::intrusive_ptr`. In order to
194/// optimize performance of the destructor and related operations by
195/// making the `Tensor` and `c10::intrusive_ptr` paths generate the
196/// same code, we represent a null `c10::intrusive_ptr` as
197/// `UndefinedTensorImpl::singleton()`, *not* `nullptr`.
198///
199/// IValues are used as inputs to and outputs from the TorchScript interpreter.
200/// To retrieve the value contained within an IValue, use the `.toX()` methods,
201/// where `X` is the type you are trying to get. Note that neither the `.toX()`
202/// methods nor the templated `.to<T>` functions do any kind of casting, they
203/// only unwrap the contained value. For example:
204///
205/// \rst
206/// .. code-block:: cpp
207///
208/// // Make the IValue
209/// torch::IValue my_ivalue(26);
210/// std::cout << my_ivalue << "\n";
211///
212/// // Unwrap the IValue
213/// int64_t my_int = my_ivalue.toInt();
214/// std::cout << my_int << "\n";
215///
216/// // This will throw an error!
217/// // `my_ivalue` is tagged as an int and cannot be used as another type
218/// torch::Tensor my_tensor = my_ivalue.toTensor();
219/// \endrst
220struct TORCH_API IValue final {
221 IValue(const IValue& rhs)
222 : IValue(rhs.payload, rhs.tag) {
223 if (isIntrusivePtr() && payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
224 c10::raw::intrusive_ptr::incref(payload.u.as_intrusive_ptr);
225 }
226 }
227
228 IValue(IValue&& rhs) noexcept : tag(rhs.tag) {
229 moveFrom(std::move(rhs));
230 }
231
232 /// @private [doxygen private]
233 ~IValue() {
234 destroy();
235 }
236
237 C10_ALWAYS_INLINE IValue& operator=(IValue&& rhs) & noexcept {
238 if (&rhs == this) {
239 return *this;
240 }
241
242 destroy();
243 moveFrom(std::move(rhs));
244 return *this;
245 }
246
247 IValue& operator=(IValue const& rhs) & {
248 *this = IValue(rhs);
249 return *this;
250 }
251
252 void dump() const;
253
254 /**
255 * Equality comparison. The semantics are the same as Python's `==`:
256 * 1. Numerical types are compared by value.
257 * 2. Tensors compute element-wise equality, returning a BoolTensor (see:
258 * `torch.eq()`)
259 * 3. Strings are compared by value.
260 * 4. Sequence types (list, tuple) are compared lexicographically by
261 * comparing their elements. Different sequence types never compare equal.
262 * 5. Mappings (dict) must have equal (key, value) pairs.
263 * 6. If not listed above, the default behavior for is to test identity
264 * equality (e.g. pointer equality).
265 *
266 * Why does this return an IValue instead of a bool? Because in PyTorch,
267 * `tensor1 == tensor2` returns a `BoolTensor`, not a bool.
268 *
269 * NOTE: we (like Python) assume that identity equality implies value equality
270 * for efficiency.
271 * TODO: need to support customizing equality
272 */
273 IValue equals(const IValue& rhs) const;
274 /**
275 * This implements the same semantics as `bool(lhs == rhs)` in Python. which
276 * is the same as `equals()` except for Tensor types.
277 */
278 TORCH_API friend bool operator==(const IValue& lhs, const IValue& rhs);
279 TORCH_API friend bool operator!=(const IValue& lhs, const IValue& rhs);
280
281 /**
282 * Identity comparison. Checks if `this` is the same object as `rhs`. The
283 * semantics are the same as Python's `is` operator.
284 *
285 * NOTE: Like in Python, this operation is poorly defined for primitive types
286 * like numbers and strings. Prefer to use `==` unless you really want to
287 * check identity equality.
288 */
289 bool is(const IValue& rhs) const;
290
291 /**
292 * Hashing for IValues. Returns an IValue-boxed int.
293 *
294 * Some notes:
295 * - Like eager, Tensors are hashed by looking at the pointer. This is not
296 * strictly correct because two value-equal tensors with different tensor
297 * pointers will hash differently, but we choose to reproduce the eager
298 * semantics.
299 * - Hashing is not defined on all built-in IValue types (e.g. list and
300 * dict), following Python. Calling `hash()` on these types will throw.
301 */
302 IValue hash() const {
303 return (int64_t)IValue::hash(*this);
304 }
305 // This is defined because `c10::hash` dispatches to a function of this
306 // signature. See the member function `hash()`.
307 static size_t hash(const IValue& iv);
308
309 /**
310 * @private [doxygen private]
311 * [container equality]
312 * This is an equality implementation that assumes objects with the same
313 * identity equal themselves, for efficiency reasons. We primarily have this
314 * for consistency, because Python does the same thing. This actually
315 * provokes user-visible changes in behavior due to quirks in torch:
316 * [tensor1] == [tensor1] -> True (because container equality will first
317 * compare identity) [tensor1] == [tensor1_copy] -> RuntimeError:
318 * Boolean value of Tensor with more than one value is ambiguous
319 */
320 TORCH_API friend bool _fastEqualsForContainer(
321 const IValue& lhs,
322 const IValue& rhs);
323
324private:
325 static bool isAliasOf(const at::Tensor& a, const at::Tensor& b) {
326 if (a.is_sparse()) {
327 return isAliasOf(a._values(), b) || isAliasOf(a._indices(), b);
328 }
329 if (b.is_sparse()) {
330 return isAliasOf(a, b._values()) || isAliasOf(a, b._indices());
331 }
332 if (a.is_sparse_csr()) {
333 return isAliasOf(a.values(), b) ||
334 isAliasOf(a.crow_indices(), b) ||
335 isAliasOf(a.col_indices(), b);
336 }
337 if (b.is_sparse_csr()) {
338 return isAliasOf(a, b.values()) ||
339 isAliasOf(a, b.crow_indices()) ||
340 isAliasOf(a, b.col_indices());
341 }
342
343 // Opaque tensors such as the ones constructed by the MKL-DNN backend
344 // don't have storage so we just compare their TensorImpls.
345 // TODO: Find way to expose alias info for opaque tensors.
346 if (!a.has_storage() || !b.has_storage()) {
347 return a.unsafeGetTensorImpl() == b.unsafeGetTensorImpl();
348 }
349
350 return a.is_alias_of(b);
351 }
352
353 template <typename T>
354 bool isListOf() const;
355
356public:
357 /// @private [doxygen private]
358 bool isAliasOf(const IValue& rhs) const {
359 if (this->tag != rhs.tag) {
360 // Trivially don't alias if the type is different
361 return false;
362 }
363
364 // Tensors should be compared based on internal storage
365 if (this->isTensor()) {
366 return isAliasOf(this->toTensor(), rhs.toTensor());
367 }
368
369 if (!isIntrusivePtr()) {
370 // Primitive types don't alias anything
371 return false;
372 }
373
374 AT_ASSERT(rhs.isIntrusivePtr());
375
376 // Other types can be compared by their ptr value
377 return this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr;
378 }
379
380 /// @private [doxygen private]
381 size_t use_count() const noexcept {
382 if (isTensor()) {
383 return payload.as_tensor.use_count();
384 }
385
386 if (!isIntrusivePtrLegacyBehavior()) {
387 return 1;
388 }
389
390 if (payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()) {
391 return 0;
392 }
393 return c10::raw::intrusive_ptr::use_count(payload.u.as_intrusive_ptr);
394 }
395
396 /// @private [doxygen private]
397 void swap(IValue& rhs) noexcept {
398 if (isTensor() && rhs.isTensor()) {
399 std::swap(payload.as_tensor, rhs.payload.as_tensor);
400 } else if (isTensor()) {
401 at::Tensor t = std::move(payload.as_tensor);
402 // As far as I can tell, omitting the usual explicit destructor call
403 // is not UB in and of itself, and it's a slight perf win. The
404 // destructor is a no-op, because the moved-from Tensor is
405 // effectively an intrusive_ptr in the null state, so we don't need
406 // the behavior for correctness reasons either. Leaving this
407 // explanatory comment, including commented-out destructor call, to
408 // make this abundantly clear.
409 //
410 // payload.as_tensor.~Tensor();
411 payload.u = rhs.payload.u;
412 new (&rhs.payload.as_tensor) at::Tensor(std::move(t));
413 } else if (rhs.isTensor()) {
414 rhs.swap(*this);
415 return;
416 } else {
417 std::swap(payload.u, rhs.payload.u);
418 }
419 std::swap(tag, rhs.tag);
420 }
421
422 // Accessors for subtypes are arranged together below
423 // While some of these accessors could be generated through templates,
424 // we prefer to write them manually for clarity
425
426 IValue(at::TensorBase t) : tag(Tag::Tensor) {
427 new (&payload.as_tensor) at::Tensor(std::move(t));
428 }
429 bool isTensor() const {
430 return Tag::Tensor == tag;
431 }
432
433 private:
434 // Outlined error path so that toTensor() can be inlined.
435 [[noreturn]] void reportToTensorTypeError() const;
436
437 public:
438 at::Tensor toTensor() &&;
439 at::Tensor& toTensor() &;
440 const at::Tensor& toTensor() const&;
441 at::TensorImpl* unsafeToTensorImpl() const {
442 TORCH_INTERNAL_ASSERT(isTensor());
443 return payload.as_tensor.unsafeGetTensorImpl();
444 }
445
446 IValue(at::Storage s) : tag(Tag::Storage) {
447 payload.u.as_intrusive_ptr = null_to_undefined_tensor(s.unsafeReleaseStorageImpl());
448 }
449 bool isStorage() const {
450 return Tag::Storage == tag;
451 }
452 c10::Storage toStorage() &&;
453 c10::Storage toStorage() const&;
454
455 const IValue& toIValue() const {
456 return *this;
457 }
458 IValue& toIValue() {
459 return *this;
460 }
461
462 /// @private [doxygen private]
463 IValue(intrusive_ptr<caffe2::Blob> blob)
464 : tag(Tag::Blob) {
465 // TODO (after Tensor merge) If we pass in a Blob holding a Tensor, extract
466 // and store it as a Tensor instead.
467 payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release());
468 }
469
470 /// @private [doxygen private]
471 bool isBlob() const {
472 return Tag::Blob == tag;
473 }
474
475 /// @private [doxygen private]
476 c10::intrusive_ptr<caffe2::Blob> toBlob() &&;
477
478 /// @private [doxygen private]
479 c10::intrusive_ptr<caffe2::Blob> toBlob() const&;
480
481 // Capsule. No new callsites of these APIs should
482 // be introduced.
483 static inline IValue make_capsule(
484 intrusive_ptr<torch::CustomClassHolder> blob);
485 bool isCapsule() const {
486 return Tag::Capsule == tag;
487 }
488 c10::intrusive_ptr<torch::CustomClassHolder> toCapsule() &&;
489 c10::intrusive_ptr<torch::CustomClassHolder> toCapsule() const&;
490
491 // Custom C++ classes
492 template <
493 typename T,
494 std::enable_if_t<
495 std::is_base_of<torch::CustomClassHolder, T>::value,
496 int> = 0>
497 IValue(intrusive_ptr<T> custom_class);
498 bool isCustomClass() const;
499 template <typename T>
500 c10::intrusive_ptr<T> toCustomClass() &&;
501 template <typename T>
502 c10::intrusive_ptr<T> toCustomClass() const&;
503
504 // Tuple
505 IValue(c10::intrusive_ptr<ivalue::Tuple> v);
506
507 template <
508 typename... Args,
509 std::enable_if_t<
510 !guts::disjunction<
511 std::is_lvalue_reference<Args>...,
512 guts::negation<std::is_constructible<IValue, Args>>...>::value,
513 std::nullptr_t> = nullptr>
514 IValue(const std::tuple<Args...>& t);
515 template <
516 typename... Args,
517 std::enable_if_t<
518 !guts::disjunction<
519 std::is_lvalue_reference<Args>...,
520 guts::negation<std::is_constructible<IValue, Args>>...>::value,
521 std::nullptr_t> = nullptr>
522 IValue(std::tuple<Args...>&& t);
523 bool isTuple() const {
524 return Tag::Tuple == tag;
525 }
526 c10::intrusive_ptr<ivalue::Tuple> toTuple() &&;
527 c10::intrusive_ptr<ivalue::Tuple> toTuple() const&;
528 C10_NODISCARD ivalue::Tuple& toTupleRef() const;
529
530 // Double
531 IValue(double d) : tag(Tag::Double) {
532 payload.u.as_double = d;
533 }
534 bool isDouble() const {
535 return Tag::Double == tag;
536 }
537 double toDouble() const {
538 AT_ASSERT(isDouble());
539 return payload.u.as_double;
540 }
541
542 // ComplexDouble
543 template <typename T>
544 IValue(c10::complex<T> c);
545 bool isComplexDouble() const { return Tag::ComplexDouble == tag; }
546 c10::complex<double> toComplexDouble() const;
547
548 // Future
549 IValue(c10::intrusive_ptr<ivalue::Future> v);
550 bool isFuture() const {
551 return Tag::Future == tag;
552 }
553 c10::intrusive_ptr<ivalue::Future> toFuture() &&;
554 c10::intrusive_ptr<ivalue::Future> toFuture() const&;
555
556 IValue(c10::intrusive_ptr<ivalue::Await> v);
557 bool isAwait() const {
558 return Tag::Await == tag;
559 }
560 c10::intrusive_ptr<ivalue::Await> toAwait() &&;
561 c10::intrusive_ptr<ivalue::Await> toAwait() const&;
562
563 // RRef
564 IValue(c10::intrusive_ptr<c10::RRefInterface> v);
565 bool isRRef() const {
566 return Tag::RRef == tag;
567 }
568 c10::intrusive_ptr<c10::RRefInterface> toRRef() &&;
569 c10::intrusive_ptr<c10::RRefInterface> toRRef() const&;
570
571 // Quantizer
572 IValue(c10::intrusive_ptr<at::Quantizer> v);
573 bool isQuantizer() const {
574 return Tag::Quantizer == tag;
575 }
576 c10::intrusive_ptr<at::Quantizer> toQuantizer() &&;
577 c10::intrusive_ptr<at::Quantizer> toQuantizer() const&;
578
579 // Int
580 IValue(int64_t i) : tag(Tag::Int) {
581 payload.u.as_int = i;
582 }
583
584 IValue(c10::SymInt i) {
585 if (i.is_symbolic()) {
586 tag = Tag::SymInt;
587 payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();
588 } else {
589 tag = Tag::Int;
590 payload.u.as_int = i.as_int_unchecked();
591 }
592 }
593
594 bool isSymInt() const {
595 return Tag::SymInt == tag;
596 }
597
598 c10::SymInt toSymInt() &&;
599 c10::SymInt toSymInt() const&;
600
601 IValue(c10::SymFloat i) {
602 if (i.is_symbolic()) {
603 tag = Tag::SymFloat;
604 payload.u.as_intrusive_ptr = i.toSymNodeImpl().release();
605 } else {
606 tag = Tag::Double;
607 payload.u.as_double = i.as_float_unchecked();
608 }
609 }
610
611 bool isSymFloat() const {
612 return Tag::SymFloat == tag;
613 }
614
615 c10::SymFloat toSymFloat() &&;
616 c10::SymFloat toSymFloat() const&;
617
618 // allow you to pass literals (3, 4) without ambiguity
619 IValue(int32_t i) : IValue(static_cast<int64_t>(i)) {}
620
621 bool isInt() const {
622 return Tag::Int == tag;
623 }
624
625 int64_t toInt() const {
626 AT_ASSERT(isInt());
627 return payload.u.as_int;
628 }
629
630 // Bool
631 IValue(bool b) : tag(Tag::Bool) {
632#if defined(__clang__) && defined(__x86_64__)
633 // Initializing entire payload stops valgrind's from reporting
634 // "jump or move depends on uninitialised value" in IValue copy constructor
635 // See https://github.com/pytorch/pytorch/issues/37117
636 payload.u.as_int = b;
637#else
638 payload.u.as_bool = b;
639#endif
640 }
641 bool isBool() const {
642 return Tag::Bool == tag;
643 }
644 bool toBool() const {
645 AT_ASSERT(isBool());
646 return payload.u.as_bool;
647 }
648
649 // IntList
650 bool isIntList() const;
651 c10::List<int64_t> toIntList() &&;
652 c10::List<int64_t> toIntList() const&;
653 std::vector<int64_t> toIntVector() const;
654 at::DimVector toDimVector() const;
655
656 // ConstantString
657 IValue(c10::intrusive_ptr<ivalue::ConstantString> v);
658 IValue(std::string v);
659 IValue(const char* v) : IValue(std::string(v)) {}
660 IValue(c10::string_view v) : IValue(std::string(v)) {};
661 bool isString() const {
662 return Tag::String == tag;
663 }
664 c10::intrusive_ptr<ivalue::ConstantString> toString() &&;
665 c10::intrusive_ptr<ivalue::ConstantString> toString() const&;
666 const std::string& toStringRef() const;
667 c10::optional<std::reference_wrapper<const std::string>> toOptionalStringRef()
668 const;
669 c10::string_view toStringView() const;
670
671 // DoubleList
672 bool isDoubleList() const;
673 c10::List<double> toDoubleList() &&;
674 c10::List<double> toDoubleList() const&;
675 std::vector<double> toDoubleVector() const;
676
677 // ComplexDoubleList
678 bool isComplexDoubleList() const;
679 c10::List<c10::complex<double>> toComplexDoubleList() &&;
680 c10::List<c10::complex<double>> toComplexDoubleList() const&;
681 std::vector<c10::complex<double>> toComplexDoubleVector() const;
682
683 // BoolList
684 bool isBoolList() const;
685 c10::List<bool> toBoolList() &&;
686 c10::List<bool> toBoolList() const&;
687
688 // TensorList
689 bool isTensorList() const;
690 c10::List<at::Tensor> toTensorList() &&;
691 c10::List<at::Tensor> toTensorList() const&;
692 std::vector<at::Tensor> toTensorVector() const;
693
694 // OptionalTensorList
695 bool isOptionalTensorList() const;
696 c10::List<c10::optional<at::Tensor>> toOptionalTensorList() &&;
697 c10::List<c10::optional<at::Tensor>> toOptionalTensorList() const&;
698 std::vector<c10::optional<at::Tensor>> toOptionalTensorVector() const;
699
700 // GenericList
701 IValue(c10::List<IValue> v);
702 bool isList() const {
703 return Tag::GenericList == tag;
704 }
705 c10::List<IValue> toList() &&;
706 c10::List<IValue> toList() const&;
707 c10::ArrayRef<IValue> toListRef() const;
708
709 // Some template constructors of IValue calls another constructor recursively.
710 // This SFINAEs the called constructor exists.
711 template <class T>
712 using enable_if_ivalue_constructible =
713 std::enable_if_t<std::is_constructible<IValue, T>::value, std::nullptr_t>;
714
715 // The rule for lists is more complicated; the generic constructor is only
716 // acceptable if your element isn't SymInt. If you do have a SymInt element,
717 // then you must also, at construction time, check if you can decay the list
718 // into an int list (this is MANDATORY, as at a use site we may expect
719 // toIntList to work even if at the call site you had a SymIntArrayRef
720 // argument). In practice, only SymIntArrayRef is used this way, so we
721 // didn't bother making it work for the other constructors, we just make sure
722 // they're not selectable.
723 template <class T>
724 using enable_if_list_is_ivalue_constructible =
725 std::enable_if_t<std::is_constructible<IValue, T>::value &&
726 !std::is_same<T, c10::SymInt>::value, std::nullptr_t>;
727
728 template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr>
729 IValue(c10::List<T>&& v);
730 template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr>
731 IValue(const c10::List<T>& v);
732 template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr>
733 IValue(at::ArrayRef<T> v);
734 template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr>
735 IValue(const std::vector<T>& v);
736 template <class T, size_t N>
737 IValue(std::array<T, N> v);
738
739 // Manual constructors for lists of symints, which decay to int list if
740 // possible. To avoid ambiguous overload situations, we template them
741 // to prevent implicit conversions
742 template <class T>
743 using enable_if_symint =
744 std::enable_if_t<std::is_same<T, c10::SymInt>::value, std::nullptr_t>;
745
746 template <class T, enable_if_symint<T> = nullptr>
747 IValue(at::ArrayRef<T> v);
748 template <class T, enable_if_symint<T> = nullptr>
749 IValue(at::OptionalArrayRef<T> v);
750 template <class T, enable_if_symint<T> = nullptr>
751 IValue(const std::vector<T>& v);
752
753 template <class T>
754 using enable_if_ilist_is_ivalue_constructible = std::enable_if_t<
755 std::is_constructible<IValue, T>::value &&
756 std::is_constructible<IValue, typename IListRef<T>::boxed_type>::value &&
757 !std::is_same<T, c10::SymInt>::value,
758 std::nullptr_t>;
759
760 template <class T, enable_if_ilist_is_ivalue_constructible<T> = nullptr>
761 IValue(c10::IListRef<T> v);
762
763 // GenericDict
764 IValue(c10::Dict<IValue, IValue> v);
765 bool isGenericDict() const {
766 return Tag::GenericDict == tag;
767 }
768 c10::Dict<IValue, IValue> toGenericDict() &&;
769 c10::Dict<IValue, IValue> toGenericDict() const&;
770
771 template <class Key, class Value>
772 IValue(c10::Dict<Key, Value> v);
773
774 template <class Key, class Value>
775 /// \cond
776 /// DOXYGEN_CANNOT_HANDLE_CONSTRUCTORS_WITH_MACROS_SO_EXCLUDE_THIS_LINE_FROM_DOXYGEN
777 C10_DEPRECATED_MESSAGE(
778 "IValues based on std::unordered_map<K, V> are slow and deprecated. Please use c10::Dict<K, V> instead.")
779 /// \endcond
780 IValue(std::unordered_map<Key, Value> v);
781
782 template <class T, enable_if_ivalue_constructible<T> = nullptr>
783 IValue(c10::optional<T> v);
784 template <class T, enable_if_list_is_ivalue_constructible<T> = nullptr>
785 IValue(c10::OptionalArrayRef<T> v);
786 IValue(c10::nullopt_t);
787
788 // ClassType
789 IValue(c10::intrusive_ptr<ivalue::Object> v);
790 bool isObject() const {
791 return tag == Tag::Object;
792 }
793 c10::intrusive_ptr<ivalue::Object> toObject() &&;
794 c10::intrusive_ptr<ivalue::Object> toObject() const&;
795 ivalue::Object& toObjectRef() const;
796
797 torch::jit::Module toModule() const;
798 bool isModule() const;
799
800 // PyObject
801 IValue(c10::intrusive_ptr<ivalue::PyObjectHolder> v);
802 bool isPyObject() const {
803 return tag == Tag::PyObject;
804 }
805 c10::intrusive_ptr<ivalue::PyObjectHolder> toPyObjectHolder() &&;
806 c10::intrusive_ptr<ivalue::PyObjectHolder> toPyObjectHolder() const&;
807 PyObject* toPyObject() const;
808
809 // Enum
810 explicit IValue(c10::intrusive_ptr<ivalue::EnumHolder> v);
811 bool isEnum() const {
812 return tag == Tag::Enum;
813 }
814 c10::intrusive_ptr<ivalue::EnumHolder> toEnumHolder() &&;
815 c10::intrusive_ptr<ivalue::EnumHolder> toEnumHolder() const&;
816
817 // None
818 IValue() : tag(Tag::None) {}
819 bool isNone() const {
820 return Tag::None == tag;
821 }
822 std::string toNone() const {
823 AT_ASSERT(isNone());
824 return "None";
825 }
826
827 static IValue uninitialized() {
828 auto i = IValue();
829 i.tag = Tag::Uninitialized;
830 return i;
831 }
832
833 // Scalar, which gets encoded as either an Int, a Double or a ComplexDouble
834 IValue(const at::Scalar& s) : IValue() {
835 // NB: do the symbolic versions first, as isFloatingPoint is true
836 // for both SymFloat and double
837 if (s.isSymInt()) {
838 tag = Tag::SymInt;
839 payload.u.as_intrusive_ptr = s.toSymInt().toSymNodeImpl().release();
840 } else if (s.isSymFloat()) {
841 tag = Tag::SymFloat;
842 payload.u.as_intrusive_ptr = s.toSymFloat().toSymNodeImpl().release();
843 } else if (s.isFloatingPoint()) {
844 tag = Tag::Double;
845 payload.u.as_double = s.toDouble();
846 } else if (s.isComplex()) {
847 *this = s.toComplexDouble();
848 } else if (s.isBoolean()) {
849 tag = Tag::Bool;
850 payload.u.as_bool = s.toBool();
851 } else {
852 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(s.isIntegral(false), "Unknown type in Scalar");
853 tag = Tag::Int;
854 payload.u.as_int = s.toLong();
855 }
856 }
857
858 bool isScalar() const {
859 return isDouble() || isInt() || isComplexDouble() || isBool() || isSymInt() || isSymFloat();
860 }
861
862 at::Scalar toScalar() const {
863 if (isDouble())
864 return toDouble();
865 else if (isInt())
866 return toInt();
867 else if (isComplexDouble())
868 return toComplexDouble();
869 else if (isBool())
870 return toBool();
871 else if (isSymInt())
872 return toSymInt();
873 else if (isSymFloat())
874 return toSymFloat();
875 throw std::runtime_error("IValue is not a Scalar");
876 }
877
878 // Device
879 IValue(c10::Device d) : tag(Tag::Device) {
880 payload.u.as_device.type = d.type();
881 payload.u.as_device.index = d.index();
882 }
883 bool isDevice() const {
884 return Tag::Device == tag;
885 }
886 c10::Device toDevice() const {
887 AT_ASSERT(isDevice());
888 return c10::Device(payload.u.as_device.type, payload.u.as_device.index);
889 }
890
891 // Stream
892 IValue(c10::Stream s)
893 : tag(Tag::Stream) {
894 auto v = c10::make_intrusive<ivalue::StreamData3Holder>(s.pack3());
895 payload.u.as_intrusive_ptr = v.release();
896 }
897 c10::Stream toStream() &&;
898 c10::Stream toStream() const &;
899 bool isStream() const { return Tag::Stream == tag; }
900
901 // ScalarType
902 IValue(ScalarType t)
903 : IValue(static_cast<std::underlying_type<ScalarType>::type>(t)) {}
904 at::ScalarType toScalarType() const {
905 return static_cast<at::ScalarType>(toInt());
906 }
907
908 // Layout
909 IValue(Layout l)
910 : IValue(static_cast<std::underlying_type<Layout>::type>(l)) {}
911 at::Layout toLayout() const {
912 return static_cast<at::Layout>(toInt());
913 }
914
915 // MemoryFormat
916 IValue(MemoryFormat m)
917 : IValue(static_cast<std::underlying_type<MemoryFormat>::type>(m)) {}
918 at::MemoryFormat toMemoryFormat() const {
919 return static_cast<at::MemoryFormat>(toInt());
920 }
921
922 // QScheme
923 IValue(at::QScheme qscheme) : tag(Tag::Int) {
924 payload.u.as_int = static_cast<int64_t>(qscheme);
925 }
926
927 at::QScheme toQScheme() const {
928 return static_cast<at::QScheme>(toInt());
929 }
930
931 // Dimname
932 IValue(at::Dimname dimname) : IValue(dimname.symbol().toQualString()) {}
933
934 at::Dimname toDimname() const {
935 return at::Dimname::fromSymbol(Symbol::fromQualString(toStringRef()));
936 }
937
938 // Generator
939 IValue(at::Generator g) : tag(Tag::Generator) {
940 payload.u.as_intrusive_ptr = null_to_undefined_tensor(g.unsafeReleaseGeneratorImpl());
941 }
942 bool isGenerator() const {
943 return Tag::Generator == tag;
944 }
945 at::Generator toGenerator() &&;
946 at::Generator toGenerator() const&;
947
948 // for debugging
949 std::string tagKind() const {
950 switch (tag) {
951#define DEFINE_CASE(x) \
952 case Tag::x: \
953 return #x;
954 TORCH_FORALL_TAGS(DEFINE_CASE)
955#undef DEFINE_CASE
956 }
957 return "InvalidTag(" + c10::guts::to_string(static_cast<int>(tag)) + ")";
958 }
959
960 // generic v.to<at::Tensor>() implementations
961 // that can be used in special functions like pop/push
962 // that use template meta-programming.
963 // prefer the directly named methods when you can,
964 // since they are simpler to understand
965
966 // Note: if you get linker errors saying one of these is missing,
967 // change it to ... && = delete; and you will see better error messages for
968 // why However, we cannot commit this because some compiler versions barf on
969 // it.
970 template <typename T>
971 T to() &&;
972 template <typename T>
973 typename c10::detail::ivalue_to_const_ref_overload_return<T>::type to() const&;
974
975 // ToOptional: convert a IValue to the Optional obj that accepts both T and
976 // None
977 template <typename T>
978 optional<T> toOptional();
979 template <typename T>
980 optional<T> toOptional() const;
981
982 /// @private [doxygen private]
983 /// this is a shallow comparison of two IValues to test the object identity
984 bool isSameIdentity(const IValue& rhs) const;
985
986 // Computes the "official" string representation of an IValue. This produces a
987 // TorchScript expression that can be used to recreate an IValue with the same
988 // value (e.g. when we are printing constants in the serializer).
989 //
990 // Callers can use `customFormatter` to override how `repr()` prints out an
991 // IValue. This is useful if you have some other environment where you can
992 // look up values, and you want to print a reference to that environment (like
993 // the serializer's constant table).
994 //
995 // repr() is not necessarily defined on all objects!
996 std::ostream& repr(
997 std::ostream& stream,
998 std::function<bool(std::ostream&, const IValue& v)> customFormatter)
999 const;
1000
1001 // Computes an "informal" string representation of an IValue. This should be
1002 // used for debugging, or servicing `print()`-like functions.
1003 // This is different from `repr()` in that there is no expectation that we can
1004 // exactly reconstruct an IValue from the output; feel free to use a
1005 // concise/pretty form
1006 TORCH_API friend std::ostream& operator<<(
1007 std::ostream& out,
1008 const IValue& v);
1009
1010 bool isPtrType() const {
1011 if (isTensor()) {
1012 return payload.as_tensor.defined();
1013 }
1014 return isIntrusivePtrLegacyBehavior();
1015 }
1016
1017 /// @private [doxygen private]
1018 const void* internalToPointer() const {
1019 TORCH_INTERNAL_ASSERT(
1020 isPtrType(), "Can only call internalToPointer() for pointer types");
1021 if (isTensor()) {
1022 return payload.as_tensor.unsafeGetTensorImpl();
1023 } else {
1024 return payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()
1025 ? payload.u.as_intrusive_ptr : nullptr;
1026 }
1027 }
1028
1029 template <typename T = c10::PlatformType>
1030 TypePtr type() const;
1031
1032 // Detect aliased tensors.
1033 struct HashAliasedIValue {
1034 size_t hashTensor(const at::Tensor& ten) const {
1035 if (ten.is_sparse()) {
1036 // COO sparse tensors have a "values" tensor and an "indices" tensor
1037 // so this will detect overlap of sparse tensors that share a values
1038 // tensor, but not sparse tensors that share an indices tensor.
1039 return hashTensor(ten._values());
1040 } else if (ten.is_sparse_csr()) {
1041 // COO sparse tensors have a "values" tensor and an "indices" tensor
1042 // so this will detect overlap of sparse tensors that share a values
1043 // tensor, but not sparse tensors that share an indices tensor.
1044 return hashTensor(ten.values());
1045 } else if (!ten.has_storage()) {
1046 // Opaque tensors such as the ones constructed by the MKL-DNN backend
1047 // don't have storage so we just use their TensorImpls.
1048 // TODO: Find way to expose alias info for opaque tensors.
1049 return reinterpret_cast<size_t>(ten.unsafeGetTensorImpl());
1050 } else {
1051 return reinterpret_cast<size_t>(
1052 ten.storage().unsafeGetStorageImpl());
1053 }
1054 }
1055 size_t operator()(const IValue& val) const {
1056 if (val.isTensor()) {
1057 return hashTensor(val.toTensor());
1058 }
1059 // If it is not a Tensor, then two mutable IValues alias each other only
1060 // if they are the same pointer.
1061 return val.payload.u.as_int;
1062 }
1063 };
1064
1065 struct CompAliasedIValues {
1066 bool operator()(const IValue& lhs, const IValue& rhs) const {
1067 return lhs.isAliasOf(rhs);
1068 }
1069 };
1070
1071 using HashAliasedIValues =
1072 std::unordered_set<IValue, HashAliasedIValue, CompAliasedIValues>;
1073 using HashAliasedIValueMap =
1074 std::unordered_map<IValue, IValue, HashAliasedIValue, CompAliasedIValues>;
1075
1076 // Chechs if this and rhs has a subvalues in common.
1077 // [t1,t2] and [t2, t3] returns true.
1078 bool overlaps(const IValue& rhs) const;
1079
1080 // Inserts all subvalues of this in subValues.
1081 void getSubValues(HashAliasedIValues& subValues) const;
1082
1083 // Apply visitor to every subvalue.
1084 // TODO: There are several places that recurse over IValue. This is fragile.
1085 // This visitor should be used to recurse over ivalues.
1086 void visit(const std::function<bool(const IValue&)>& visitor) const;
1087 IValue deepcopy() const;
1088 IValue deepcopy(HashAliasedIValueMap& memo) const;
1089
1090 private:
1091 static c10::intrusive_ptr_target* null_to_undefined_tensor(c10::intrusive_ptr_target* p) {
1092 return p ? p : static_cast<c10::intrusive_ptr_target*>(c10::UndefinedTensorImpl::singleton());
1093 }
1094
1095 static bool ptrEqual(const IValue& lhs, const IValue& rhs);
1096 // NOTE: IValue tags are intentionally private. In the future we may encode
1097 // this value different (e.g. using NaN boxing), and this would make it more
1098 // costly to determine the tag for all types vs just determining if something
1099 // is a particular type. Instead we want clients to use the `isX` methods when
1100 // possible. If for perf. reasons you really, absolutely, must have a jump
1101 // table, then we can revisit this.
1102 enum class Tag : uint32_t {
1103#define DEFINE_TAG(x) x,
1104 TORCH_FORALL_TAGS(DEFINE_TAG)
1105#undef DEFINE_TAG
1106 };
1107
1108 template <
1109 class T,
1110 class NullType = c10::detail::intrusive_target_default_null_type<T>>
1111 c10::intrusive_ptr<T, NullType> moveToIntrusivePtr();
1112 template <
1113 typename T,
1114 class NullType = c10::detail::intrusive_target_default_null_type<T>>
1115 c10::intrusive_ptr<T, NullType> toIntrusivePtr() const;
1116
1117 void destroy() {
1118 // We carefully construct this call to both 1) avoid UB by using
1119 // the "wrong" one of as_tensor and as_intrusive_ptr and 2) enable
1120 // the compiler to generate the same code for each case. It is
1121 // surprisingly difficult to get this right.
1122 if (isTensor() || isIntrusivePtr()) {
1123 c10::intrusive_ptr_target* p = isTensor() ? payload.as_tensor.unsafeGetTensorImpl() : payload.u.as_intrusive_ptr;
1124 c10::intrusive_ptr<intrusive_ptr_target, c10::UndefinedTensorImpl>::reclaim(p);
1125 // No need to make this destructor call!
1126 // payload.as_tensor.~Tensor();
1127 }
1128 }
1129
1130 C10_ALWAYS_INLINE void moveFrom(IValue&& rhs) noexcept {
1131 if (rhs.isTensor()) {
1132 new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor));
1133 // As far as I can tell, omitting the usual explicit destructor call
1134 // is not UB in and of itself, and it's a slight perf win. The
1135 // destructor is a no-op, because the moved-from Tensor is
1136 // effectively an intrusive_ptr in the null state, so we don't need
1137 // the behavior for correctness reasons either. Leaving this
1138 // explanatory comment, including commented-out destructor call, to
1139 // make this abundantly clear.
1140 //
1141 // rhs.payload.as_tensor.~Tensor();
1142 } else {
1143 payload.u = rhs.payload.u;
1144 }
1145 tag = rhs.tag;
1146 rhs.clearToNone();
1147 }
1148
1149 void clearToNone() noexcept {
1150 payload.u.as_int = 0;
1151 tag = Tag::None;
1152 }
1153
1154 bool isIntrusivePtr() const {
1155 switch (tag) {
1156 case Tag::None:
1157 return false;
1158 case Tag::Tensor:
1159 return false;
1160 case Tag::Storage:
1161 return true;
1162 case Tag::Generator:
1163 return true;
1164 case Tag::Double:
1165 return false;
1166 case Tag::ComplexDouble:
1167 return true;
1168 case Tag::Int:
1169 return false;
1170 case Tag::SymInt:
1171 return true;
1172 case Tag::SymFloat:
1173 return true;
1174 case Tag::Bool:
1175 return false;
1176 case Tag::Tuple:
1177 return true;
1178 case Tag::String:
1179 return true;
1180 case Tag::Blob:
1181 return true;
1182 case Tag::GenericList:
1183 return true;
1184 case Tag::GenericDict:
1185 return true;
1186 case Tag::Future:
1187 return true;
1188 case Tag::Await:
1189 return true;
1190 case Tag::Device:
1191 return false;
1192 case Tag::Stream:
1193 return true;
1194 case Tag::Object:
1195 return true;
1196 case Tag::PyObject:
1197 return true;
1198 case Tag::Uninitialized:
1199 return false;
1200 case Tag::Capsule:
1201 return true;
1202 case Tag::RRef:
1203 return true;
1204 case Tag::Quantizer:
1205 return true;
1206 case Tag::Enum:
1207 return true;
1208 }
1209 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false, "unexpected tag ", static_cast<int>(tag));
1210 return false;
1211 }
1212
1213 // Storage and Generator were treated specially when
1214 // is_intrusive_ptr was stored as explicit state. This getter
1215 // preserves the old behavior for use with WeakIValue for now.
1216 bool isIntrusivePtrLegacyBehavior() const {
1217 if (tag == Tag::Storage || tag == Tag::Generator) {
1218 return payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton();
1219 } else {
1220 return isIntrusivePtr();
1221 }
1222 }
1223
1224 union Payload {
1225 // [TriviallyCopyablePayload]
1226 // We use a nested union here so that we can make the copy easy
1227 // and efficient in the non-tensor (i.e., trivially copyable)
1228 // case. Specifically, we do not have to do a switch-on-tag to
1229 // figure out which union member to assign; we can just use
1230 // TriviallyCopyablePayload::operator=.
1231 union TriviallyCopyablePayload {
1232 TriviallyCopyablePayload() : as_int(0) {}
1233 int64_t as_int;
1234 double as_double;
1235 bool as_bool;
1236 // Invariant: never nullptr; null state is represented as
1237 // c10::UndefinedTensorImpl::singleton() for consistency of
1238 // representation with Tensor.
1239 c10::intrusive_ptr_target* as_intrusive_ptr;
1240 struct {
1241 DeviceType type;
1242 DeviceIndex index;
1243 } as_device;
1244 } u;
1245 at::Tensor as_tensor;
1246 Payload() : u() {}
1247 ~Payload() {}
1248 };
1249
1250 IValue(const Payload& p, Tag t) : tag(t) {
1251 if (isTensor()) {
1252 new (&payload.as_tensor) at::Tensor(p.as_tensor);
1253 } else {
1254 payload.u = p.u;
1255 }
1256 }
1257
1258 template <typename T>
1259 struct TagType {};
1260
1261 friend MaybeOwnedTraits<IValue>;
1262
1263 Payload payload;
1264 Tag tag{IValue::Tag::None};
1265 friend struct WeakIValue;
1266};
1267
1268struct TORCH_API WeakIValue final {
1269 WeakIValue() = default;
1270
1271 WeakIValue(const WeakIValue& rhs)
1272 : payload(rhs.payload),
1273 tag(rhs.tag),
1274 is_intrusive_ptr(rhs.is_intrusive_ptr) {
1275 if (is_intrusive_ptr && payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
1276 c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr);
1277 }
1278 }
1279 WeakIValue(const IValue& rhs)
1280 : tag(rhs.tag),
1281 is_intrusive_ptr(rhs.isIntrusivePtrLegacyBehavior()) {
1282 if (rhs.isTensor()) {
1283 payload.as_intrusive_ptr = rhs.unsafeToTensorImpl();
1284 is_intrusive_ptr = true;
1285 } else {
1286 payload = rhs.payload.u;
1287 }
1288 if (is_intrusive_ptr) {
1289 if (payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
1290 c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr);
1291 }
1292 }
1293 }
1294 WeakIValue(WeakIValue&& rhs) noexcept : WeakIValue() {
1295 swap(rhs);
1296 }
1297 ~WeakIValue() {
1298 if (is_intrusive_ptr && payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) {
1299 c10::raw::weak_intrusive_ptr::decref(payload.as_intrusive_ptr);
1300 }
1301 }
1302 WeakIValue& operator=(WeakIValue&& rhs) & noexcept {
1303 WeakIValue(std::move(rhs)).swap(*this); // this also sets rhs to None
1304 return *this;
1305 }
1306 WeakIValue& operator=(WeakIValue const& rhs) & {
1307 WeakIValue(rhs).swap(*this);
1308 return *this;
1309 }
1310 void swap(WeakIValue& rhs) noexcept {
1311 std::swap(payload, rhs.payload);
1312 std::swap(is_intrusive_ptr, rhs.is_intrusive_ptr);
1313 std::swap(tag, rhs.tag);
1314 }
1315
1316 bool isSameIdentity(const WeakIValue& rhs) const {
1317 return payload.as_int == rhs.payload.as_int && tag == rhs.tag &&
1318 is_intrusive_ptr == rhs.is_intrusive_ptr;
1319 }
1320
1321 IValue lock() const {
1322 if (!is_intrusive_ptr) {
1323 IValue::Payload newPayload;
1324 newPayload.u = payload;
1325 return IValue(newPayload, tag);
1326 }
1327 if (IValue::Tag::Tensor == tag) {
1328 auto temp = c10::weak_intrusive_ptr<at::TensorImpl, c10::UndefinedTensorImpl>::reclaim(
1329 static_cast<at::TensorImpl*>(payload.as_intrusive_ptr));
1330 c10::intrusive_ptr<at::TensorImpl, c10::UndefinedTensorImpl> ip(temp.lock());
1331 temp.release();
1332 if (!ip) {
1333 return IValue();
1334 } else {
1335 return IValue(at::Tensor(std::move(ip)));
1336 }
1337 } else {
1338 auto temp = c10::weak_intrusive_ptr<c10::intrusive_ptr_target>::reclaim(
1339 payload.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()
1340 ? nullptr
1341 : payload.as_intrusive_ptr);
1342 IValue::Payload pl;
1343 pl.u.as_intrusive_ptr = temp.lock().release();
1344 temp.release();
1345 if (!pl.u.as_intrusive_ptr) {
1346 return IValue();
1347 } else {
1348 return IValue(pl, tag);
1349 }
1350 }
1351 }
1352
1353 size_t use_count() const noexcept {
1354 if (!is_intrusive_ptr) {
1355 return 1;
1356 }
1357 auto temp = c10::weak_intrusive_ptr<c10::intrusive_ptr_target, c10::UndefinedTensorImpl>::reclaim(
1358 payload.as_intrusive_ptr);
1359 size_t result = temp.use_count();
1360 temp.release();
1361 return result;
1362 }
1363
1364 size_t weak_use_count() const noexcept {
1365 if (!is_intrusive_ptr) {
1366 return 1;
1367 }
1368 auto temp = c10::weak_intrusive_ptr<c10::intrusive_ptr_target, c10::UndefinedTensorImpl>::reclaim(
1369 payload.as_intrusive_ptr);
1370 size_t result = temp.weak_use_count();
1371 temp.release();
1372 return result;
1373 }
1374 size_t hash() const {
1375 return payload.as_int;
1376 }
1377
1378 private:
1379 using Payload = IValue::Payload::TriviallyCopyablePayload;
1380 Payload payload;
1381 IValue::Tag tag{IValue::Tag::None};
1382 bool is_intrusive_ptr{false};
1383};
1384
1385// An owning pointer to a type. When the type is class type, it requires a pair
1386// of shared_ptrs to the class type and its owning CU, so that the class type is
1387// guaranteed to stay alive as long as we hold this object.
1388struct TORCH_API StrongTypePtr {
1389 StrongTypePtr(
1390 std::shared_ptr<torch::jit::CompilationUnit> cu,
1391 TypePtr type);
1392
1393 std::shared_ptr<torch::jit::CompilationUnit> cu_;
1394 TypePtr type_;
1395};
1396
1397// [Constant Object Weak CompilationUnit Reference]
1398// A non owning pointer to a type. When a class get inserted as a constant
1399// into a graph, if we used a strong pointer we would have a circular reference
1400// from Object -> CompilationUnit and CompilationUnit -> Graph (which owns the
1401// Constant Object)
1402struct TORCH_API WeakTypePtr {
1403 WeakTypePtr(
1404 std::weak_ptr<torch::jit::CompilationUnit> cu,
1405 TypePtr type);
1406
1407 std::weak_ptr<torch::jit::CompilationUnit> cu_;
1408 TypePtr type_;
1409};
1410
1411// internal build errors with std::variant :/
1412struct WeakOrStrongCompilationUnit {
1413 explicit WeakOrStrongCompilationUnit(
1414 std::shared_ptr<torch::jit::CompilationUnit> shared_cu) : strong_ptr_(std::move(shared_cu)), weak_ptr_(c10::nullopt) {}
1415
1416 explicit WeakOrStrongCompilationUnit(
1417 std::weak_ptr<torch::jit::CompilationUnit> weak_cu) : strong_ptr_(c10::nullopt), weak_ptr_(std::move(weak_cu)) {}
1418
1419 std::shared_ptr<torch::jit::CompilationUnit> getStrongRefOrThrow() const {
1420 TORCH_INTERNAL_ASSERT(strong_ptr_ != c10::nullopt);
1421 return *strong_ptr_;
1422 }
1423
1424 std::weak_ptr<torch::jit::CompilationUnit> getWeakRefOrThrow() const {
1425 TORCH_INTERNAL_ASSERT(weak_ptr_ != c10::nullopt);
1426 return *weak_ptr_;
1427 }
1428
1429 bool holdingStrongRef() const {
1430 return strong_ptr_ != c10::nullopt;
1431 }
1432
1433 bool holdingEmptyStrongRef() const {
1434 return holdingStrongRef() && *strong_ptr_ == nullptr;
1435 }
1436
1437 c10::optional<std::shared_ptr<torch::jit::CompilationUnit>> strong_ptr_;
1438 c10::optional<std::weak_ptr<torch::jit::CompilationUnit>> weak_ptr_;
1439};
1440
1441// An Object will hold a non-owning Compilation Unit reference if it is a
1442// Constant in the graph and a Owning reference otherwise
1443struct TORCH_API WeakOrStrongTypePtr {
1444 explicit WeakOrStrongTypePtr(WeakTypePtr weak)
1445 : cu_(WeakOrStrongCompilationUnit(std::move(weak.cu_))), type_(std::move(weak.type_)) {}
1446 explicit WeakOrStrongTypePtr(StrongTypePtr strong)
1447 : cu_(WeakOrStrongCompilationUnit(std::move(strong.cu_))), type_(std::move(strong.type_)) {}
1448 explicit WeakOrStrongTypePtr(WeakOrStrongCompilationUnit cu, TypePtr type)
1449 : cu_(std::move(cu)), type_(std::move(type)) {}
1450 WeakTypePtr asWeakTypePtr() const;
1451
1452 WeakOrStrongCompilationUnit cu_;
1453 TypePtr type_;
1454
1455 bool holds_strong_ref() const {
1456 return cu_.holdingStrongRef();
1457 }
1458
1459 bool holds_empty_strong_ref() const {
1460 return cu_.holdingEmptyStrongRef();
1461 }
1462};
1463
1464
1465} // namespace c10
1466
1467#include <ATen/core/ivalue_inl.h> // IWYU pragma: keep
1468