1#pragma once
2
3#include <condition_variable>
4#include <memory>
5#include <type_traits>
6#include <utility>
7
8#include <ATen/core/Dict.h>
9#include <ATen/core/List.h>
10#include <ATen/core/IListRef.h>
11#include <ATen/core/functional.h>
12#include <ATen/core/jit_type.h>
13#include <ATen/core/qualified_name.h>
14#include <ATen/core/rref_interface.h>
15#include <ATen/core/symbol.h>
16#include <c10/core/DeviceGuard.h>
17#include <c10/core/Event.h>
18#include <c10/core/Scalar.h>
19#include <c10/core/Stream.h>
20#include <c10/core/StreamGuard.h>
21#include <c10/core/TensorImpl.h>
22#include <c10/core/UndefinedTensorImpl.h>
23#include <c10/core/impl/DeviceGuardImplInterface.h>
24#include <c10/util/FunctionRef.h>
25#include <c10/util/hash.h>
26#include <c10/util/intrusive_ptr.h>
27#include <c10/util/irange.h>
28
29namespace torch {
30namespace jit {
31struct Function;
32struct CompilationUnit;
33} // namespace jit
34TORCH_API bool isCustomClass(const c10::IValue& v);
35} // namespace torch
36namespace c10 {
37struct IValue;
38struct ClassType;
39struct TupleType;
40struct EnumType;
41struct InferredType;
42
43// For custom class __init__ registration, we need to pass in a function
44// that looks like this: [](IValue x, args...)
45
46// However, make_boxed_from_unboxed_functor.h automatically sets the input types
47// of the function by introspecting the types of the functor (which is IValue in
48// this case). However, we need the type it binds to be Foo.
49
50// Instead, we pass in a lambda [](ivalue_holder<CurClass> x, args...) from
51// which getTypePtr can recover the original class pointer.
52
53template <typename TaggedCapsuleType>
54struct tagged_capsule {
55 IValue ivalue;
56};
57
58template <class T, class NullType>
59c10::intrusive_ptr<T, NullType> IValue::moveToIntrusivePtr() {
60 auto t = c10::intrusive_ptr<T, NullType>::reclaim(
61 payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()
62 ? NullType::singleton()
63 : static_cast<T*>(payload.u.as_intrusive_ptr));
64 clearToNone();
65 return t;
66}
67template <typename T, class NullType>
68c10::intrusive_ptr<T, NullType> IValue::toIntrusivePtr() const {
69 if (payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()) {
70 return c10::intrusive_ptr<T, NullType>();
71 }
72 c10::raw::intrusive_ptr::incref(payload.u.as_intrusive_ptr);
73 return c10::intrusive_ptr<T, NullType>::reclaim(
74 static_cast<T*>(payload.u.as_intrusive_ptr));
75}
76
77template <class T, class U>
78intrusive_ptr<T> static_intrusive_pointer_cast(intrusive_ptr<U> r) {
79 return intrusive_ptr<T>::reclaim(static_cast<T*>(r.release()));
80}
81
82template <class T, class U>
83intrusive_ptr<T> dynamic_intrusive_pointer_cast(intrusive_ptr<U> r) {
84 return intrusive_ptr<T>::reclaim(dynamic_cast<T*>(r.release()));
85}
86
87inline c10::intrusive_ptr<ivalue::Future> IValue::toFuture() && {
88 AT_ASSERT(isFuture(), "Expected Future but got ", tagKind());
89 return moveToIntrusivePtr<ivalue::Future>();
90}
91inline c10::intrusive_ptr<ivalue::Future> IValue::toFuture() const& {
92 AT_ASSERT(isFuture(), "Expected Future but got ", tagKind());
93 return toIntrusivePtr<ivalue::Future>();
94}
95inline c10::intrusive_ptr<ivalue::Await> IValue::toAwait() && {
96 AT_ASSERT(isAwait(), "Expected Await but got ", tagKind());
97 return moveToIntrusivePtr<ivalue::Await>();
98}
99inline c10::intrusive_ptr<ivalue::Await> IValue::toAwait() const& {
100 AT_ASSERT(isAwait(), "Expected Await but got ", tagKind());
101 return toIntrusivePtr<ivalue::Await>();
102}
103inline c10::intrusive_ptr<c10::RRefInterface> IValue::toRRef() && {
104 AT_ASSERT(isRRef(), "Expected RRef but got ", tagKind());
105 return moveToIntrusivePtr<c10::RRefInterface>();
106}
107inline c10::intrusive_ptr<c10::RRefInterface> IValue::toRRef() const& {
108 AT_ASSERT(isRRef(), "Expected RRef but got ", tagKind());
109 return toIntrusivePtr<c10::RRefInterface>();
110}
111inline c10::intrusive_ptr<at::Quantizer> IValue::toQuantizer() && {
112 AT_ASSERT(isQuantizer(), "Expected Quantizer but got ", tagKind());
113 return moveToIntrusivePtr<at::Quantizer>();
114}
115inline c10::intrusive_ptr<at::Quantizer> IValue::toQuantizer() const& {
116 AT_ASSERT(isQuantizer(), "Expected Quantizer but got ", tagKind());
117 return toIntrusivePtr<at::Quantizer>();
118}
119inline c10::intrusive_ptr<ivalue::ConstantString> IValue::toString() && {
120 AT_ASSERT(isString(), "Expected String but got ", tagKind());
121 return moveToIntrusivePtr<ivalue::ConstantString>();
122}
123inline c10::intrusive_ptr<ivalue::ConstantString> IValue::toString() const& {
124 AT_ASSERT(isString(), "Expected String but got ", tagKind());
125 return toIntrusivePtr<ivalue::ConstantString>();
126}
127inline c10::intrusive_ptr<ivalue::Object> IValue::toObject() && {
128 AT_ASSERT(isObject(), "Expected Object but got ", tagKind());
129 return moveToIntrusivePtr<ivalue::Object>();
130}
131inline c10::intrusive_ptr<ivalue::Object> IValue::toObject() const& {
132 AT_ASSERT(isObject(), "Expected Object but got ", tagKind());
133 return toIntrusivePtr<ivalue::Object>();
134}
135inline c10::intrusive_ptr<ivalue::PyObjectHolder> IValue::
136 toPyObjectHolder() && {
137 TORCH_INTERNAL_ASSERT(isPyObject(), "Expected PyObject but got ", tagKind());
138 return moveToIntrusivePtr<ivalue::PyObjectHolder>();
139}
140inline c10::intrusive_ptr<ivalue::PyObjectHolder> IValue::toPyObjectHolder()
141 const& {
142 TORCH_INTERNAL_ASSERT(isPyObject(), "Expected PyObject but got ", tagKind());
143 return toIntrusivePtr<ivalue::PyObjectHolder>();
144}
145inline c10::intrusive_ptr<ivalue::EnumHolder> IValue::toEnumHolder() && {
146 TORCH_INTERNAL_ASSERT(isEnum(), "Expected Enum but got ", tagKind());
147 return moveToIntrusivePtr<ivalue::EnumHolder>();
148}
149inline c10::intrusive_ptr<ivalue::EnumHolder> IValue::toEnumHolder() const& {
150 TORCH_INTERNAL_ASSERT(isEnum(), "Expected Enum but got ", tagKind());
151 return toIntrusivePtr<ivalue::EnumHolder>();
152}
153inline c10::complex<double> IValue::toComplexDouble() const {
154 TORCH_INTERNAL_ASSERT(isComplexDouble(), "Expected ComplexDouble but got ", tagKind());
155 auto ptr = toIntrusivePtr<ivalue::ComplexHolder>();
156 return (*ptr).val;
157}
158inline at::Tensor IValue::toTensor() && {
159 if (C10_UNLIKELY(!isTensor())) {
160 reportToTensorTypeError();
161 }
162 auto result = std::move(payload.as_tensor);
163 // As far as I can tell, omitting the usual explicit destructor call
164 // is not UB in and of itself, and it's a slight perf win. The
165 // destructor is a no-op, because the moved-from Tensor is
166 // effectively an intrusive_ptr in the null state, so we don't need
167 // the behavior for correctness reasons either. Leaving this
168 // explanatory comment, including commented-out destructor call, to
169 // make this abundantly clear.
170 //
171 // payload.as_tensor.~Tensor();
172 clearToNone();
173 return result;
174}
175inline at::Tensor& IValue::toTensor() & {
176 if (C10_UNLIKELY(!isTensor())) {
177 reportToTensorTypeError();
178 }
179 return payload.as_tensor;
180}
181inline const at::Tensor& IValue::toTensor() const& {
182 if (C10_UNLIKELY(!isTensor())) {
183 reportToTensorTypeError();
184 }
185 return payload.as_tensor;
186}
187inline c10::Storage IValue::toStorage() && {
188 AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind());
189 return c10::Storage(
190 moveToIntrusivePtr<at::StorageImpl>());
191}
192inline c10::Storage IValue::toStorage() const& {
193 AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind());
194 return c10::Storage(toIntrusivePtr<at::StorageImpl>());
195}
196inline c10::Stream IValue::toStream() && {
197 AT_ASSERT(isStream(), "Expected Stream but got ", tagKind());
198 auto ptr = toIntrusivePtr<ivalue::StreamData3Holder>();
199 return c10::Stream::unpack3((*ptr).val.stream_id,
200 (*ptr).val.device_index,
201 (*ptr).val.device_type);
202}
203inline c10::Stream IValue::toStream() const& {
204 AT_ASSERT(isStream(), "Expected Stream but got ", tagKind());
205 auto ptr = toIntrusivePtr<ivalue::StreamData3Holder>();
206 return c10::Stream::unpack3((*ptr).val.stream_id,
207 (*ptr).val.device_index,
208 (*ptr).val.device_type);
209}
210inline c10::intrusive_ptr<caffe2::Blob> IValue::toBlob() && {
211 AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind());
212 return moveToIntrusivePtr<caffe2::Blob>();
213}
214inline c10::intrusive_ptr<caffe2::Blob> IValue::toBlob() const& {
215 AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind());
216 return toIntrusivePtr<caffe2::Blob>();
217 ;
218}
219inline c10::intrusive_ptr<torch::CustomClassHolder> IValue::toCapsule() && {
220 TORCH_INTERNAL_ASSERT(isCapsule());
221 return moveToIntrusivePtr<torch::CustomClassHolder>();
222}
223inline c10::intrusive_ptr<torch::CustomClassHolder> IValue::toCapsule() const& {
224 TORCH_INTERNAL_ASSERT(isCapsule());
225 return toIntrusivePtr<torch::CustomClassHolder>();
226}
227inline at::Generator IValue::toGenerator() && {
228 AT_ASSERT(isGenerator(), "Expected Generator but got ", tagKind());
229 return at::Generator(moveToIntrusivePtr<at::GeneratorImpl>());
230}
231inline at::Generator IValue::toGenerator() const& {
232 AT_ASSERT(isGenerator(), "Expected Generator but got ", tagKind());
233 return at::Generator(toIntrusivePtr<at::GeneratorImpl>());
234}
235inline c10::SymInt IValue::toSymInt() && {
236 AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind());
237 if (isSymInt()) {
238 return c10::SymInt(moveToIntrusivePtr<c10::SymNodeImpl>());
239 } else {
240 return c10::SymInt(payload.u.as_int);
241 }
242}
243inline c10::SymInt IValue::toSymInt() const& {
244 AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind());
245 if (isSymInt()) {
246 return c10::SymInt(toIntrusivePtr<c10::SymNodeImpl>());
247 } else {
248 return c10::SymInt(payload.u.as_int);
249 }
250}
251inline c10::SymFloat IValue::toSymFloat() && {
252 AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got ", tagKind());
253 if (isSymFloat()) {
254 return c10::SymFloat(moveToIntrusivePtr<c10::SymNodeImpl>());
255 } else {
256 return c10::SymFloat(payload.u.as_double);
257 }
258}
259inline c10::SymFloat IValue::toSymFloat() const& {
260 AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got ", tagKind());
261 if (isSymFloat()) {
262 return c10::SymFloat(toIntrusivePtr<c10::SymNodeImpl>());
263 } else {
264 return c10::SymFloat(payload.u.as_double);
265 }
266}
267
268namespace ivalue {
269
270void TORCH_API
271checkCustomClassType(const ClassType* expected_type, const Type* actual_type);
272
273template <typename T>
274using Shared = c10::intrusive_ptr<T>;
275
276// string
277struct TORCH_API ConstantString final : c10::intrusive_ptr_target {
278 private:
279 const std::string str_;
280
281 public:
282 ConstantString(std::string str) : str_(std::move(str)) {}
283 ConstantString(c10::string_view str) : str_(std::string(str)) {}
284 static c10::intrusive_ptr<ConstantString> create(std::string str_);
285 static c10::intrusive_ptr<ConstantString> create(c10::string_view str_);
286 static c10::intrusive_ptr<ConstantString> create(const char* str_);
287
288 const std::string& string() const {
289 return str_;
290 }
291 c10::string_view string_view() const {
292 return str_;
293 }
294
295 operator const std::string&() const {
296 return string();
297 }
298 TORCH_API friend std::ostream& operator<<(
299 std::ostream& out,
300 const ConstantString& v);
301};
302
303struct Future;
304
305struct TORCH_API TupleElements {
306 private:
307 size_t inlineSize_;
308 // We represent TupleElements this way to save doing a heap
309 // allocation in the common (at least for unpickling) case where we
310 // have only 3 elements. We have our own union instead of
311 // c10::SmallVector<IValue> because c10::SmallVector<IValue> always
312 // stores the begin/end/capacity pointers, which would be a waste of
313 // space in our use case.
314 union {
315 std::vector<IValue> elementsVector_;
316 // Don't want to declare a std::array because the convenient
317 // iteration and size members are a footgun in this case -- the
318 // actual size of the array may be smaller than 3!
319 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays)
320 IValue elementsInline_[3];
321 };
322
323 void destroyInline() {
324 for (const auto ii : c10::irange(inlineSize_)) {
325 elementsInline_[ii].~IValue();
326 }
327 }
328 public:
329
330 using iterator = IValue*;
331 using const_iterator = const IValue*;
332
333 TupleElements() : inlineSize_(0) {
334 new (&elementsVector_) std::vector<IValue>();
335 }
336
337 explicit TupleElements(std::vector<IValue> elements)
338 : inlineSize_(0), elementsVector_(std::move(elements)) {}
339
340 explicit TupleElements(c10::ArrayRef<IValue> elements)
341 : inlineSize_(elements.size() <= 3 ? elements.size() : 0) {
342 switch (inlineSize_) {
343 case 3:
344 new (&elementsInline_[2]) IValue(elements[2]);
345 C10_FALLTHROUGH;
346 case 2:
347 new (&elementsInline_[1]) IValue(elements[1]);
348 C10_FALLTHROUGH;
349 case 1:
350 new (&elementsInline_[0]) IValue(elements[0]);
351 break;
352 case 0:
353 new (&elementsVector_) std::vector<IValue>(elements.begin(), elements.end());
354 break;
355 }
356 }
357
358 explicit TupleElements(IValue&& e1)
359 : inlineSize_(1) {
360 new (&elementsInline_[0]) IValue(std::move(e1));
361 }
362
363 explicit TupleElements(IValue&& e1, IValue&& e2)
364 : inlineSize_(2) {
365 new (&elementsInline_[0]) IValue(std::move(e1));
366 new (&elementsInline_[1]) IValue(std::move(e2));
367 }
368
369 explicit TupleElements(IValue&& e1, IValue&& e2, IValue&& e3)
370 : inlineSize_(3) {
371 new (&elementsInline_[0]) IValue(std::move(e1));
372 new (&elementsInline_[1]) IValue(std::move(e2));
373 new (&elementsInline_[2]) IValue(std::move(e3));
374 }
375
376 ~TupleElements() {
377 if (inlineSize_) {
378 destroyInline();
379 } else {
380 elementsVector_.~vector();
381 }
382 }
383
384 // It would be nice to make this noncopyable to prevent people from
385 // writing code like `auto output =
386 // forward(...).toTupleRef().elements()` (which does refcount bumps on
387 // each element, unlike the more efficient but verbose
388 // ```
389 // auto outputIntrusivePtr = forward(...).toTuple();
390 // const auto& output = outputIntrusivePtr->elements();
391 // ```
392 // ), but there is simply an overwhelming amount of code that does
393 // it the inefficient way.
394 // See also operator std::vector below.
395 TupleElements(const TupleElements& rhs)
396 : inlineSize_(rhs.inlineSize_) {
397 if (rhs.inlineSize_) {
398 for (const auto ii : c10::irange(inlineSize_)) {
399 new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]);
400 }
401 } else {
402 new (&elementsVector_) std::vector<IValue>(rhs.elementsVector_);
403 }
404 }
405
406 TupleElements& operator=(const TupleElements& rhs) {
407 if (inlineSize_) {
408 if (rhs.inlineSize_) {
409 for (const auto ii : c10::irange(std::min(inlineSize_, rhs.inlineSize_))) {
410 elementsInline_[ii] = rhs.elementsInline_[ii];
411 }
412 if (rhs.inlineSize_ > inlineSize_) {
413 for (const auto ii : c10::irange(inlineSize_, rhs.inlineSize_)) {
414 new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]);
415 }
416 } else {
417 for (const auto ii : c10::irange(rhs.inlineSize_, inlineSize_)) {
418 elementsInline_[ii].~IValue();
419 }
420 }
421 } else {
422 destroyInline();
423 new (&elementsVector_) std::vector<IValue>(rhs.elementsVector_);
424 }
425 } else {
426 if (rhs.inlineSize_) {
427 elementsVector_.~vector();
428 for (const auto ii : c10::irange(rhs.inlineSize_)) {
429 new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]);
430 }
431 } else {
432 elementsVector_ = rhs.elementsVector_;
433 }
434 }
435 inlineSize_ = rhs.inlineSize_;
436 return *this;
437 }
438
439 TupleElements(TupleElements&& rhs) noexcept
440 : inlineSize_(rhs.inlineSize_) {
441 if (inlineSize_) {
442 for (const auto ii : c10::irange(inlineSize_)) {
443 new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii]));
444 }
445 } else {
446 new (&elementsVector_) std::vector<IValue>(std::move(rhs.elementsVector_));
447 }
448 }
449
450 TupleElements& operator=(TupleElements&& rhs) noexcept {
451 if (inlineSize_) {
452 if (rhs.inlineSize_) {
453 for (const auto ii : c10::irange(std::min(inlineSize_, rhs.inlineSize_))) {
454 elementsInline_[ii] = std::move(rhs.elementsInline_[ii]);
455 }
456 if (rhs.inlineSize_ > inlineSize_) {
457 for (const auto ii : c10::irange(inlineSize_, rhs.inlineSize_)) {
458 new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii]));
459 }
460 } else {
461 for (const auto ii : c10::irange(rhs.inlineSize_, inlineSize_)) {
462 elementsInline_[ii].~IValue();
463 }
464 }
465 } else {
466 destroyInline();
467 new (&elementsVector_) std::vector<IValue>(std::move(rhs.elementsVector_));
468 }
469 } else {
470 if (rhs.inlineSize_) {
471 elementsVector_.~vector();
472 for (const auto ii : c10::irange(rhs.inlineSize_)) {
473 new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii]));
474 }
475 } else {
476 elementsVector_ = std::move(rhs.elementsVector_);
477 }
478 }
479 inlineSize_ = rhs.inlineSize_;
480 return *this;
481 }
482
483 C10_NODISCARD c10::ArrayRef<IValue> asArrayRef() const {
484 if (inlineSize_) {
485 return c10::ArrayRef<IValue>(elementsInline_, inlineSize_);
486 } else {
487 return elementsVector_;
488 }
489 }
490
491 // Mimic implicit conversion from std::vector to ArrayRef.
492 operator c10::ArrayRef<IValue>() const {
493 return asArrayRef();
494 }
495
496 static size_t hash(const TupleElements& v) {
497 return c10::hash<c10::ArrayRef<IValue>>()(v.asArrayRef());
498 }
499
500 void setContents(std::vector<IValue>&& contents) {
501 if (inlineSize_) {
502 destroyInline();
503 new (&elementsVector_) std::vector<IValue>(std::move(contents));
504 inlineSize_ = 0;
505 } else {
506 elementsVector_ = std::move(contents);
507 }
508 }
509
510 C10_NODISCARD bool empty() const {
511 return inlineSize_ ? false : elementsVector_.empty();
512 }
513
514 C10_NODISCARD size_t size() const {
515 return inlineSize_ ? inlineSize_ : elementsVector_.size();
516 }
517
518 C10_NODISCARD IValue& operator[](size_t idx) {
519 if (inlineSize_) {
520 return elementsInline_[idx];
521 } else {
522 return elementsVector_[idx];
523 }
524 }
525
526 C10_NODISCARD const IValue& operator[](size_t idx) const {
527 if (inlineSize_) {
528 return elementsInline_[idx];
529 } else {
530 return elementsVector_[idx];
531 }
532 }
533
534 C10_NODISCARD IValue& at(size_t idx) {
535 if (inlineSize_) {
536 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inlineSize_ <= 3);
537 TORCH_CHECK(idx < inlineSize_, "TupleElements: invalid index Index = ", idx, "; Length = ", inlineSize_);
538 return elementsInline_[idx];
539 } else {
540 return elementsVector_.at(idx);
541 }
542 }
543
544 C10_NODISCARD const IValue& at(size_t idx) const {
545 if (inlineSize_) {
546 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inlineSize_ <= 3);
547 TORCH_CHECK(idx < inlineSize_, "TupleElements: invalid index Index = ", idx, "; Length = ", inlineSize_);
548 return elementsInline_[idx];
549 } else {
550 TORCH_CHECK(idx < elementsVector_.size(), "TupleElements: invalid index Index = ", idx, "; Length = ", elementsVector_.size());
551 return elementsVector_.at(idx);
552 }
553 }
554
555 C10_NODISCARD iterator begin() {
556 if (inlineSize_) {
557 return elementsInline_;
558 } else {
559 return elementsVector_.data();
560 }
561 }
562
563 C10_NODISCARD iterator end() {
564 if (inlineSize_) {
565 return elementsInline_ + inlineSize_;
566 } else {
567 return elementsVector_.data() + elementsVector_.size();
568 }
569 }
570
571 C10_NODISCARD const_iterator begin() const {
572 if (inlineSize_) {
573 return elementsInline_;
574 } else {
575 return elementsVector_.data();
576 }
577 }
578
579 C10_NODISCARD const_iterator end() const {
580 if (inlineSize_) {
581 return elementsInline_ + inlineSize_;
582 } else {
583 return elementsVector_.data() + elementsVector_.size();
584 }
585 }
586
587 C10_NODISCARD const_iterator cbegin() const {
588 return begin();
589 }
590
591 C10_NODISCARD const_iterator cend() const {
592 return end();
593 }
594
595 C10_NODISCARD std::vector<IValue> vec() const & {
596 return asArrayRef().vec();
597 }
598
599 C10_NODISCARD IValue& back() {
600 return *(end() - 1);
601 }
602
603 C10_NODISCARD const IValue& back() const {
604 return *(end() - 1);
605 }
606
607 C10_NODISCARD std::vector<IValue> vec() && {
608 std::vector<IValue> result;
609 result.reserve(size());
610 for (auto&& iv : *this) {
611 result.push_back(std::move(iv));
612 }
613 return result;
614 }
615
616 // More compatibility shims for the overwhelming amount of code that
617 // likes to copy tuple elements into a vector; see comment above the
618 // copy constructor.
619 operator std::vector<IValue>() const & {
620 return vec();
621 }
622
623 operator std::vector<IValue>() && {
624 return vec();
625 }
626};
627
628template <typename T>
629struct TupleTypeFactory {};
630
631template <>
632struct TORCH_API TupleTypeFactory<TupleType> {
633 static TupleTypePtr create(std::vector<TypePtr> types) {
634 return TupleType::create(std::move(types));
635 }
636 static TupleTypePtr fallback(const Type& type);
637};
638
639template <>
640struct TORCH_API TupleTypeFactory<c10::DynamicType> {
641 static DynamicTypePtr create(std::vector<TypePtr> elemTypes);
642 static DynamicTypePtr fallback(const Type&);
643};
644
645struct TORCH_API Tuple : c10::intrusive_ptr_target {
646 private:
647 TupleElements elements_;
648 mutable c10::TypePtr type_; // lazily computed for unnamed tuples
649
650 public:
651 // named tuples have additional type information, so we
652 // directly create them tagged
653 static c10::intrusive_ptr<Tuple> createNamed(
654 std::vector<IValue> elements_,
655 c10::TypePtr type_) {
656 return c10::make_intrusive<Tuple>(std::move(elements_), std::move(type_));
657 }
658
659 static c10::intrusive_ptr<Tuple> createNamed(
660 TupleElements elements_,
661 std::shared_ptr<TupleType> type_) {
662 return c10::make_intrusive<Tuple>(std::move(elements_), std::move(type_));
663 }
664
665 static c10::intrusive_ptr<Tuple> createNamed(
666 std::initializer_list<IValue> elements_,
667 std::shared_ptr<TupleType> type_) {
668 return createNamed(TupleElements(c10::ArrayRef<IValue>(elements_)), std::move(type_));
669 }
670
671 // MSVC apparently can't disambiguate the other two overloads of
672 // create when passed an initializer_list without this.
673 static c10::intrusive_ptr<Tuple> create(std::initializer_list<IValue> elements_) {
674 return create(c10::ArrayRef<IValue>(elements_));
675 }
676
677 static c10::intrusive_ptr<Tuple> create(std::vector<IValue> elements_) {
678 return c10::make_intrusive<Tuple>(std::move(elements_));
679 }
680
681 static c10::intrusive_ptr<Tuple> create(TupleElements elements_) {
682 return c10::make_intrusive<Tuple>(std::move(elements_));
683 }
684
685 static c10::intrusive_ptr<Tuple> create(c10::ArrayRef<IValue> elements_) {
686 return create(TupleElements(elements_));
687 }
688
689 static c10::intrusive_ptr<Tuple> create(IValue e1) {
690 return c10::make_intrusive<Tuple>(std::move(e1));
691 }
692
693 static c10::intrusive_ptr<Tuple> create(IValue e1, IValue e2) {
694 return c10::make_intrusive<Tuple>(std::move(e1), std::move(e2));
695 }
696
697 static c10::intrusive_ptr<Tuple> create(IValue e1, IValue e2, IValue e3) {
698 return c10::make_intrusive<Tuple>(std::move(e1), std::move(e2), std::move(e3));
699 }
700
701 private:
702 // Workaround inability to use `>` operator in template argument list.
703 template <typename... Args>
704 static constexpr bool hasMoreThanThreeArgs() {
705 return sizeof...(Args) > 3;
706 }
707
708 public:
709 template <typename... Args>
710 static c10::intrusive_ptr<Tuple> create(Args&&... elements_) {
711 switch (sizeof...(Args)) {
712 case 1:
713 case 2:
714 case 3:
715 return create(IValue(std::forward<Args>(elements_))...);
716 default:
717 return create(
718 std::vector<IValue>{IValue(std::forward<Args>(elements_))...});
719 }
720 }
721
722 // Again, it would be nice to make this noncopyable, but there's a
723 // lot of extant code that copies Tuples.
724 // Tuple(const Tuple& rhs) = delete;
725
726 const TupleElements& elements() const& {
727 return elements_;
728 }
729
730 TupleElements elements() && {
731 return std::move(elements_);
732 }
733
734 void setElements(std::vector<IValue>&& elements) {
735 elements_.setContents(std::move(elements));
736 }
737
738 void setElements(TupleElements&& elements) {
739 elements_ = std::move(elements);
740 }
741
742 void unsafeSetElement(size_t idx, const IValue& element) {
743 elements_[idx] = element;
744 }
745
746 void unsafeSetElement(size_t idx, IValue&& element) {
747 elements_[idx] = std::move(element);
748 }
749
750 size_t size() const {
751 return elements_.size();
752 }
753
754 template <typename T = c10::TupleType>
755 std::shared_ptr<T> type() const {
756 if (!type_) {
757 type_ = TupleTypeFactory<T>::create(fmap(elements(), [&](const IValue& v) {
758 return v.type<typename T::ElementType>();
759 }));
760 }
761 if (auto t = type_->cast<T>()) {
762 return t;
763 }
764 return TupleTypeFactory<T>::fallback(*type_);
765 }
766
767 static size_t hash(const Tuple& t) {
768 return c10::get_hash(t.elements());
769 }
770
771 TORCH_API friend bool operator==(
772 const ivalue::Tuple& lhs,
773 const ivalue::Tuple& rhs);
774
775 private:
776 // NOTE: If we try to avoid the overloads without
777 // `std::shared_ptr<TupleType> type` by defaulting it to nullptr, we
778 // end up having to call (part of) the shared_ptr destructor for
779 // `type` even though we should know statically it won't do
780 // anything.
781 explicit Tuple(std::vector<IValue> elements)
782 : elements_(std::move(elements)){}
783
784 explicit Tuple(std::vector<IValue> elements, c10::TypePtr type)
785 : elements_(std::move(elements)), type_(std::move(type)) {}
786
787 explicit Tuple(TupleElements&& elements)
788 : elements_(std::move(elements)) {}
789
790 explicit Tuple(TupleElements&& elements, std::shared_ptr<TupleType> type)
791 : elements_(std::move(elements)), type_(std::move(type)) {}
792
793 explicit Tuple(IValue&& e1)
794 : elements_(std::move(e1)) {}
795
796 explicit Tuple(IValue&& e1, std::shared_ptr<TupleType> type)
797 : elements_(std::move(e1)), type_(std::move(type)) {}
798
799 explicit Tuple(IValue&& e1, IValue&& e2)
800 : elements_(std::move(e1), std::move(e2)) {}
801
802 explicit Tuple(IValue&& e1, IValue&& e2, std::shared_ptr<TupleType> type)
803 : elements_(std::move(e1), std::move(e2)), type_(std::move(type)) {}
804
805 explicit Tuple(IValue&& e1, IValue&& e2, IValue&& e3)
806 : elements_(std::move(e1), std::move(e2), std::move(e3)) {}
807
808 explicit Tuple(IValue&& e1, IValue&& e2, IValue&& e3, std::shared_ptr<TupleType> type)
809 : elements_(std::move(e1), std::move(e2), std::move(e3)), type_(std::move(type)) {}
810
811 friend class c10::intrusive_ptr<Tuple>;
812};
813
814struct Object;
815struct PyObjectHolder;
816struct EnumHolder;
817} // namespace ivalue
818
819// Future
820struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
821 private:
822 // Keep this private in order to force users to go through make_intrusive and
823 // thus prevent creating a Future that's not held by an intrusive_ptr.
824 explicit Future(TypePtr type, std::vector<c10::Device> devices={})
825 : type_(std::move(type)),
826 impl_(getTypeOfDevices(devices)),
827 devices_(sortAndDeduplicateDevices(impl_, std::move(devices))) {}
828
829 friend c10::intrusive_ptr<Future>;
830
831 public:
832 Future(const Future&) = delete;
833 Future(Future&&) = delete;
834 Future& operator=(const Future&) = delete;
835 Future& operator=(Future&&) = delete;
836
837 struct TORCH_API FutureError final : public std::exception {
838 explicit FutureError(std::string&& error_msg_)
839 : error_msg(std::move(error_msg_)) {}
840
841 FutureError() = default;
842
843 const char* what() const noexcept override {
844 return error_msg.c_str();
845 }
846
847 std::string error_msg;
848 };
849
850 /**
851 * Wait on the future until it completes.
852 */
853 void wait() {
854 std::unique_lock<std::mutex> lock(mutex_);
855 finished_cv_.wait(lock, [&]() -> bool { return completed_; });
856 synchronizeWithCurrentStreams();
857 }
858
859 /**
860 * Wait on the future until it completes and throw an
861 * exception if an error exists.
862 */
863 void waitAndThrow() {
864 wait();
865
866 if (eptr_) {
867 std::rethrow_exception(eptr_);
868 }
869 }
870
871 /**
872 * Explicitly mark the future as completed with the output value. Optionally,
873 * the storages for all tensors in IValue can be passed as well. The DataPtrs
874 * of these storages are used to synchronize CUDA streams. If storages isn't
875 * given we will attempt to extract it from the value, if we need to (this
876 * happens if a non-empty set of devices was given to the constructor). Thus
877 * one only needs to provide storages when 1) they cannot be extracted through
878 * IValue::getSubValues() or through pickling in case of Python object; or
879 * when 2) customized storage extraction is more efficient.
880 */
881 using WeakStorage = c10::weak_intrusive_ptr<c10::StorageImpl>;
882 void markCompleted(
883 IValue value,
884 c10::optional<std::vector<WeakStorage>> storages = c10::nullopt) {
885 // Start by performing all steps that can throw, before setting any field.
886 // Do this before even acquiring the mutex, because extractStorages might
887 // acquire the GIL, which could lead to a lock inversion with our mutex.
888 // See https://github.com/pytorch/pytorch/issues/58239.
889 std::vector<WeakStorage> actualStorages;
890 std::vector<c10::Device> usedDevices;
891 try {
892 // FIXME We should always extract DataPtrs, in order to catch the case of
893 // users using CUDA values but forgetting to set devices, which currently
894 // leads to a silent synchronization/correctness issue. However, as this
895 // might worsen perf in CPU-only cases, we should only do so after careful
896 // benchmarks.
897 if (impl_.type() != c10::kCPU) {
898 actualStorages =
899 storages.has_value() ? std::move(*storages) : extractStorages(value);
900 usedDevices = getDevicesOfStorages(impl_, actualStorages);
901 ensureIsSubsetOfDevices(usedDevices, devices_);
902 }
903 } catch (const std::exception&) {
904 setError(std::current_exception());
905 return;
906 }
907
908 std::unique_lock<std::mutex> lock(mutex_);
909 TORCH_CHECK(
910 !completed(),
911 "Attempting to mark a completed Future as complete again. Note that "
912 "a Future can only be marked completed once.");
913
914 // Only set value_ and completed_ flag once all checks and preparation steps
915 // have returned successfully to allow for proper error propagation.
916 value_ = std::move(value);
917 completed_ = true;
918
919 currentDevice_ = impl_.getDevice();
920 storages_ = std::move(actualStorages);
921 for (const c10::Device& device : usedDevices) {
922 c10::Event event(impl_.type());
923 event.record(impl_.getStream(device));
924 events_.push_back(std::move(event));
925 }
926
927 std::vector<std::function<void(Future&)>> cbs;
928 cbs.swap(callbacks_);
929 lock.unlock();
930
931 finished_cv_.notify_all();
932 for (auto& callback : cbs) {
933 invokeCallback(std::move(callback));
934 }
935 }
936
937 void markCompleted() {
938 markCompleted(IValue{});
939 }
940
941 void setError(std::exception_ptr eptr) {
942 std::unique_lock<std::mutex> lock(mutex_);
943 setErrorInternal(std::move(eptr), lock);
944 }
945
946 void setErrorIfNeeded(std::exception_ptr eptr) {
947 std::unique_lock<std::mutex> lock(mutex_);
948 if (completed_) {
949 // This should be rare and shouldn't cause log spew. Its important to
950 // log errors and thats why we have this log here.
951 std::string msg = c10::str(
952 "Skipping setting following error on the Future since "
953 "it is already marked completed (this is not necessarily "
954 "an error):\n",
955 tryRetrieveErrorMessageInternal(std::move(eptr)));
956 if (eptr_) {
957 msg += c10::str(
958 ", \nOriginal exception:\n",
959 tryRetrieveErrorMessageInternal(eptr_));
960 }
961 LOG(INFO) << msg;
962 return;
963 } else {
964 setErrorInternal(std::move(eptr), lock);
965 }
966 }
967
968 // Get the result of the current future.
969 IValue value() {
970 std::unique_lock<std::mutex> lock(mutex_);
971 AT_ASSERT(completed());
972 if (eptr_) {
973 std::rethrow_exception(eptr_);
974 }
975 return value_;
976 }
977
978 // This accessor should only be used if we know that the future is
979 // completed() with no error.
980 const IValue& constValue() const {
981 std::unique_lock<std::mutex> lock(mutex_);
982 AT_ASSERT(completed());
983 TORCH_INTERNAL_ASSERT(
984 !eptr_,
985 "value() accessor should only be used when future is not completed with ",
986 "an error, but future had the following error: ",
987 tryRetrieveErrorMessageInternal(eptr_)
988 );
989 return value_;
990 }
991
992 // This accessor should only be used if we know that the future is
993 // completed() with no error.
994 const std::vector<WeakStorage>& storages() const {
995 std::unique_lock<std::mutex> lock(mutex_);
996 AT_ASSERT(completed());
997 AT_ASSERT(!eptr_);
998 return storages_;
999 }
1000
1001 /**
1002 * Add a callback to the future.
1003 * The callbacks will be executed once the future completes.
1004 * If the future has already completed,
1005 * this function will execute the callback immediately.
1006 */
1007 template <typename T>
1008 void addCallback(T callback) {
1009#if __cpp_lib_is_invocable >= 201703
1010 static_assert(
1011 std::is_invocable_r<void, T, Future&>::value,
1012 "The callback must have signature void(Future&)");
1013#endif
1014 std::unique_lock<std::mutex> lock(mutex_);
1015 if (completed()) {
1016 lock.unlock();
1017 invokeCallback(std::move(callback));
1018 return;
1019 }
1020 callbacks_.emplace_back(std::move(callback));
1021 }
1022
1023 /**
1024 * Add a callback to the future, and return another Future to hold the return
1025 * value of the callback. This is necessary when the callback provider needs
1026 * to know for sure when the callback has finished.
1027 */
1028 template <typename T>
1029 c10::intrusive_ptr<Future> then(T callback, TypePtr type) {
1030 using IValueWithStorages = std::tuple<IValue, std::vector<WeakStorage>>;
1031#if __cpp_lib_is_invocable >= 201703
1032 static_assert(
1033 guts::disjunction<
1034 std::is_invocable_r<IValue, T, Future&>,
1035 std::is_invocable_r<IValueWithStorages, T, Future&>>::value,
1036 "The callback must have signature IValue(Future&) or "
1037 "std::tuple<IValue, std::vector<Storage>>(Future&)");
1038#endif
1039 auto childFut = createInstance(std::move(type));
1040 addCallback([childFut,
1041 cb = std::move(callback)](Future& parentFut) mutable {
1042 try {
1043 guts::if_constexpr<std::is_convertible<
1044 typename c10::invoke_result_t<T &&, Future&>,
1045 IValueWithStorages>::value>(
1046 [&](auto identity) {
1047 IValue value;
1048 std::vector<WeakStorage> storages;
1049 std::tie(value, storages) = identity(cb)(parentFut);
1050 childFut->markCompleted(std::move(value), std::move(storages));
1051 },
1052 [&](auto identity) {
1053 childFut->markCompleted(identity(cb)(parentFut));
1054 });
1055 } catch (std::exception&) {
1056 childFut->setError(std::current_exception());
1057 }
1058 });
1059 return childFut;
1060 }
1061
1062 template <typename T>
1063 c10::intrusive_ptr<Future> thenAsync(T callback, TypePtr type) {
1064#if __cpp_lib_is_invocable >= 201703
1065 static_assert(
1066 std::is_invocable_r<c10::intrusive_ptr<Future>, T, Future&>::value,
1067 "The callback must have signature c10::intrusive_ptr<Future>(Future&)");
1068#endif
1069 auto childFut = createInstance(std::move(type));
1070 addCallback(
1071 [childFut, cb = std::move(callback)](Future& parentFut) mutable {
1072 c10::intrusive_ptr<Future> intermediateFut;
1073 try {
1074 intermediateFut = cb(parentFut);
1075 } catch (std::exception&) {
1076 childFut->setError(std::current_exception());
1077 return;
1078 }
1079 intermediateFut->addCallback(
1080 [childFut = std::move(childFut)](Future& intermediateFut) {
1081 if (intermediateFut.hasError()) {
1082 childFut->setError(intermediateFut.exception_ptr());
1083 } else {
1084 childFut->markCompleted(
1085 intermediateFut.value(), intermediateFut.storages());
1086 }
1087 });
1088 });
1089 return childFut;
1090 }
1091
1092 // Tries to retrieve the error message from std::exception_ptr.
1093 std::string tryRetrieveErrorMessage() const {
1094 TORCH_CHECK(hasError(), "No error present on the future.");
1095 std::unique_lock<std::mutex> lock(mutex_);
1096 return tryRetrieveErrorMessageInternal(eptr_);
1097 }
1098
1099 // Check if the current future has completed
1100 bool completed() const {
1101 return completed_;
1102 }
1103
1104 bool hasValue() const {
1105 std::unique_lock<std::mutex> lock(mutex_);
1106 return completed_ && !eptr_;
1107 }
1108
1109 bool hasError() const {
1110 std::unique_lock<std::mutex> lock(mutex_);
1111 return eptr_ ? true : false;
1112 }
1113
1114 std::exception_ptr exception_ptr() const {
1115 std::unique_lock<std::mutex> lock(mutex_);
1116 return eptr_;
1117 }
1118
1119 TORCH_API friend std::ostream& operator<<(
1120 std::ostream& out,
1121 const Future& v);
1122
1123 TypePtr elementType() const {
1124 return type_;
1125 }
1126
1127 const std::vector<c10::Device>& devices() const {
1128 return devices_;
1129 }
1130
1131 // This method should be used when one intends to manually create a child
1132 // future, for example when implementing a customized version of then().
1133 c10::intrusive_ptr<Future> createInstance(at::TypePtr type) {
1134 return c10::make_intrusive<Future>(std::move(type), devices_);
1135 }
1136
1137 private:
1138
1139 // This method should always be used when invoking a callback (regardless of
1140 // how/when that happens) as it will ensure that the proper "environment" is
1141 // set up before running the callback, as in, it will set up the CUDA streams,
1142 // synchronize them with the value, and so on (if needed).
1143 template<typename T>
1144 void invokeCallback(T callback) {
1145#if __cpp_lib_is_invocable >= 201703
1146 static_assert(
1147 std::is_invocable_r<void, T, Future&>::value,
1148 "The callback must have signature void(Future&)");
1149#endif
1150
1151 c10::OptionalDeviceGuard deviceGuard(currentDevice_);
1152
1153 std::vector<c10::Stream> streams;
1154 streams.reserve(devices_.size());
1155 for (const c10::Device& device : devices_) {
1156 streams.push_back(impl_.getStreamFromGlobalPool(device));
1157 }
1158 c10::MultiStreamGuard streamGuard(streams);
1159 synchronizeWithCurrentStreams();
1160
1161 callback(*this);
1162 }
1163
1164 // This method should be called before this future's value is used, as it
1165 // ensures that the CUDA streams that are "current" at the callsite properly
1166 // synchronize with the value.
1167 void synchronizeWithCurrentStreams() {
1168 for (c10::Event& event : events_) {
1169 event.block(impl_.getStream(event.device()));
1170 }
1171
1172 for (const WeakStorage& weak_storage : storages_) {
1173 c10::intrusive_ptr<c10::StorageImpl> storage = weak_storage.lock();
1174 if (!storage) {
1175 continue;
1176 }
1177 if (!storage->device().is_cpu()) {
1178 impl_.recordDataPtrOnStream(
1179 storage->data_ptr(), impl_.getStream(storage->device()));
1180 }
1181 }
1182 }
1183
1184 void setErrorInternal(
1185 std::exception_ptr eptr,
1186 std::unique_lock<std::mutex>& lock) {
1187 TORCH_CHECK(
1188 !eptr_,
1189 "Error already set on this Future: ",
1190 tryRetrieveErrorMessageInternal(eptr_),
1191 ", trying to set error: ",
1192 tryRetrieveErrorMessageInternal(eptr));
1193 TORCH_INTERNAL_ASSERT(!completed(), "Future is already marked completed");
1194 completed_ = true;
1195 eptr_ = std::move(eptr);
1196
1197 std::vector<std::function<void(Future&)>> cbs;
1198 cbs.swap(callbacks_);
1199 lock.unlock();
1200
1201 finished_cv_.notify_all();
1202 for (auto& callback : cbs) {
1203 invokeCallback(std::move(callback));
1204 }
1205 }
1206
1207 // Tries to retrieve the error message from std::exception_ptr.
1208 std::string tryRetrieveErrorMessageInternal(std::exception_ptr eptr) const {
1209 try {
1210 std::rethrow_exception(std::move(eptr));
1211 } catch (const std::exception& e) {
1212 return e.what();
1213 } catch (...) {
1214 return "Unknown Exception Type";
1215 }
1216 }
1217
1218 // Defined in ivalue.cpp.
1219 static std::vector<WeakStorage> extractStorages(
1220 const at::IValue& value);
1221
1222 static std::vector<c10::Device> getDevicesOfStorages(
1223 const c10::impl::VirtualGuardImpl& impl,
1224 const std::vector<WeakStorage>& storages) {
1225 c10::DeviceIndex deviceCount = impl.deviceCount();
1226 std::vector<bool> isDeviceUsed(deviceCount, false);
1227 for (const WeakStorage& weak_storage : storages) {
1228 c10::intrusive_ptr<c10::StorageImpl> storage = weak_storage.lock();
1229 if (!storage) {
1230 continue;
1231 }
1232 c10::Device device = storage->device();
1233 if (!device.is_cpu()) {
1234 TORCH_CHECK_VALUE(
1235 device.type() == impl.type(),
1236 "Expected all data ptrs to be on a device of type ",
1237 impl.type(),
1238 ", got one on device ",
1239 device);
1240 isDeviceUsed[device.index()] = true;
1241 }
1242 }
1243 std::vector<c10::Device> devices;
1244 for (c10::DeviceIndex idx = 0; idx < deviceCount; idx++) {
1245 if (isDeviceUsed[idx]) {
1246 devices.emplace_back(impl.type(), idx);
1247 }
1248 }
1249 return devices;
1250 }
1251
1252 static std::string formatSetOfDevices(
1253 const std::vector<c10::Device>& devices) {
1254 if (devices.empty()) {
1255 return "(none)";
1256 }
1257 std::ostringstream oss;
1258 oss << devices[0];
1259 for (const auto idx : c10::irange(1, devices.size())) {
1260 if (idx == devices.size() - 1) {
1261 oss << " and ";
1262 } else {
1263 oss << ", ";
1264 }
1265 oss << devices[idx];
1266 }
1267 return oss.str();
1268 }
1269
1270 static c10::DeviceType getTypeOfDevices(
1271 const std::vector<c10::Device>& devices) {
1272 if (devices.empty()) {
1273 return c10::kCPU;
1274 }
1275 c10::DeviceType deviceType = devices[0].type();
1276 for (const auto idx : c10::irange(1, devices.size())) {
1277 TORCH_CHECK_VALUE(
1278 devices[idx].type() == deviceType,
1279 "Expected all devices to be of the same type, but got a mismatch between ",
1280 devices[0],
1281 " and ",
1282 devices[idx]);
1283 }
1284 return deviceType;
1285 }
1286
1287 // We need devices to be sorted in order to use ensureIsSubsetOfDevices.
1288 static std::vector<c10::Device> sortAndDeduplicateDevices(
1289 const c10::impl::VirtualGuardImpl& /*impl*/,
1290 std::vector<c10::Device> devices) {
1291 std::sort(
1292 devices.begin(), devices.end(),
1293 [](const c10::Device& a, const c10::Device& b) { return a.index() < b.index(); });
1294 // Deduplicate by compacting.
1295 size_t targetIdx = 0;
1296 for (const auto sourceIdx : c10::irange(devices.size())) {
1297 TORCH_CHECK_VALUE(
1298 devices[sourceIdx].has_index(),
1299 "Expected devices to have indices, got ", devices[sourceIdx]);
1300 if (targetIdx > 0 && devices[targetIdx - 1].index() == devices[sourceIdx].index()) {
1301 // It's a duplicate, skip it.
1302 continue;
1303 }
1304 if (sourceIdx != targetIdx) {
1305 devices[targetIdx] = devices[sourceIdx];
1306 }
1307 targetIdx++;
1308 }
1309 // If there were duplicates there's now a gap at the end: trim it. Resizing
1310 // requires the item type to be default-constructible (which c10::Device is
1311 // not) because in principle it could be required to create new items. Since
1312 // we know we'll shrink the vector, we provide a custom dummy value instead.
1313 devices.resize(targetIdx, c10::Device(c10::kCPU));
1314 return devices;
1315 }
1316
1317 static void ensureIsSubsetOfDevices(
1318 const std::vector<c10::Device>& subset,
1319 const std::vector<c10::Device>& superset) {
1320 // We assume the devices in both vectors have the same consistent type, and
1321 // their indices are unique and sorted.
1322 std::vector<c10::Device> excessDevices;
1323 std::set_difference(
1324 subset.begin(),
1325 subset.end(),
1326 superset.begin(),
1327 superset.end(),
1328 std::back_inserter(excessDevices),
1329 [](const c10::Device& a, const c10::Device& b) { return a.index() < b.index(); });
1330 TORCH_CHECK_VALUE(
1331 excessDevices.empty(),
1332 "The result contained tensors residing on device(s) ",
1333 formatSetOfDevices(excessDevices),
1334 " which are not among the expected device(s) ",
1335 formatSetOfDevices(superset));
1336 }
1337
1338 mutable std::mutex mutex_;
1339 std::atomic_bool completed_ = {false}; // is this future complete
1340 std::condition_variable finished_cv_;
1341
1342 IValue value_; // when finished the value
1343 TypePtr type_;
1344 std::vector<std::function<void(Future&)>> callbacks_;
1345 std::exception_ptr eptr_;
1346
1347 // An upcast pointer to a virtual class which allows us to manipulate events,
1348 // streams, ... in a generic way, without an explicit dependency on CUDA.
1349 const c10::impl::VirtualGuardImpl impl_;
1350
1351 // The device that was current when markCompleted was called, which we'll
1352 // restore when invoking callbacks. It's optional because we'll only store it
1353 // if the future completes successfully.
1354 optional<c10::Device> currentDevice_;
1355
1356 // The events that correspond to the completion of the async I/O kernels. They
1357 // are recorded on the appropriate streams when the future is marked completed
1358 // and can then be queried/waited/blocked on. There is one event for each
1359 // distinct device on which the value's tensors reside.
1360 std::vector<c10::Event> events_;
1361
1362 // A cached version of the storages extracted from the value when the future
1363 // is first marked completed.
1364 std::vector<WeakStorage> storages_;
1365
1366 // The bounding set of devices that this future, and any of its children, is
1367 // allowed to use. This is a superset of the set of devices used by the events
1368 // above. We need this to know what streams (for which devices) to set as
1369 // current when invoking a callback, thus allowing the callback to use devices
1370 // that the parent future didn't use. This field is set to the value provided
1371 // in the constructor and will be "inherited" by all child futures.
1372 const std::vector<c10::Device> devices_;
1373};
1374
1375struct C10_EXPORT ivalue::Await final : c10::intrusive_ptr_target {
1376 private:
1377 explicit Await(TypePtr elType, std::function<IValue()> fn)
1378 : elType_(std::move(elType)), type_(AwaitType::create(elType_)), fn_(std::move(fn)) {}
1379
1380 explicit Await(TypePtr elType) : elType_(std::move(elType)), type_(AwaitType::create(elType_)) { }
1381
1382 friend c10::intrusive_ptr<Await>;
1383
1384 public:
1385 Await(const Await&) = delete;
1386 Await(Await&&) = delete;
1387 Await& operator=(const Await&) = delete;
1388 Await& operator=(Await&&) = delete;
1389
1390 IValue wait() {
1391 if (!completed_) {
1392 TORCH_CHECK(fn_, "Incompleted Await: fn can't be None");
1393 value_ = fn_();
1394 completed_ = true;
1395 args_ = {};
1396 }
1397 return value_;
1398 }
1399
1400 IValue value() {
1401 TORCH_CHECK(completed_, "Await must be completed");
1402 return value_;
1403 }
1404
1405 void setFn(std::function<IValue()> fn) {
1406 fn_ = std::move(fn);
1407 }
1408
1409 bool completed() {
1410 return completed_;
1411 }
1412
1413 void markCompleted(IValue value) {
1414 value_ = std::move(value);
1415 completed_ = true;
1416 }
1417
1418 TORCH_API friend std::ostream& operator<<(
1419 std::ostream& out,
1420 const Await& v);
1421
1422 TypePtr elementType() const {
1423 return elType_;
1424 }
1425
1426 TypePtr type() const {
1427 return type_;
1428 }
1429
1430 void setArgs(std::vector<IValue> args) {
1431 args_ = std::move(args);
1432 }
1433
1434 std::vector<IValue>& args() {
1435 return args_;
1436 }
1437
1438 private:
1439 TypePtr elType_;
1440 TypePtr type_;
1441 std::vector<IValue> args_;
1442 std::function<IValue()> fn_;
1443 IValue value_;
1444 bool completed_{};
1445};
1446
1447// Input is a list of Futures with the same target type.
1448// Output is a Future to the List of completed Futures.
1449TORCH_API intrusive_ptr<ivalue::Future> collectAll(
1450 c10::List<c10::intrusive_ptr<ivalue::Future>> srcs);
1451// Input is a List of Futures with the same target type.
1452// Output is a Future that will be updated with a seen value.
1453TORCH_API intrusive_ptr<ivalue::Future> collectAny(
1454 c10::List<c10::intrusive_ptr<ivalue::Future>> srcs);
1455
1456// User-defined object.
1457struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
1458 public:
1459 // In general, class types hold a shared_ptr to its owning CompilationUnit,
1460 // so that its type and methods do not get deallocated while the class exists.
1461 // However, the CompilationUnit holds ownership of the type's graphs, so
1462 // inserting a constant object into a Graph would create a reference cycle if
1463 // that constant object held a shared_ptr to its CU. For these objects we
1464 // instatiate them with non-owning references to its CU
1465 Object(WeakOrStrongTypePtr type, size_t numSlots) : type_(std::move(type)) {
1466 slots_.resize(numSlots);
1467 }
1468
1469 Object(StrongTypePtr type, size_t numSlots)
1470 : type_(WeakOrStrongTypePtr(std::move(type))) {
1471 slots_.resize(numSlots);
1472 }
1473
1474 static c10::intrusive_ptr<Object> create(
1475 WeakOrStrongTypePtr type,
1476 size_t numSlots) {
1477 return c10::make_intrusive<Object>(std::move(type), numSlots);
1478 }
1479
1480 static c10::intrusive_ptr<Object> create(
1481 StrongTypePtr type,
1482 size_t numSlots) {
1483 return c10::make_intrusive<Object>(std::move(type), numSlots);
1484 }
1485
1486 static c10::intrusive_ptr<Object> create(ClassTypePtr classType, size_t numSlots);
1487
1488 /**
1489 * Slot API.
1490 *
1491 * Attributes are stored as a simple vector so that lookups are fast at
1492 * runtime. A "slot" is just an index into that vector, which can be computed
1493 * statically if you have access to the class type. Use this API if you are
1494 * writing compiler stuff.
1495 */
1496 void setSlot(size_t slot, IValue v) {
1497 if (slot >= slots_.size()) {
1498 // for module types, it is possible that the members of the class have
1499 // expanded after the object was created. In this case, we expand
1500 // the slots to the right size
1501 resizeObject(slot);
1502 }
1503 slots_[slot] = std::move(v);
1504 }
1505
1506 const IValue& getSlot(size_t slot) const {
1507 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(slot < slots_.size());
1508 // NOTE: This lookup is fairly hot, so we use unchecked access to the
1509 // vector. Errors should still be detectable with ASan.
1510 return slots_[slot];
1511 }
1512
1513 void unsafeRemoveSlot(size_t slot) {
1514 TORCH_CHECK(slot < slots_.size());
1515 slots_.erase(slots_.begin() + slot);
1516 }
1517
1518 /**
1519 * Attribute API.
1520 *
1521 * Wrappers around the slot stuff so that users can access attributes
1522 * directly. Use this API if you are a user.
1523 *
1524 * Note: Unlike in Python, TorchScript must make a distinction between
1525 * attributes (which are IValues) and methods (which are Methods). If you
1526 * want a method, use `obj.type()->getMethod()`
1527 */
1528 IValue getAttr(const std::string& name) const;
1529 void setAttr(const std::string& name, IValue v);
1530 // Remove attribute by name, caller is responsible for
1531 // the safety of this operation
1532 // We didn't remove the attribute in the type because the type
1533 // might be shared by multiple objects.
1534 // Therefore after removing attribute, the object is in an inconsistent
1535 // state where it has more attribute types in its Type than
1536 // the attribute slots it has, user needs to make sure the object
1537 // has consistent by removing the attribute in type as well
1538 void unsafeRemoveAttr(const std::string& name);
1539
1540 std::string name() const;
1541
1542 const std::vector<IValue>& slots() const {
1543 return slots_;
1544 }
1545 std::shared_ptr<ClassType> type() const;
1546
1547 std::shared_ptr<torch::jit::CompilationUnit> compilation_unit() {
1548 if (type_.holds_strong_ref()) {
1549 return type_.cu_.getStrongRefOrThrow();
1550 } else {
1551 auto weak_ptr = type_.cu_.getWeakRefOrThrow();
1552 return std::shared_ptr<torch::jit::CompilationUnit>(weak_ptr);
1553 }
1554 }
1555
1556 c10::intrusive_ptr<Object> copy_to_weak_compilation_ref() const;
1557
1558 void unsafe_make_weak_compilation_ref() {
1559 type_ = WeakOrStrongTypePtr(type_.asWeakTypePtr());
1560 }
1561
1562 c10::intrusive_ptr<Object> copy() const;
1563
1564 c10::intrusive_ptr<Object> deepcopy() const;
1565
1566 c10::intrusive_ptr<Object> deepcopy(IValue::HashAliasedIValueMap& memo) const;
1567
1568 bool is_weak_compilation_ref() const {
1569 return !type_.holds_strong_ref();
1570 }
1571
1572 bool is_empty_strong_compilation_ref() const {
1573 return type_.holds_empty_strong_ref();
1574 }
1575
1576 private:
1577 void resizeObject(size_t slot);
1578 WeakOrStrongTypePtr type_;
1579 std::vector<IValue> slots_;
1580};
1581
1582// virtual ivalue PyObjectHolder that hold a py::object, we make this virtual
1583// because the py::object and refcounting logic should happen in libtorch_python
1584// see concrete implementation in python_ivalue.h
1585struct ivalue::PyObjectHolder : c10::intrusive_ptr_target {
1586 public:
1587 virtual PyObject* getPyObject() = 0;
1588 virtual c10::InferredType tryToInferType() = 0;
1589 virtual IValue toIValue(const TypePtr& type, c10::optional<int32_t> N = c10::nullopt) = 0;
1590 virtual std::string toStr() = 0;
1591 virtual std::vector<at::Tensor> extractTensors() = 0;
1592
1593 ~PyObjectHolder() override = default;
1594};
1595
1596struct ivalue::EnumHolder : c10::intrusive_ptr_target {
1597 public:
1598 EnumHolder(std::shared_ptr<EnumType> type, std::string name, IValue value)
1599 : type_(std::move(type)),
1600 name_(std::move(name)),
1601 value_(std::move(value)) {}
1602
1603 bool is(const ivalue::EnumHolder& rhs) {
1604 return *this == rhs;
1605 }
1606
1607 friend bool operator==(
1608 const ivalue::EnumHolder& lhs,
1609 const ivalue::EnumHolder& rhs);
1610
1611 TORCH_API friend std::ostream& operator<<(
1612 std::ostream& out,
1613 const EnumHolder& v);
1614
1615 TORCH_API const std::string qualifiedClassName() const;
1616
1617 const std::string unqualifiedClassName() const;
1618
1619 const std::string& name() const {
1620 return name_;
1621 }
1622
1623 const IValue& value() const {
1624 return value_;
1625 }
1626
1627 std::shared_ptr<EnumType> type() const {
1628 return type_;
1629 }
1630
1631 private:
1632 std::shared_ptr<EnumType> type_;
1633 std::string name_;
1634 IValue value_;
1635};
1636
1637#undef TORCH_FORALL_TAGS
1638
1639namespace detail {
1640
1641struct _guarded_unsigned_long_unique_dummy final {
1642 _guarded_unsigned_long_unique_dummy(int64_t){};
1643};
1644using _guarded_unsigned_long = std::conditional_t<
1645 std::is_same<unsigned long, uint32_t>::value ||
1646 std::is_same<unsigned long, uint64_t>::value,
1647 _guarded_unsigned_long_unique_dummy,
1648 unsigned long>;
1649
1650} // namespace detail
1651
1652inline ivalue::Object& IValue::toObjectRef() const {
1653 AT_ASSERT(isObject(), "Expected Object but got ", tagKind());
1654 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "Attempted to create null reference");
1655 return *static_cast<c10::ivalue::Object*>(payload.u.as_intrusive_ptr);
1656}
1657
1658// note: when adding a DEFINE_TO case here you should also add a
1659// toX method to IValue. These named methods are much more discoverable
1660// than the to templated function.
1661
1662#define DEFINE_TO(T, method_name) \
1663 template <> \
1664 inline T IValue::to<T>()&& { \
1665 return static_cast<T>(std::move(*this).method_name()); \
1666 } \
1667 template <> \
1668 inline c10::detail::ivalue_to_const_ref_overload_return<T>::type IValue::to<T>() const& { \
1669 typedef c10::detail::ivalue_to_const_ref_overload_return<T>::type return_type; \
1670 return static_cast<return_type>(this->method_name()); \
1671 }
1672
1673DEFINE_TO(at::Tensor, toTensor)
1674DEFINE_TO(at::Storage, toStorage)
1675DEFINE_TO(c10::Stream, toStream)
1676DEFINE_TO(float, toDouble)
1677DEFINE_TO(double, toDouble)
1678DEFINE_TO(c10::complex<double>, toComplexDouble)
1679DEFINE_TO(unsigned char, toInt)
1680DEFINE_TO(signed char, toInt)
1681DEFINE_TO(unsigned short, toInt)
1682DEFINE_TO(short, toInt)
1683DEFINE_TO(int, toInt)
1684DEFINE_TO(uint32_t, toInt)
1685DEFINE_TO(uint64_t, toInt)
1686DEFINE_TO(detail::_guarded_unsigned_long, toInt)
1687DEFINE_TO(int64_t, toInt)
1688DEFINE_TO(bool, toBool)
1689DEFINE_TO(c10::intrusive_ptr<caffe2::Blob>, toBlob);
1690DEFINE_TO(c10::intrusive_ptr<ivalue::ConstantString>, toString)
1691DEFINE_TO(c10::intrusive_ptr<ivalue::Object>, toObject)
1692DEFINE_TO(at::Scalar, toScalar)
1693DEFINE_TO(c10::List<int64_t>, toIntList)
1694DEFINE_TO(c10::List<double>, toDoubleList)
1695DEFINE_TO(c10::List<c10::complex<double>>, toComplexDoubleList)
1696DEFINE_TO(c10::List<bool>, toBoolList)
1697DEFINE_TO(c10::List<at::Tensor>, toTensorList)
1698DEFINE_TO(c10::impl::GenericList, toList)
1699DEFINE_TO(c10::impl::GenericDict, toGenericDict)
1700DEFINE_TO(c10::intrusive_ptr<ivalue::Tuple>, toTuple)
1701DEFINE_TO(std::string, toStringRef)
1702DEFINE_TO(c10::string_view, toStringView)
1703DEFINE_TO(c10::intrusive_ptr<ivalue::Future>, toFuture)
1704DEFINE_TO(c10::intrusive_ptr<ivalue::Await>, toAwait)
1705DEFINE_TO(c10::intrusive_ptr<c10::RRefInterface>, toRRef)
1706DEFINE_TO(c10::intrusive_ptr<at::Quantizer>, toQuantizer)
1707DEFINE_TO(IValue, toIValue)
1708DEFINE_TO(c10::Device, toDevice)
1709DEFINE_TO(at::ScalarType, toScalarType)
1710DEFINE_TO(at::Layout, toLayout)
1711DEFINE_TO(at::MemoryFormat, toMemoryFormat)
1712DEFINE_TO(at::QScheme, toQScheme)
1713DEFINE_TO(at::Dimname, toDimname)
1714DEFINE_TO(at::Generator, toGenerator)
1715DEFINE_TO(c10::SymInt, toSymInt)
1716DEFINE_TO(c10::SymFloat, toSymFloat)
1717
1718template <class T>
1719struct _fake_type {};
1720
1721// generic_to<T> converts an IValue from a generic list or generic dict
1722// to a concrete list/dict type likelike List<T>, Dict<...> or optional<T>.
1723// Note that in the case of lists, this only works for IValue-based lists,
1724// i.e. not for int64_t, double, ...
1725// generic_to<T> is an implementation detail of IValue::to<T> and not
1726// supposed to be called directly.
1727// The _fake_type<T> parameter allows us to overload
1728// based on the return type.
1729template <class Elem>
1730// TODO this is deprecated but we don't throw a warning because a lot of ops in
1731// native_functions.yaml still return std::vector.
1732// C10_DEPRECATED_MESSAGE("IValues based on std::vector<T> are potentially slow
1733// and deprecated. Please use torch::List<T> instead.")
1734std::vector<Elem> generic_to(IValue ivalue, _fake_type<std::vector<Elem>>) {
1735 // We need to do a deep copy of the vector because there might be other
1736 // references to this same IValue that also use the list. We can't just
1737 // move the elements out.
1738 auto list = std::move(ivalue).to<List<Elem>>();
1739 std::vector<Elem> result;
1740 result.reserve(list.size());
1741 for (Elem v : list) {
1742 result.push_back(std::move(v));
1743 }
1744 return result;
1745}
1746
1747template <typename T>
1748c10::intrusive_ptr<T> IValue::toCustomClass() && {
1749 static_assert(
1750 std::is_base_of<torch::CustomClassHolder, T>::value == true,
1751 "toCustomClass requires that template parameter T must inherit "
1752 "from torch::CustomClassHolder");
1753 auto obj = toObject();
1754 TORCH_CHECK(
1755 obj->slots().size() == 1,
1756 "Tried to cast IValue to custom class but it did "
1757 "not contain a custom class!");
1758 const auto* expected_type = c10::getCustomClassType<c10::intrusive_ptr<T>>().get();
1759 ivalue::checkCustomClassType(expected_type, type().get());
1760 auto userObj =
1761 c10::static_intrusive_pointer_cast<T>(obj->getSlot(0).toCapsule());
1762 return userObj;
1763}
1764
1765template <typename T>
1766c10::intrusive_ptr<T> IValue::toCustomClass() const& {
1767 static_assert(
1768 std::is_base_of<torch::CustomClassHolder, T>::value == true,
1769 "toCustomClass requires that template parameter T must inherit "
1770 "from torch::CustomClassHolder");
1771 auto obj = toObject();
1772 TORCH_CHECK(
1773 obj->slots().size() == 1,
1774 "Tried to cast IValue to custom class but it did "
1775 "not contain a custom class!");
1776 const auto* expected_type = c10::getCustomClassType<c10::intrusive_ptr<T>>().get();
1777 ivalue::checkCustomClassType(expected_type, type().get());
1778 auto userObj =
1779 c10::static_intrusive_pointer_cast<T>(obj->getSlot(0).toCapsule());
1780 return userObj;
1781}
1782
1783template <typename T>
1784T generic_to(IValue ivalue, _fake_type<T>) {
1785 using ElemType = typename std::remove_pointer<T>::type::element_type;
1786 return std::move(ivalue).toCustomClass<ElemType>();
1787}
1788
1789template <typename T>
1790tagged_capsule<T> generic_to(IValue ivalue, _fake_type<tagged_capsule<T>>) {
1791 return tagged_capsule<T>{std::move(ivalue)};
1792}
1793
1794template <typename Elem>
1795c10::List<Elem> generic_to(IValue ivalue, _fake_type<c10::List<Elem>>) {
1796 return impl::toTypedList<Elem>(std::move(ivalue).toList());
1797}
1798
1799template <typename T>
1800static T createVectorLikeFromList(const c10::detail::ListImpl* impl) {
1801 T result;
1802 result.reserve(impl->list.size());
1803 for (const auto & i : impl->list) {
1804 result.push_back(i.to<typename T::value_type>());
1805 }
1806 return result;
1807}
1808
1809template <typename T>
1810static std::vector<T> createVectorFromList(const c10::detail::ListImpl* impl) {
1811 return createVectorLikeFromList<std::vector<T>>(impl);
1812}
1813
1814template <typename T>
1815std::vector<T> createVectorFromList(const c10::List<T>& impl) {
1816 std::vector<T> result;
1817 result.reserve(impl.size());
1818 for (size_t i = 0, N = impl.size(); i < N; ++i) {
1819 result.push_back(impl[i]);
1820 }
1821 return result;
1822}
1823
1824template <typename T>
1825OptionalArray<T> generic_to(IValue ivalue, _fake_type<OptionalArray<T>>) {
1826 if (ivalue.isNone()) {
1827 return {};
1828 }
1829 return createVectorFromList<T>(
1830 std::move(ivalue).to<c10::List<T>>()
1831 );
1832}
1833
1834namespace detail {
1835template <typename Elem, size_t... I>
1836std::array<Elem, sizeof...(I)> generic_to_array(
1837 IValue ivalue,
1838 _fake_type<std::array<Elem, sizeof...(I)>>,
1839 std::index_sequence<I...>) {
1840 // We need to do a deep copy of the array because there might be other
1841 // references to this same IValue that also use the list. We can't just
1842 // move the elements out.
1843 auto list = std::move(ivalue).to<List<Elem>>();
1844 TORCH_CHECK(
1845 list.size() == sizeof...(I),
1846 "Tried to convert a List with ",
1847 list.size(),
1848 " elements to a fixed-size array of size ",
1849 sizeof...(I));
1850 return {list[I]...};
1851}
1852} // namespace detail
1853
1854template <typename Elem, size_t N>
1855std::array<Elem, N> generic_to(
1856 IValue ivalue,
1857 _fake_type<std::array<Elem, N>> ft) {
1858 return detail::generic_to_array(ivalue, ft, std::make_index_sequence<N>());
1859}
1860
1861template <typename Key, typename Value>
1862c10::Dict<Key, Value> generic_to(
1863 IValue ivalue,
1864 _fake_type<c10::Dict<Key, Value>>) {
1865 return impl::toTypedDict<Key, Value>(std::move(ivalue).toGenericDict());
1866}
1867
1868template <typename K, typename V>
1869C10_DEPRECATED_MESSAGE(
1870 "IValues based on std::unordered_map are slow and deprecated. Please use c10::Dict<K, V> instead.")
1871std::unordered_map<K, V> generic_to(
1872 IValue ivalue,
1873 _fake_type<std::unordered_map<K, V>>) {
1874 std::unordered_map<K, V> specialized_dict;
1875
1876 for (const auto& item : std::move(ivalue).toGenericDict()) {
1877 specialized_dict[item.key().template to<K>()] = item.value().template to<V>();
1878 }
1879
1880 return specialized_dict;
1881}
1882
1883template <typename T>
1884c10::optional<T> generic_to(IValue ivalue, _fake_type<c10::optional<T>>) {
1885 if (ivalue.isNone()) {
1886 return c10::nullopt;
1887 }
1888 return std::move(ivalue).to<T>();
1889}
1890
1891namespace detail {
1892template <typename Tuple, std::size_t... INDEX>
1893Tuple generic_to_tuple_impl(
1894 const ivalue::TupleElements& t,
1895 std::index_sequence<INDEX...>) {
1896 return std::make_tuple(
1897 t[INDEX].to<typename std::tuple_element<INDEX, Tuple>::type>()...);
1898}
1899} // namespace detail
1900
1901template <
1902 typename... Args,
1903 typename Indices = std::make_index_sequence<sizeof...(Args)>,
1904 std::enable_if_t<
1905 !guts::disjunction<
1906 std::is_lvalue_reference<Args>...,
1907 guts::negation<std::is_constructible<IValue, Args>>...>::value,
1908 std::nullptr_t> = nullptr>
1909std::tuple<Args...> generic_to(IValue ivalue, _fake_type<std::tuple<Args...>>) {
1910 const auto& vals = ivalue.toTupleRef().elements();
1911 TORCH_CHECK(vals.size() == sizeof...(Args));
1912 return detail::generic_to_tuple_impl<std::tuple<Args...>>(vals, Indices{});
1913}
1914
1915template <typename T>
1916inline T IValue::to() && {
1917 return generic_to(std::move(*this), _fake_type<T>{});
1918}
1919
1920template <>
1921inline c10::optional<c10::string_view> IValue::to() && {
1922 // In the default implementation, the IValue is destroyed with std::move.
1923 // But if the unboxed type is optional<string_view> we cannot destroy
1924 // the IValue.
1925 return generic_to(*this, _fake_type<c10::optional<c10::string_view>>{});
1926}
1927
1928template <typename T>
1929inline typename c10::detail::ivalue_to_const_ref_overload_return<T>::type IValue::to() const& {
1930 return generic_to(*this, _fake_type<T>{});
1931}
1932
1933inline c10::List<int64_t> IValue::toIntList() && {
1934 AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
1935 return c10::List<int64_t>(moveToIntrusivePtr<c10::detail::ListImpl>());
1936}
1937inline c10::List<int64_t> IValue::toIntList() const& {
1938 AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
1939 return c10::List<int64_t>(toIntrusivePtr<c10::detail::ListImpl>());
1940}
1941inline std::vector<int64_t> IValue::toIntVector() const {
1942 AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
1943 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1944 payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
1945 "called toIntVector on null intrusive_ptr IValue");
1946 return createVectorFromList<int64_t>(
1947 static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
1948}
1949inline at::DimVector IValue::toDimVector() const {
1950 AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
1951 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1952 payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
1953 "called toDimVector on null intrusive_ptr IValue");
1954 return createVectorLikeFromList<at::DimVector>(
1955 static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
1956}
1957inline c10::List<double> IValue::toDoubleList() && {
1958 AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind());
1959 return c10::List<double>(moveToIntrusivePtr<c10::detail::ListImpl>());
1960}
1961inline c10::List<double> IValue::toDoubleList() const& {
1962 AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind());
1963 return c10::List<double>(toIntrusivePtr<c10::detail::ListImpl>());
1964}
1965inline std::vector<double> IValue::toDoubleVector() const {
1966 AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind());
1967 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1968 payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
1969 "called toDoubleVector on null intrusive_ptr IValue");
1970 return createVectorFromList<double>(
1971 static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
1972}
1973inline c10::List<c10::complex<double>> IValue::toComplexDoubleList() && {
1974 AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got ", tagKind());
1975 return c10::List<c10::complex<double>>(moveToIntrusivePtr<c10::detail::ListImpl>());
1976}
1977inline c10::List<c10::complex<double>> IValue::toComplexDoubleList() const& {
1978 AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got ", tagKind());
1979 return c10::List<c10::complex<double>>(toIntrusivePtr<c10::detail::ListImpl>());
1980}
1981inline std::vector<c10::complex<double>> IValue::toComplexDoubleVector() const {
1982 AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got ", tagKind());
1983 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1984 payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
1985 "called toComplexDoubleVector on null intrusive_ptr IValue");
1986 return createVectorFromList<c10::complex<double>>(
1987 static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
1988}
1989inline c10::List<bool> IValue::toBoolList() && {
1990 AT_ASSERT(isBoolList(), "Expected BoolList but got ", tagKind());
1991 return c10::List<bool>(moveToIntrusivePtr<c10::detail::ListImpl>());
1992}
1993inline c10::List<bool> IValue::toBoolList() const& {
1994 AT_ASSERT(isBoolList(), "Expected BoolList but got ", tagKind());
1995 return c10::List<bool>(toIntrusivePtr<c10::detail::ListImpl>());
1996}
1997inline c10::List<at::Tensor> IValue::toTensorList() && {
1998 AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind());
1999 return c10::List<at::Tensor>(moveToIntrusivePtr<c10::detail::ListImpl>());
2000}
2001inline c10::List<at::Tensor> IValue::toTensorList() const& {
2002 AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind());
2003 return c10::List<at::Tensor>(toIntrusivePtr<c10::detail::ListImpl>());
2004}
2005inline std::vector<at::Tensor> IValue::toTensorVector() const {
2006 AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind());
2007 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2008 payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2009 "called toTensorVector on null intrusive_ptr IValue");
2010 return createVectorFromList<at::Tensor>(
2011 static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
2012}
2013inline c10::List<c10::optional<at::Tensor>> IValue::toOptionalTensorList() && {
2014 AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got ", tagKind());
2015 return c10::List<c10::optional<at::Tensor>>(moveToIntrusivePtr<c10::detail::ListImpl>());
2016}
2017inline c10::List<c10::optional<at::Tensor>> IValue::toOptionalTensorList() const& {
2018 AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got ", tagKind());
2019 return c10::List<c10::optional<at::Tensor>>(toIntrusivePtr<c10::detail::ListImpl>());
2020}
2021inline std::vector<c10::optional<at::Tensor>> IValue::toOptionalTensorVector() const {
2022 AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got ", tagKind());
2023 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2024 payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2025 "called toOptionalTensorVector on null intrusive_ptr IValue");
2026 return createVectorFromList<c10::optional<at::Tensor>>(
2027 static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
2028}
2029inline c10::List<IValue> IValue::toList() && {
2030 AT_ASSERT(isList(), "Expected GenericList but got ", tagKind());
2031 return c10::List<IValue>(moveToIntrusivePtr<c10::detail::ListImpl>());
2032}
2033inline c10::List<IValue> IValue::toList() const& {
2034 AT_ASSERT(isList(), "Expected GenericList but got ", tagKind());
2035 return c10::List<IValue>(toIntrusivePtr<c10::detail::ListImpl>());
2036}
2037inline c10::ArrayRef<IValue> IValue::toListRef() const {
2038 AT_ASSERT(isList(), "Expected GenericList but got ", tagKind());
2039 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2040 payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2041 "called toListRef on null intrusive_ptr IValue");
2042 return static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr)
2043 ->list;
2044}
2045inline c10::Dict<IValue, IValue> IValue::toGenericDict() && {
2046 AT_ASSERT(isGenericDict(), "Expected GenericDict but got ", tagKind());
2047 return c10::Dict<IValue, IValue>(moveToIntrusivePtr<c10::detail::DictImpl>());
2048}
2049inline c10::Dict<IValue, IValue> IValue::toGenericDict() const& {
2050 AT_ASSERT(isGenericDict(), "Expected GenericDict but got ", tagKind());
2051 return c10::Dict<IValue, IValue>(toIntrusivePtr<c10::detail::DictImpl>());
2052}
2053inline c10::intrusive_ptr<ivalue::Tuple> IValue::toTuple() && {
2054 AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind());
2055 return moveToIntrusivePtr<ivalue::Tuple>();
2056}
2057inline c10::intrusive_ptr<ivalue::Tuple> IValue::toTuple() const& {
2058 AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind());
2059 return toIntrusivePtr<ivalue::Tuple>();
2060}
2061inline ivalue::Tuple& IValue::toTupleRef() const {
2062 AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind());
2063 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2064 payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2065 "called toTupleRef on null intrusive_ptr IValue");
2066 return *static_cast<c10::ivalue::Tuple*>(
2067 payload.u.as_intrusive_ptr);
2068}
2069
2070inline IValue::IValue(c10::intrusive_ptr<ivalue::Tuple> v)
2071 : tag(Tag::Tuple) {
2072 payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2073}
2074template <
2075 typename... Args,
2076 std::enable_if_t<
2077 !guts::disjunction<
2078 std::is_lvalue_reference<Args>...,
2079 guts::negation<std::is_constructible<IValue, Args>>...>::value,
2080 std::nullptr_t>>
2081inline IValue::IValue(const std::tuple<Args...>& t)
2082 : IValue(
2083 std::move(c10::guts::apply(c10::ivalue::Tuple::create<const Args&...>, t))) {
2084}
2085
2086template <
2087 typename... Args,
2088 std::enable_if_t<
2089 !guts::disjunction<
2090 std::is_lvalue_reference<Args>...,
2091 guts::negation<std::is_constructible<IValue, Args>>...>::value,
2092 std::nullptr_t>>
2093inline IValue::IValue(std::tuple<Args...>&& t)
2094 : IValue(
2095 std::move(c10::guts::apply(c10::ivalue::Tuple::create<Args&&...>, std::move(t)))) {
2096}
2097
2098inline IValue::IValue(c10::intrusive_ptr<ivalue::ConstantString> v)
2099 : tag(Tag::String) {
2100 payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2101}
2102inline IValue::IValue(std::string v)
2103 : IValue(ivalue::ConstantString::create(std::move(v))) {}
2104
2105inline IValue::IValue(c10::impl::GenericList v)
2106 : tag(Tag::GenericList) {
2107 payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release());
2108}
2109
2110template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
2111inline IValue::IValue(c10::List<T>&& v) : IValue(impl::toList<T>(std::move(v))) {}
2112template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
2113inline IValue::IValue(const c10::List<T>& v) : IValue(impl::toList<T>(v)) {}
2114template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
2115inline IValue::IValue(at::ArrayRef<T> v) : IValue(c10::List<T>()) {
2116 auto list = to<c10::List<T>>();
2117 list.reserve(v.size());
2118 for (const auto& e : v) {
2119 list.push_back(e);
2120 }
2121}
2122template <class T, IValue::enable_if_symint<T>>
2123inline IValue::IValue(at::ArrayRef<T> v) : IValue() {
2124 auto vi = c10::asIntArrayRefSlowOpt(v);
2125 if (vi.has_value()) {
2126 // This list is entirely integers; ensure it is typed as
2127 // an IntList so toIntList works
2128 *this = IValue(*vi);
2129 } else {
2130 // This list has SymInts; type it as a SymInt
2131 *this = IValue(impl::toList<c10::SymInt>(c10::List<c10::SymInt>()));
2132 auto list = to<c10::List<c10::SymInt>>();
2133 list.reserve(v.size());
2134 for (const auto& e : v) {
2135 list.push_back(e);
2136 }
2137 }
2138}
2139template <class T, IValue::enable_if_symint<T>>
2140inline IValue::IValue(at::OptionalArrayRef<T> mb_v) : IValue() {
2141 if (!mb_v.has_value()) return;
2142 *this = IValue(*mb_v);
2143}
2144template <class T, IValue::enable_if_symint<T>>
2145inline IValue::IValue(const std::vector<T>& v) : IValue() {
2146 *this = IValue(at::ArrayRef<T>(v));
2147}
2148template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
2149inline IValue::IValue(const std::vector<T>& v) : IValue(c10::List<T>()) {
2150 auto list = to<c10::List<T>>();
2151 list.reserve(v.size());
2152 for (const auto& e : v) {
2153 list.push_back(e);
2154 }
2155}
2156template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
2157inline IValue::IValue(c10::OptionalArrayRef<T> v) : IValue() {
2158 if (v.has_value()) {
2159 *this = IValue(std::move(*v));
2160 }
2161}
2162
2163template <class T, size_t N>
2164inline IValue::IValue(std::array<T, N> v) : IValue(c10::List<T>()) {
2165 auto list = to<c10::List<T>>();
2166 list.reserve(v.size());
2167 for (auto& e : v) {
2168 list.push_back(std::move(e));
2169 }
2170}
2171
2172template <class T, IValue::enable_if_ilist_is_ivalue_constructible<T>>
2173inline IValue::IValue(c10::IListRef<T> v) : IValue() {
2174 constexpr bool boxed_type_constructs_ivalue =
2175 std::is_constructible<IValue, typename c10::IListRef<T>::boxed_type>::value;
2176 // First, we try to use the boxed value.
2177 // If we fail (either it's not in the boxed state, or its boxed type
2178 // can not construct an IValue), we fallback to copying the list.
2179 if (boxed_type_constructs_ivalue && v.isBoxed()) {
2180 *this = IValue(impl::toList(v.toBoxed()));
2181 } else {
2182 c10::List<T> list;
2183 list.reserve(v.size());
2184 for (const auto& t : v) {
2185 list.push_back(t);
2186 }
2187 *this = IValue(impl::toList(std::move(list)));
2188 }
2189}
2190
2191inline IValue::IValue(c10::impl::GenericDict v)
2192 : tag(Tag::GenericDict) {
2193 payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release());
2194}
2195template <class Key, class Value>
2196inline IValue::IValue(c10::Dict<Key, Value> v)
2197 : IValue(impl::toGenericDict(std::move(v))) {}
2198
2199template <class Key, class Value>
2200inline IValue::IValue(std::unordered_map<Key, Value> v)
2201 : IValue(Dict<Key, Value>()) {
2202 auto dict = to<c10::Dict<Key, Value>>();
2203 dict.reserve(v.size());
2204 for (auto& e : v) {
2205 dict.insert(std::move(e.first), std::move(e.second));
2206 }
2207}
2208
2209template <class T, IValue::enable_if_ivalue_constructible<T>>
2210inline IValue::IValue(c10::optional<T> v) : IValue() {
2211 if (v.has_value()) {
2212 *this = IValue(std::move(*v));
2213 }
2214}
2215
2216inline IValue::IValue(c10::nullopt_t) : IValue() {}
2217
2218inline IValue::IValue(c10::intrusive_ptr<ivalue::Object> v)
2219 : tag(Tag::Object) {
2220 payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2221}
2222
2223inline IValue::IValue(c10::intrusive_ptr<ivalue::PyObjectHolder> v)
2224 : tag(Tag::PyObject) {
2225 payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2226}
2227
2228inline IValue::IValue(c10::intrusive_ptr<ivalue::EnumHolder> v)
2229 : tag(Tag::Enum) {
2230 payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2231}
2232
2233inline IValue IValue::make_capsule(
2234 intrusive_ptr<torch::CustomClassHolder> blob) {
2235 IValue iv;
2236 iv.tag = Tag::Capsule;
2237 iv.payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release());
2238 return iv;
2239}
2240
2241template <
2242 typename T,
2243 std::enable_if_t<std::is_base_of<torch::CustomClassHolder, T>::value, int>>
2244IValue::IValue(c10::intrusive_ptr<T> custom_class) : tag(Tag::Object) {
2245 auto classType = []() {
2246 try {
2247 return c10::getCustomClassType<c10::intrusive_ptr<T>>();
2248 } catch (const c10::Error&) {
2249 throw c10::Error(
2250 "Trying to instantiate a class that isn't a registered custom class: " +
2251 std::string(c10::util::get_fully_qualified_type_name<T>()),
2252 "");
2253 }
2254 }();
2255 auto ivalue_obj = c10::ivalue::Object::create(std::move(classType), /* numSlots */1);
2256 ivalue_obj->setSlot(0, IValue::make_capsule(std::move(custom_class)));
2257 payload.u.as_intrusive_ptr = null_to_undefined_tensor(ivalue_obj.release());
2258
2259}
2260
2261inline IValue::IValue(c10::intrusive_ptr<ivalue::Future> v)
2262 : tag(Tag::Future) {
2263 payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2264}
2265
2266inline IValue::IValue(c10::intrusive_ptr<ivalue::Await> v)
2267 : tag(Tag::Await) {
2268 payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2269}
2270
2271inline IValue::IValue(c10::intrusive_ptr<c10::RRefInterface> v)
2272 : tag(Tag::RRef) {
2273 payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2274}
2275
2276inline IValue::IValue(c10::intrusive_ptr<at::Quantizer> v)
2277 : tag(Tag::Quantizer) {
2278 payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2279}
2280
2281template <typename T>
2282inline IValue::IValue(c10::complex<T> c)
2283 : tag(Tag::ComplexDouble) {
2284 auto v = c10::make_intrusive<ivalue::ComplexHolder>(c);
2285 payload.u.as_intrusive_ptr = v.release();
2286}
2287
2288inline const std::string& IValue::toStringRef() const {
2289 AT_ASSERT(isString(), "Expected String but got ", tagKind());
2290 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2291 payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2292 "called toStringRef on null intrusive_ptr IValue");
2293 return static_cast<const c10::ivalue::ConstantString*>(
2294 payload.u.as_intrusive_ptr)
2295 ->string();
2296}
2297inline c10::optional<std::reference_wrapper<const std::string>> IValue::
2298 toOptionalStringRef() const {
2299 if (isNone()) {
2300 return c10::nullopt;
2301 }
2302 AT_ASSERT(isString(), "Expected optional<string> but got ", tagKind());
2303 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2304 payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2305 "called toOptionalStringRef on null intrusive_ptr IValue");
2306 return std::reference_wrapper<const std::string>(
2307 static_cast<const c10::ivalue::ConstantString*>(payload.u.as_intrusive_ptr)
2308 ->string());
2309}
2310
2311inline c10::string_view IValue::toStringView() const {
2312 AT_ASSERT(isString(), "Expected String but got ", tagKind());
2313 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2314 payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2315 "called toStringView on null intrusive_ptr IValue");
2316 return static_cast<const c10::ivalue::ConstantString*>(
2317 payload.u.as_intrusive_ptr)
2318 ->string_view();
2319}
2320
2321inline PyObject* IValue::toPyObject() const {
2322 return toPyObjectHolder()->getPyObject();
2323}
2324
2325template <typename T>
2326inline optional<T> IValue::toOptional() {
2327 if (this->isNone()) {
2328 return nullopt;
2329 }
2330 return this->to<T>();
2331}
2332
2333template <typename T>
2334inline optional<T> IValue::toOptional() const {
2335 if (this->isNone()) {
2336 return nullopt;
2337 }
2338 return this->to<T>();
2339}
2340
2341inline bool IValue::isCustomClass() const {
2342 return torch::isCustomClass(*this);
2343}
2344
2345inline bool IValue::isSameIdentity(const IValue& rhs) const {
2346 // We choose to not use memcmp for payload check due to potential random
2347 // padding characters on union type
2348
2349 // Semantics:
2350 // 1. Immutable primitive values of the same type (Int, Double, None, Bool,
2351 // Str) return value equality
2352 // 2. If it is a tensor type, we need to take undefined tensor into account
2353 // 3. Undefined_tensor is None and vice versa should be true
2354 // 4. If it is a reference type (i.e. isIntrusivePtr()), then is True when
2355 // the pointed-to object is the same.
2356 // 5. False for all other comparisons.
2357 if (this->isNone() && rhs.isNone()) {
2358 return true;
2359 } else if (this->isBool() && rhs.isBool()) {
2360 // for bool type, do equality check
2361 return this->toBool() == rhs.toBool();
2362 } else if (this->isTensor() && rhs.isTensor()) {
2363 return this->payload.as_tensor.is_same(rhs.payload.as_tensor);
2364 } else if (this->isTensor() && rhs.isNone()) {
2365 // special case: undefined tensor and None are the same identity
2366 return !this->payload.as_tensor.defined();
2367 } else if (this->isNone() && rhs.isTensor()) {
2368 // special case: undefined tensor and None are the same identity
2369 return !rhs.payload.as_tensor.defined();
2370 } else if (this->isInt() && rhs.isInt()) {
2371 return this->toInt() == rhs.toInt();
2372 } else if (this->isDouble() && rhs.isDouble()) {
2373 return this->toDouble() == rhs.toDouble();
2374 } else if (this->isString() && rhs.isString()) {
2375 return this->toStringRef() == rhs.toStringRef();
2376 } else {
2377 // for objects holding in IValue, do shallow compare on pointer address to
2378 // testify the identity
2379 return this->isIntrusivePtr() && rhs.isIntrusivePtr() &&
2380 this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr;
2381 }
2382}
2383
2384namespace ivalue {
2385namespace detail {
2386
2387template <typename T>
2388IValue from_(T&& x, std::true_type) {
2389 return IValue(std::forward<T>(x));
2390}
2391template <typename T>
2392IValue from_(c10::intrusive_ptr<T> x, std::false_type) {
2393 return IValue(std::move(x));
2394}
2395template <typename T>
2396IValue from_(T&& /*x*/, std::false_type) {
2397 static_assert(
2398 guts::false_t<T>::value,
2399 "You are calling from with a type that it doesn't support, and isn't a potential custom class (ie: is an intrusive_ptr)");
2400 return IValue();
2401}
2402} // namespace detail
2403
2404template <typename T>
2405IValue from(T&& x) {
2406 return detail::from_(
2407 std::forward<T>(x), typename std::is_constructible<IValue, T>::type{});
2408}
2409
2410} // namespace ivalue
2411
2412
2413template <>
2414struct MaybeOwnedTraits<IValue> {
2415 using owned_type = IValue;
2416 using borrow_type = IValue;
2417
2418 static borrow_type createBorrow(const owned_type& from) {
2419 if (!from.isPtrType()) {
2420 return from;
2421 }
2422 if (from.isTensor()) {
2423 return IValue(MaybeOwnedTraits<at::Tensor>::createBorrow(from.toTensor()));
2424 } else {
2425 return IValue(from.payload, from.tag);
2426 }
2427 }
2428
2429 static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
2430 lhs.clearToNone();
2431 if (!rhs.isPtrType()) {
2432 lhs = rhs;
2433 } else if (rhs.isTensor()) {
2434 lhs = IValue(MaybeOwnedTraits<at::Tensor>::createBorrow(rhs.toTensor()));
2435 } else {
2436 lhs = IValue(rhs.payload, rhs.tag);
2437 }
2438 }
2439
2440 static void destroyBorrow(borrow_type& toDestroy) {
2441 toDestroy.clearToNone();
2442 }
2443
2444 static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
2445 return borrow;
2446 }
2447
2448 static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
2449 return &borrow;
2450 }
2451
2452 static bool debugBorrowIsValid(const borrow_type&) {
2453 return true;
2454 }
2455};
2456
2457template <>
2458struct IValue::TagType<c10::Type> {
2459 static TORCH_API c10::TypePtr get(const IValue&);
2460};
2461
2462template <>
2463struct IValue::TagType<c10::DynamicType> {
2464 static TORCH_API c10::TypePtr get(const IValue&);
2465};
2466
2467template <typename T>
2468TypePtr IValue::type() const {
2469 return IValue::TagType<T>::get(*this);
2470}
2471
2472} // namespace c10
2473